pytorch中torch.no_grad()、requires_grad、eval()

python

requires_grad

requires_grad=True 要求计算梯度;
requires_grad=False 不要求计算梯度;

在pytorch中,tensor有一个 requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。 tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么 所有依赖它的节点requires_grad都为True (即使其他相依赖的tensor的requires_grad = False)

x = torch.randn(10, 5, requires_grad = True)

y = torch.randn(10, 5, requires_grad = False)

z = torch.randn(10, 5, requires_grad = False)

w = x + y + z

w.requires_grad

输出:

True

volatile

volatile是Variable的另一个重要的标识,它能够将所有依赖它的节点全部设为volatile=True,优先级比requires_grad=True高。

而volatile=True的节点不会求导,即使requires_grad=True,也不会进行反向传播,对于不需要反向传播的情景(inference,测试阶段推断阶段),该参数可以实现一定速度的提升,并节省一半的显存,因为其不需要保存梯度。

但是, 注意 volatile已经取消了,使用with torch.no_grad()来替代 。

torch.no_grad()

是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。

with torch.no_grad()或者@torch.no_grad()中的数据不需要计算梯度,也不会进行反向传播。

(torch.no_grad()是新版本pytorch中volatile的替代)

x = torch.randn(2, 3, requires_grad = True)

y = torch.randn(2, 3, requires_grad = False)

z = torch.randn(2, 3, requires_grad = False)

m=x+y+z

with torch.no_grad():

w = x + y + z

print(w)

print(m)

print(w.requires_grad)

print(w.grad_fn)

print(w.requires_grad)

输出:

tensor([[-2.7066, -0.7406,  0.5740],

[-0.7071, -1.6057, 1.9732]])

tensor([[-2.7066, -0.7406, 0.5740],

[-0.7071, -1.6057, 1.9732]], grad_fn=<AddBackward0>)

False

None

False

model.eval()与with torch.no_grad()

共同点:

在PyTorch中进行validation时,使用这两者均可切换到测试模式。

如用于通知dropout层和batchnorm层在train和val模式间切换。

在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。

在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

不同点:

model.eval()会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传。

with torch.zero_grad()则停止autograd模块的工作,也就是停止gradient计算,以起到加速和节省显存的作用,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。

也就是说,如果不在意显存大小和计算时间的话,仅使用model.eval()已足够得到正确的validation的结果;而with torch.zero_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储gradient),从而可以更快计算,也可以跑更大的batch来测试。

参考

1.https://www.jianshu.com/p/1cea017f5d11

2.csdn博客

以上是 pytorch中torch.no_grad()、requires_grad、eval() 的全部内容, 来源链接: utcz.com/z/530723.html

回到顶部