pytorch自定义二值化网络层方式

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch

from torch.autograd import Function

from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):

def forward(self, input):

self.save_for_backward(input)

a = torch.ones_like(input)

b = -torch.ones_like(input)

output = torch.where(input>=0,a,b)

return output

def backward(self, output_grad):

input, = self.saved_tensors

input_abs = torch.abs(input)

ones = torch.ones_like(input)

zeros = torch.zeros_like(input)

input_grad = torch.where(input_abs<=1,ones, zeros)

return input_grad

定义一个module

class BinarizedModule(nn.Module):

def __init__(self):

super(BinarizedModule, self).__init__()

self.BF = BinarizedF()

def forward(self,input):

print(input.shape)

output =self.BF(input)

return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)

output = BinarizedModule()(a)

output.backward(torch.ones(a.size()))

print(a)

print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):

def forward(self, input):

self.save_for_backward(input)

output = torch.ones_like(input)

output[input<0] = -1

return output

def backward(self, output_grad):

input, = self.saved_tensors

input_grad = output_grad.clone()

input_abs = torch.abs(input)

input_grad[input_abs>1] = 0

return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

以上是 pytorch自定义二值化网络层方式 的全部内容, 来源链接: utcz.com/z/323412.html

回到顶部