以channel Attention Block为例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
class CAB(nn.Module): def __init__( self , in_channels, out_channels): super (CAB, self ).__init__() self .global_pooling = nn.AdaptiveAvgPool2d(output_size = 1 ) self .conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1 , stride = 1 , padding = 0 ) self .relu = nn.ReLU() self .conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 1 , stride = 1 , padding = 0 ) self .sigmod = nn.Sigmoid() def forward( self , x): x1, x2 = x # high, low x = torch.cat([x1,x2],dim = 1 ) x = self .global_pooling(x) x = self .conv1(x) x = self .relu(x) x = self .conv2(x) x = self .sigmod(x) x2 = x * x2 res = x2 + x1 return res |
以上这篇pytorch forward两个参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/weixin_41950276/article/details/89069659