请问MinkowskiEngine中怎么实现PyTorch: view( ), torch.bmm和torch.nn.Parameter()?
我最近在使用MinkowskiEngine在Resnet中添加注意力模块,其中NonLocal注意力模块用Pytorch实现的代码如下:
class NonLocalModule(nn.Module): def __init__(self, C, latent= 8):
super(NonLocalModule, self).__init__()
self.inputChannel = C
self.latentChannel = C // latent
self.bn1 = nn.BatchNorm1d(C//latent)
self.bn2 = nn.BatchNorm1d(C//latent)
self.bn3 = nn.BatchNorm1d(C//latent)
self.bn4 = nn.BatchNorm1d(C)
self.cov1 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
self.bn1,
nn.ReLU())
self.cov2 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
self.bn2,
nn.ReLU())
self.cov3 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
self.bn3,
nn.ReLU())
self.out_conv = nn.Sequential(nn.Conv1d(in_channels=C//latent, out_channels=C, kernel_size=1, bias=False),
self.bn4,
nn.ReLU())
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
b, c, n = x.shape
out1 = self.cov1(x).view(b, -1, n).permute(0, 2, 1) #b,n,c/latent
out2 = self.cov2(x).view(b, -1, n) #b, c/latent, n
attention_matrix = self.softmax(torch.bmm(out1, out2)) # b,n,n
out3 = self.cov3(x).view(b, -1, n) # b,c/latent,n
attention = torch.bmm(out3, attention_matrix.permute(0, 2, 1)) # b,c/latent,n
out = self.out_conv(attention) #b,c,n
return self.gamma*out + x
nn.BatchNorm1d, nn.Conv1d和nn.ReLU都在MinkowskiEngine中有对应的实现版本。但是我搜索MinkowskiEngine官方文档却没有找到PyTorch: view( ), torch.bmm和torch.nn.Parameter()该如何用MinkowskiEngine实现。有人知道该怎么将上述代码改成MinkowskiEngine版本吗?
回答:
大概思路:
import MinkowskiEngine as MEclass NonLocalModule(ME.MinkowskiNetwork):
def __init__(self, C, latent=8):
super(NonLocalModule, self).__init__()
# ... 定义层和参数 ...
def forward(self, x):
# x 是一个 ME.SparseTensor 对象
b, c, n = x.feats.size()
# 用 view-like 操作
out1 = self.cov1(x.feats).reshape(b, -1, n).permute(0, 2, 1)
# 手动实现 bmm
out2 = self.cov2(x.feats).reshape(b, -1, n)
attention_matrix = []
for i in range(b):
attention_matrix.append(torch.mm(out1[i], out2[i]))
attention_matrix = torch.stack(attention_matrix)
# ...
# 返回 ME.SparseTensor
return ME.SparseTensor(feats=self.gamma * out + x.feats, coords_key=x.coords_key, coords_manager=x.coords_man)
以上是 请问MinkowskiEngine中怎么实现PyTorch: view( ), torch.bmm和torch.nn.Parameter()? 的全部内容, 来源链接: utcz.com/p/938975.html