任务要求:
自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:
1
2
3
|
import torch from torch.autograd import Function from torch.autograd import Variable |
定义二值化函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
class BinarizedF(Function): def forward( self , input ): self .save_for_backward( input ) a = torch.ones_like( input ) b = - torch.ones_like( input ) output = torch.where( input > = 0 ,a,b) return output def backward( self , output_grad): input , = self .saved_tensors input_abs = torch. abs ( input ) ones = torch.ones_like( input ) zeros = torch.zeros_like( input ) input_grad = torch.where(input_abs< = 1 ,ones, zeros) return input_grad |
定义一个module
1
2
3
4
5
6
7
8
|
class BinarizedModule(nn.Module): def __init__( self ): super (BinarizedModule, self ).__init__() self .BF = BinarizedF() def forward( self , input ): print ( input .shape) output = self .BF( input ) return output |
进行测试
1
2
3
4
5
|
a = Variable(torch.randn( 4 , 480 , 640 ), requires_grad = True ) output = BinarizedModule()(a) output.backward(torch.ones(a.size())) print (a) print (a.grad) |
其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s
1
2
3
4
5
6
7
8
9
10
11
12
|
class BinarizedF(Function): def forward( self , input ): self .save_for_backward( input ) output = torch.ones_like( input ) output[ input < 0 ] = - 1 return output def backward( self , output_grad): input , = self .saved_tensors input_grad = output_grad.clone() input_abs = torch. abs ( input ) input_grad[input_abs> 1 ] = 0 return input_grad |
以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/weixin_42696356/article/details/100899711