对pytorch中的梯度更新方法详解

背景

使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整。收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样。据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟。

对代码说明

共三个实验,分布写在代码中的(一)(二)(三)三个地方。运行实验时注释掉其他两个

实验及其结果

实验(三):

不使用zero_grad()时,grad累加在一起,官网是使用accumulate 来表述的,所以不太清楚是取的和还是均值(这两种最有可能)。

不使用zero_grad()时,是直接叠加add的方式累加的。

tensor([[[ 1., 1.],……torch.Size([2, 2, 2])

0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

tensor([[[ 2., 2.],…… torch.Size([2, 2, 2])

1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

实验(二):

单卡上不同的batchsize对梯度是怎么作用的。 mini-batch SGD中的batch是加快训练,同时保持一定的噪声。但设置不同的batchsize的权重的梯度是怎么计算的呢。

设置运行实验(二),可以看到结果如下:所以单卡batchsize计算梯度是取均值的

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验(一):

多gpu情况下,梯度怎么合并在一起的。

在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是当设置g=2,实验一运行时,结果也是取均值的,类同于实验(二)

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验代码

import torch

import torch.nn as nn

from torch.autograd import Variable

class model(nn.Module):

def __init__(self, w):

super(model, self).__init__()

self.w = w

def forward(self, xx):

b, c, _, _ = xx.shape

# extra = xx.device.index + 1 ## 实验(一)

y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra)

return y.reshape(len(xx), -1)

g = 1

x = Variable(torch.ones(2, 1, 2, 2))

# x[1] += 1 ## 实验(二)

w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True)

# optim = torch.optim.SGD({'params': x},

lr = 0.01

momentum = 0.9

M = model(w)

M = torch.nn.DataParallel(M, device_ids=range(g))

for i in range(3):

b = len(x)

z = M(x)

zz = z.sum(1)

l = (zz - Variable(torch.ones(b).cuda())).mean()

# zz.backward(Variable(torch.ones(b).cuda()))

l.backward()

print(w.grad, w.grad.shape)

# w.grad.zero_() ## 实验(三)

print(i, b, '* * ' * 20)

以上这篇对pytorch中的梯度更新方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

以上是 对pytorch中的梯度更新方法详解 的全部内容, 来源链接: utcz.com/z/332581.html

回到顶部