创建 autograd.Function
的子类,需要实现两个静态的方法 forward 和 backward。应用该 op 时,调用 apply 方法,不要直接调用 forward 方法。
forward 静态方法中第一个参数为 ctx,它可以理解 Function 对象本身,其方法 save_for_backward 用于在前向计算时,将反向计算用到的中间结果进行缓存。在 backward 静态方法中,也有 ctx 方法,我们可以通过它的 saved_tensors 来取出前向计算时缓存的中间结果。
前向计算方法 forward 的其他参数为该 op 需要传递进来的参数,例如:该 op 用作乘法运算,则其他的两个参数就是需要相乘的两个张量。后向计算方法 backward 的另一个参数为 grad_outputs,我们知道神经网络是通过链式求导的方法来计算参数梯度,这里的 grad_outputs 为上一个步输出的梯度。backward 方法最终要返回相应张量的梯度,例如:forward 函数中按顺序传入了 w,x 则在 backward 中计算完 w 和 x 梯度之后,以此返回即可。
示例代码:
import torch
from torchviz import make_dot
class CustomMul(torch.autograd.Function):
@staticmethod
def forward(ctx, w, x):
# print('CustomMul forward')
ctx.save_for_backward(w, x)
return w * x
@staticmethod
def backward(ctx, grad_outputs):
# print('CustomMul backward')
# print('CustomMul grad_outputs:', grad_outputs)
w = ctx.saved_tensors[0]
x = ctx.saved_tensors[1]
w_grad = grad_outputs * x
x_grad = grad_outputs * w
return w_grad, x_grad
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b):
# print('CustomAdd forward')
return x + b
@staticmethod
def backward(ctx, grad_outputs):
# print('CustomAdd backward')
# print('CustomAdd grad_outputs:', grad_outputs)
b_grad = grad_outputs * 1
x_grad = grad_outputs * 1
return x_grad, b_grad
def mul(w, input):
return CustomMul.apply(w, input)
def add(b, input):
return CustomAdd.apply(b, input)
if __name__ == '__main__':
w = torch.tensor([[2.0, 3.0]], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)
x = torch.tensor([[4.0, 5.0]])
outputs = mul(w, x)
outputs = add(b, torch.sum(outputs))
dot = make_dot(outputs, show_attrs=True)
dot.render('temp', format='png', cleanup=False)
outputs.backward()
print(w.grad, x.grad, b.grad)
程序执行结果:
tensor([[4., 5.]], dtype=torch.float64) None tensor([1.])
计算图如下:
最后,可以通过 torch.autograd.gradcheck
检查下自定义的 CustomMul 工作是否正常:
print(torch.autograd.gradcheck(func=CustomMul.apply, inputs=(w, x)))
输出结果为:True