
[Pytorch] multi-output일 때 input gradient 구하기
·
프로그래밍 Programming/파이썬 Python
output에 대한 input의 gradient를 구하는 방법은 여러가지인데 input tensor에 대해 requires_grad는 반드시 True로 지정해야 한다. 그래야 gradient 추적이 가능하다. 그 다음에는 두 가지 방법으로 나뉜다. 1. backward()를 이용하고 (input tensor).grad를 통해 가져오는 방법 2. torch.autograd.grad를 이용해서 직접 계산하는 방법 1번이 전형적으로 gradient를 구하는 방법이지만 나는 얻어낸 gradient에 대하여 graph를 연장해야했다. 그러려면 1번 방법으로는 충분하지 않다. 그래서 2번을 해야하는데 multioutput일 때는 gradient에 대한 이해도가 좀 더 요구되는 편이다. 1번 방법. backward(..