Pytorch修改指定模块权重的方法,即 torch.Tensor.detach()和Tensor.requires_grad方法的用法

0、前言

在学习pytorch的计算图和自动求导机制时,我们要想在心中建立一个“计算过程的图像”,需要深入了解其中的每个细节,这次主要说一下tensor的requires_grad参数。
无论如何定义计算过程、如何定义计算图,要谨记我们的核心目的是为了计算某些tensor的梯度。在pytorch的计算图中,其实只有两种元素:数据(tensor)和运算,运算就是加减乘除、开方、幂指对、三角函数等可求导运算,而tensor可细分为两类:叶子节点(leaf node)和非叶子节点。使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True
叶子节点和tensor的requires_grad参数

一、detach()那么这个函数有什么作用?

假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = a.detach()

b = B(a)
loss = criterion(b, target)
loss.backward()

以下代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True

x = x.detach()   #分离之后
x.requires_grad   #False

y = x+y      	  #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行

z = t.pow(y, 2)
z.backward()    #反向传播

y.grad        #tensor([4.])
x.grad        #None

二、Tensor.requires_grad属性

既然谈到了修改模型的权重问题,那么还有一种情况是:
假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?
这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可

for param in B.parameters():
	param.requires_grad = False

a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
上一篇:Error: The HTTP image filter module requires the GD library – Nginx / CentOS 7.3


下一篇:MySql安装问题This application requires Visual Studio 2013 Redistributable. Please install