프로그래밍 Programming/파이썬 Python
[PyTorch] gradient descent로 변수를 직접 update할 때 주의할 점
보통의공대생
2022. 7. 14. 16:05
코드 상에서 특정 변수를 따로 gradient descent 방법으로 업데이트해야할 일이 있는데 이상하게 에러가 났다.
그래서 쉬운 예제를 통해서 이해를 해보고자 했다.
a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)
b = torch.sin(a)
c = 2 * b
d = c + 1
out = d.sum()
out.backward(retain_graph=True)
gradient = a.grad.clone().detach()
a -= 0.001 * gradient
print(a.requires_grad)
이렇게 코드를 짜면
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
다음과 같은 error가 return된다.
a는 여러 개의 하위 값들을 가지고 있어서 gradient에 learning rate를 곱해서 바로 업데이트 할 수 없다.
대신
a.data -= 0.001 * gradient
으로 업데이트한다.
그리고 굳이 gradient를 따로 두지 않고 a.grad를 사용해도 된다.