[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기

2023. 3. 22. 11:02·연구 Research/인공지능 Artificial Intelligent

본 글에서는 JAX로 미분값을 구하는 방법에 대해서 다룬다.

JAX에서는 미분값을 구하기 위해 grad, jacfwd, jacrev를 제공하기 때문에 몇 가지 예제를 통해서 익숙해지고자 한다.

일단 크게 scalar-valued function과 vector-valued function으로 나누고, 각 function이 한 개의 변수에만 의존하는지, 또는 두 개 이상의 변수에만 의존하는지를 따진다.

 

예제코드는 유튜브 튜토리얼 + JAX 매뉴얼을 참고하였다.

 

1. Scalar-valued function일 때

Gradient는 scalar-valued univariate function에 대한 기울기

Jacobian은 vector-valued or scalar-valued multivariate function 에 대한 기울기이다.

Hessian은 scalar-valued multivariate function을 두 번 미분한 것이다.

 

1-1) Gradient

x = 1.  # example input

f = lambda x: x**2 + x + 4  # simple 2nd order polynomial fn
visualize_fn(f, l=-1, r=2, n=100)

dfdx = grad(f)  # 2*x + 1, IF x=1.0, then 6.0
d2fdx = grad(dfdx)  # 2
d3fdx = grad(d2fdx)  # 0

 

 

1-2) Jacobian & Hessian

 

함수의 입력이 2개 이상일 때

from jax import jacfwd, jacrev

f = lambda x, y: x**2 + y**2  # simple paraboloid

# df/dx = 2x
# df/dy = 2y
# J = [df/dx, df/dy]

# d2f/dx = 2
# d2f/dy = 2
# d2f/dxdy = 0
# d2f/dydx = 0
# H = [[d2f/dx, d2f/dxdy], [d2f/dydx, d2f/dy]]

def hessian(f):
    return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums =(0, 1)))

print(f'Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}')
print(f'Full Hessian = {hessian(f)(1., 1.)}')

 

위의 예제에서는 jacfwd, jacrev를 썼다.

 

jacfwd는 함수 설명을 보면 

Jacobian of fun evaluated column-by-column using forward-mode AD.

automatic differentiation을 forward-mode로 수행하여 자코비안을 구한다. 즉 jacfwd는 jacobian + forward에서 따온 함수명이다.

 

jacrev의 경우에는 reverse-mode automatic differentiation으로 계산한다.

Jacobian of fun evaluated row-by-row using reverse-mode AD.

 

그렇다면 이 두 개는 동일한 것이 아닌가, 생각할 수도 있다. 실제로 결과는 동일하게 나온다. 그러나 매뉴얼을 보면 jacfwd의 경우에는 행이 더 많은, 긴 직사각형 행렬에 더 유리하고, jacrev는 열이 더 많은, 넓은 직사각형 행렬에 더 유리하다. 또한 정사각형 행렬에 가까울 경우에는 jacfwd가 약간 더 좋다고 한다.

 

그리고 위의 함수는 scalar function이라서 굳이 jacfwd, jacrev를 쓰지 않고 grad로도 구할 수 있다.

 

f = lambda x, y: x**2 + y**2  # simple paraboloid
grad(f, argnums=(0,1))(1.,1.)
>>
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(2., dtype=float32, weak_type=True))

 


 

2. Vector-valued function일 때

 

2-1) 변수 1개

 

vactor-valued function은 사실 상 스칼라함수가 여러 개 있는 것이다. 그러나 grad 함수는 스칼라 함수에 대해서만 적용할 수 있기 때문에 위의 jacfwd를 이용해야 한다.

앞서 본 Jacobian과 Hessian을 보면 jacrev와 jacfwd를 썼는데 

 

from jax import jacfwd, jacrev

f = lambda x: np.array( [ x**2, 4*x ] )  # simple paraboloid (google it...)

print(f'Jacobian = {jacrev(f)(1.)}')
Jacobian = [2. 4.]

결과는 다음과 같이 나온다. grad를 쓰면 오류가 출력되니 주의해야 한다.

 

2-2) 변수 2개 이상

from jax import jacfwd, jacrev

f = lambda x, y: np.array( [ x**3 + 2 * x**2 + 2 * y**3 + y**2, 4*x+3* y**2] )

# jacobian
# [ [ 3x^2 + 4x, 4 ], [ 6y^2 + 2y, 6y ] ]

# Hessian
# [ [6x + 4, 0 ],[ 0, 0 ]]
# [ [ 0, 0 ], [ 12y+2, 6 ] ]

def hessian(f):
    return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums =(0, 1)))

print(f'Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}')
print(f'Full Hessian = {hessian(f)(1., 1.)}')

 

 

 

출력은 다음과 같다.

Jacobian = (DeviceArray([7., 4.], dtype=float32, weak_type=True), DeviceArray([8., 6.], dtype=float32, weak_type=True))
Full Hessian = ((DeviceArray([10.,  0.], dtype=float32, weak_type=True), DeviceArray([0., 0.], dtype=float32, weak_type=True)), (DeviceArray([0., 0.], dtype=float32, weak_type=True), DeviceArray([14.,  6.], dtype=float32, weak_type=True)))

 

Jacobian을 구하면 각 scalar function의 jacobian이 튜플 형태로 나오게 된다.

Hessian을 구하면 또 그 함수에 대해서 x,y에 대한 미분값을 구하므로 자세히 보면 

 

이렇게 각각 묶음이 나온다.

 

(DeviceArray([10., 0.], dtype=float32, weak_type=True), DeviceArray([0., 0.], dtype=float32, weak_type=True))

 

(DeviceArray([0., 0.], dtype=float32, weak_type=True), DeviceArray([14., 6.], dtype=float32, weak_type=True))

 

위의 경우에는 Jacobian 첫 번째에 y 항이 전혀 없어서 한 array는 0으로 나올 수 밖에 없다.

아래의 경우에도 Jacobian 두 번째에 x항이 전혀 없어서 첫 array는 0으로 나온다.

 

 

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[JAX] Gaussian process 파라미터에 따른 결과 visualization  (0) 2023.03.24
[JAX] 기본 Neural Networks 모델  (0) 2023.03.23
[JAX] Cholesky decomposition error 230318 기준  (0) 2023.03.18
[JAX] JAX vmap에 대한 설명  (0) 2023.03.17
[AI] Sampyl에 대한 간단한 설명  (0) 2023.03.05
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [JAX] Gaussian process 파라미터에 따른 결과 visualization
  • [JAX] 기본 Neural Networks 모델
  • [JAX] Cholesky decomposition error 230318 기준
  • [JAX] JAX vmap에 대한 설명
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (460)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    teps
    텝스공부
    pytorch
    WOX
    Zotero
    생산성
    우분투
    Julia
    MATLAB
    Statics
    텝스
    ChatGPT
    고체역학
    LaTeX
    Linear algebra
    서버
    Dear abby
    논문작성
    옵시디언
    에러기록
    Python
    matplotlib
    논문작성법
    인공지능
    수치해석
    딥러닝
    IEEE
    JAX
    Numerical Analysis
    obsidian
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기
상단으로

티스토리툴바