일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- JAX
- pytorch
- Dear abby
- 우분투
- teps
- 고체역학
- LaTeX
- 딥러닝
- 에러기록
- Python
- matplotlib
- Statics
- WOX
- Zotero
- 텝스공부
- MATLAB
- 옵시디언
- IEEE
- 수치해석
- 인공지능
- Numerical Analysis
- Linear algebra
- 생산성
- Julia
- obsidian
- 수식삽입
- 텝스
- 논문작성
- 논문작성법
- ChatGPT
- Today
- Total
뛰는 놈 위에 나는 공대생
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 본문
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기
보통의공대생 2023. 3. 22. 11:02본 글에서는 JAX로 미분값을 구하는 방법에 대해서 다룬다.
JAX에서는 미분값을 구하기 위해 grad, jacfwd, jacrev를 제공하기 때문에 몇 가지 예제를 통해서 익숙해지고자 한다.
일단 크게 scalar-valued function과 vector-valued function으로 나누고, 각 function이 한 개의 변수에만 의존하는지, 또는 두 개 이상의 변수에만 의존하는지를 따진다.
예제코드는 유튜브 튜토리얼 + JAX 매뉴얼을 참고하였다.
1. Scalar-valued function일 때
Gradient는 scalar-valued univariate function에 대한 기울기
Jacobian은 vector-valued or scalar-valued multivariate function 에 대한 기울기이다.
Hessian은 scalar-valued multivariate function을 두 번 미분한 것이다.
1-1) Gradient
x = 1. # example input
f = lambda x: x**2 + x + 4 # simple 2nd order polynomial fn
visualize_fn(f, l=-1, r=2, n=100)
dfdx = grad(f) # 2*x + 1, IF x=1.0, then 6.0
d2fdx = grad(dfdx) # 2
d3fdx = grad(d2fdx) # 0
1-2) Jacobian & Hessian
함수의 입력이 2개 이상일 때
from jax import jacfwd, jacrev
f = lambda x, y: x**2 + y**2 # simple paraboloid
# df/dx = 2x
# df/dy = 2y
# J = [df/dx, df/dy]
# d2f/dx = 2
# d2f/dy = 2
# d2f/dxdy = 0
# d2f/dydx = 0
# H = [[d2f/dx, d2f/dxdy], [d2f/dydx, d2f/dy]]
def hessian(f):
return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums =(0, 1)))
print(f'Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}')
print(f'Full Hessian = {hessian(f)(1., 1.)}')
위의 예제에서는 jacfwd, jacrev를 썼다.
jacfwd는 함수 설명을 보면
Jacobian of fun evaluated column-by-column using forward-mode AD.
automatic differentiation을 forward-mode로 수행하여 자코비안을 구한다. 즉 jacfwd는 jacobian + forward에서 따온 함수명이다.
jacrev의 경우에는 reverse-mode automatic differentiation으로 계산한다.
Jacobian of fun evaluated row-by-row using reverse-mode AD.
그렇다면 이 두 개는 동일한 것이 아닌가, 생각할 수도 있다. 실제로 결과는 동일하게 나온다. 그러나 매뉴얼을 보면 jacfwd의 경우에는 행이 더 많은, 긴 직사각형 행렬에 더 유리하고, jacrev는 열이 더 많은, 넓은 직사각형 행렬에 더 유리하다. 또한 정사각형 행렬에 가까울 경우에는 jacfwd가 약간 더 좋다고 한다.
그리고 위의 함수는 scalar function이라서 굳이 jacfwd, jacrev를 쓰지 않고 grad로도 구할 수 있다.
f = lambda x, y: x**2 + y**2 # simple paraboloid
grad(f, argnums=(0,1))(1.,1.)
>>
(DeviceArray(2., dtype=float32, weak_type=True),
DeviceArray(2., dtype=float32, weak_type=True))
2. Vector-valued function일 때
2-1) 변수 1개
vactor-valued function은 사실 상 스칼라함수가 여러 개 있는 것이다. 그러나 grad 함수는 스칼라 함수에 대해서만 적용할 수 있기 때문에 위의 jacfwd를 이용해야 한다.
앞서 본 Jacobian과 Hessian을 보면 jacrev와 jacfwd를 썼는데
from jax import jacfwd, jacrev
f = lambda x: np.array( [ x**2, 4*x ] ) # simple paraboloid (google it...)
print(f'Jacobian = {jacrev(f)(1.)}')
Jacobian = [2. 4.]
결과는 다음과 같이 나온다. grad를 쓰면 오류가 출력되니 주의해야 한다.
2-2) 변수 2개 이상
from jax import jacfwd, jacrev
f = lambda x, y: np.array( [ x**3 + 2 * x**2 + 2 * y**3 + y**2, 4*x+3* y**2] )
# jacobian
# [ [ 3x^2 + 4x, 4 ], [ 6y^2 + 2y, 6y ] ]
# Hessian
# [ [6x + 4, 0 ],[ 0, 0 ]]
# [ [ 0, 0 ], [ 12y+2, 6 ] ]
def hessian(f):
return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums =(0, 1)))
print(f'Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}')
print(f'Full Hessian = {hessian(f)(1., 1.)}')
출력은 다음과 같다.
Jacobian = (DeviceArray([7., 4.], dtype=float32, weak_type=True), DeviceArray([8., 6.], dtype=float32, weak_type=True))
Full Hessian = ((DeviceArray([10., 0.], dtype=float32, weak_type=True), DeviceArray([0., 0.], dtype=float32, weak_type=True)), (DeviceArray([0., 0.], dtype=float32, weak_type=True), DeviceArray([14., 6.], dtype=float32, weak_type=True)))
Jacobian을 구하면 각 scalar function의 jacobian이 튜플 형태로 나오게 된다.
Hessian을 구하면 또 그 함수에 대해서 x,y에 대한 미분값을 구하므로 자세히 보면
이렇게 각각 묶음이 나온다.
(DeviceArray([10., 0.], dtype=float32, weak_type=True), DeviceArray([0., 0.], dtype=float32, weak_type=True))
(DeviceArray([0., 0.], dtype=float32, weak_type=True), DeviceArray([14., 6.], dtype=float32, weak_type=True))
위의 경우에는 Jacobian 첫 번째에 y 항이 전혀 없어서 한 array는 0으로 나올 수 밖에 없다.
아래의 경우에도 Jacobian 두 번째에 x항이 전혀 없어서 첫 array는 0으로 나온다.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] Gaussian process 파라미터에 따른 결과 visualization (0) | 2023.03.24 |
---|---|
[JAX] 기본 Neural Networks 모델 (0) | 2023.03.23 |
[JAX] Cholesky decomposition error 230318 기준 (0) | 2023.03.18 |
[JAX] JAX vmap에 대한 설명 (0) | 2023.03.17 |
[AI] Sampyl에 대한 간단한 설명 (0) | 2023.03.05 |