Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기

보통의공대생 2023. 3. 22. 11:02

본 글에서는 JAX로 미분값을 구하는 방법에 대해서 다룬다.

JAX에서는 미분값을 구하기 위해 grad, jacfwd, jacrev를 제공하기 때문에 몇 가지 예제를 통해서 익숙해지고자 한다.

일단 크게 scalar-valued function과 vector-valued function으로 나누고, 각 function이 한 개의 변수에만 의존하는지, 또는 두 개 이상의 변수에만 의존하는지를 따진다.

 

예제코드는 유튜브 튜토리얼 + JAX 매뉴얼을 참고하였다.

 

1. Scalar-valued function일 때

Gradient는 scalar-valued univariate function에 대한 기울기

Jacobian은 vector-valued or scalar-valued multivariate function 에 대한 기울기이다.

Hessian은 scalar-valued multivariate function을 두 번 미분한 것이다.

 

1-1) Gradient

x = 1.  # example input

f = lambda x: x**2 + x + 4  # simple 2nd order polynomial fn
visualize_fn(f, l=-1, r=2, n=100)

dfdx = grad(f)  # 2*x + 1, IF x=1.0, then 6.0
d2fdx = grad(dfdx)  # 2
d3fdx = grad(d2fdx)  # 0

 

 

1-2) Jacobian & Hessian

 

함수의 입력이 2개 이상일 때

from jax import jacfwd, jacrev

f = lambda x, y: x**2 + y**2  # simple paraboloid

# df/dx = 2x
# df/dy = 2y
# J = [df/dx, df/dy]

# d2f/dx = 2
# d2f/dy = 2
# d2f/dxdy = 0
# d2f/dydx = 0
# H = [[d2f/dx, d2f/dxdy], [d2f/dydx, d2f/dy]]

def hessian(f):
    return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums =(0, 1)))

print(f'Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}')
print(f'Full Hessian = {hessian(f)(1., 1.)}')

 

위의 예제에서는 jacfwd, jacrev를 썼다.

 

jacfwd는 함수 설명을 보면 

Jacobian of fun evaluated column-by-column using forward-mode AD.

automatic differentiation을 forward-mode로 수행하여 자코비안을 구한다. 즉 jacfwd는 jacobian + forward에서 따온 함수명이다.

 

jacrev의 경우에는 reverse-mode automatic differentiation으로 계산한다.

Jacobian of fun evaluated row-by-row using reverse-mode AD.

 

그렇다면 이 두 개는 동일한 것이 아닌가, 생각할 수도 있다. 실제로 결과는 동일하게 나온다. 그러나 매뉴얼을 보면 jacfwd의 경우에는 행이 더 많은, 긴 직사각형 행렬에 더 유리하고, jacrev는 열이 더 많은, 넓은 직사각형 행렬에 더 유리하다. 또한 정사각형 행렬에 가까울 경우에는 jacfwd가 약간 더 좋다고 한다.

 

그리고 위의 함수는 scalar function이라서 굳이 jacfwd, jacrev를 쓰지 않고 grad로도 구할 수 있다.

 

f = lambda x, y: x**2 + y**2  # simple paraboloid
grad(f, argnums=(0,1))(1.,1.)
>>
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(2., dtype=float32, weak_type=True))

 


 

2. Vector-valued function일 때

 

2-1) 변수 1개

 

vactor-valued function은 사실 상 스칼라함수가 여러 개 있는 것이다. 그러나 grad 함수는 스칼라 함수에 대해서만 적용할 수 있기 때문에 위의 jacfwd를 이용해야 한다.

앞서 본 Jacobian과 Hessian을 보면 jacrev와 jacfwd를 썼는데 

 

from jax import jacfwd, jacrev

f = lambda x: np.array( [ x**2, 4*x ] )  # simple paraboloid (google it...)

print(f'Jacobian = {jacrev(f)(1.)}')
Jacobian = [2. 4.]

결과는 다음과 같이 나온다. grad를 쓰면 오류가 출력되니 주의해야 한다.

 

2-2) 변수 2개 이상

from jax import jacfwd, jacrev

f = lambda x, y: np.array( [ x**3 + 2 * x**2 + 2 * y**3 + y**2, 4*x+3* y**2] )

# jacobian
# [ [ 3x^2 + 4x, 4 ], [ 6y^2 + 2y, 6y ] ]

# Hessian
# [ [6x + 4, 0 ],[ 0, 0 ]]
# [ [ 0, 0 ], [ 12y+2, 6 ] ]

def hessian(f):
    return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums =(0, 1)))

print(f'Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}')
print(f'Full Hessian = {hessian(f)(1., 1.)}')

 

 

 

출력은 다음과 같다.

Jacobian = (DeviceArray([7., 4.], dtype=float32, weak_type=True), DeviceArray([8., 6.], dtype=float32, weak_type=True))
Full Hessian = ((DeviceArray([10.,  0.], dtype=float32, weak_type=True), DeviceArray([0., 0.], dtype=float32, weak_type=True)), (DeviceArray([0., 0.], dtype=float32, weak_type=True), DeviceArray([14.,  6.], dtype=float32, weak_type=True)))

 

Jacobian을 구하면 각 scalar function의 jacobian이 튜플 형태로 나오게 된다.

Hessian을 구하면 또 그 함수에 대해서 x,y에 대한 미분값을 구하므로 자세히 보면 

 

이렇게 각각 묶음이 나온다.

 

(DeviceArray([10., 0.], dtype=float32, weak_type=True), DeviceArray([0., 0.], dtype=float32, weak_type=True))

 

(DeviceArray([0., 0.], dtype=float32, weak_type=True), DeviceArray([14., 6.], dtype=float32, weak_type=True))

 

위의 경우에는 Jacobian 첫 번째에 y 항이 전혀 없어서 한 array는 0으로 나올 수 밖에 없다.

아래의 경우에도 Jacobian 두 번째에 x항이 전혀 없어서 첫 array는 0으로 나온다.

 

 

 

Comments