评论

收藏

[C++] Pytorch修改指定模块权重的方法,即 torch.Tensor.detach()和Tensor.requires_grad方法的用法

编程语言 编程语言 发布于:2021-07-11 17:25 | 阅读数:492 | 评论:0

  
0、前言
  在学习pytorch的计算图和自动求导机制时,我们要想在心中建立一个“计算过程的图像”,需要深入了解其中的每个细节,这次主要说一下tensor的requires_grad参数。
无论如何定义计算过程、如何定义计算图,要谨记我们的核心目的是为了计算某些tensor的梯度。在pytorch的计算图中,其实只有两种元素:数据(tensor)和运算,运算就是加减乘除、开方、幂指对、三角函数等可求导运算,而tensor可细分为两类:叶子节点(leaf node)和非叶子节点。使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True
叶子节点和tensor的requires_grad参数

一、detach()那么这个函数有什么作用?
  假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法
a = A(input)
a = a.detach()
b = B(a)
loss = criterion(b, target)
loss.backward()
  以下代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None
import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y    #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

二、Tensor.requires_grad属性
  既然谈到了修改模型的权重问题,那么还有一种情况是:
假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?
这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

  
关注下面的标签,发现更多相似文章