请问MinkowskiEngine中怎么实现PyTorch: view( ), torch.bmm和torch.nn.Parameter()?

请问MinkowskiEngine中怎么实现PyTorch: view( ), torch.bmm和torch.nn.Parameter()?

我最近在使用MinkowskiEngine在Resnet中添加注意力模块,其中NonLocal注意力模块用Pytorch实现的代码如下:

class NonLocalModule(nn.Module):

def __init__(self, C, latent= 8):

super(NonLocalModule, self).__init__()

self.inputChannel = C

self.latentChannel = C // latent

self.bn1 = nn.BatchNorm1d(C//latent)

self.bn2 = nn.BatchNorm1d(C//latent)

self.bn3 = nn.BatchNorm1d(C//latent)

self.bn4 = nn.BatchNorm1d(C)

self.cov1 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),

self.bn1,

nn.ReLU())

self.cov2 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),

self.bn2,

nn.ReLU())

self.cov3 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),

self.bn3,

nn.ReLU())

self.out_conv = nn.Sequential(nn.Conv1d(in_channels=C//latent, out_channels=C, kernel_size=1, bias=False),

self.bn4,

nn.ReLU())

self.gamma = nn.Parameter(torch.zeros(1))

self.softmax = nn.Softmax(dim=-1)

def forward(self, x):

b, c, n = x.shape

out1 = self.cov1(x).view(b, -1, n).permute(0, 2, 1) #b,n,c/latent

out2 = self.cov2(x).view(b, -1, n) #b, c/latent, n

attention_matrix = self.softmax(torch.bmm(out1, out2)) # b,n,n

out3 = self.cov3(x).view(b, -1, n) # b,c/latent,n

attention = torch.bmm(out3, attention_matrix.permute(0, 2, 1)) # b,c/latent,n

out = self.out_conv(attention) #b,c,n

return self.gamma*out + x

nn.BatchNorm1d, nn.Conv1d和nn.ReLU都在MinkowskiEngine中有对应的实现版本。但是我搜索MinkowskiEngine官方文档却没有找到PyTorch: view( ), torch.bmm和torch.nn.Parameter()该如何用MinkowskiEngine实现。有人知道该怎么将上述代码改成MinkowskiEngine版本吗?


回答:

大概思路:

import MinkowskiEngine as ME

class NonLocalModule(ME.MinkowskiNetwork):

def __init__(self, C, latent=8):

super(NonLocalModule, self).__init__()

# ... 定义层和参数 ...

def forward(self, x):

# x 是一个 ME.SparseTensor 对象

b, c, n = x.feats.size()

# 用 view-like 操作

out1 = self.cov1(x.feats).reshape(b, -1, n).permute(0, 2, 1)

# 手动实现 bmm

out2 = self.cov2(x.feats).reshape(b, -1, n)

attention_matrix = []

for i in range(b):

attention_matrix.append(torch.mm(out1[i], out2[i]))

attention_matrix = torch.stack(attention_matrix)

# ...

# 返回 ME.SparseTensor

return ME.SparseTensor(feats=self.gamma * out + x.feats, coords_key=x.coords_key, coords_manager=x.coords_man)

以上是 请问MinkowskiEngine中怎么实现PyTorch: view( ), torch.bmm和torch.nn.Parameter()? 的全部内容, 来源链接: utcz.com/p/938975.html

回到顶部