[JAX] JAX vmap에 대한 설명

2023. 3. 17. 20:43·연구 Research/인공지능 Artificial Intelligent

아직 jax가 한글화가 많이 안 되어있어서 기본적인 기능은 내가 적어놓으려고 한다.

 

jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)

function을 argument axes에 대해서 mapping해주는 기능.

 

  • fun : mapping할 function
  • in_axes : function에 들어가는 input을 의미한다. 정수, None, Python container(tuple/list/dict 모두 가능)을 지원한다. 이는 모두 mapping할 input array 축을 의미한다.
    • 만약 fun의 argument가 array이면 in_axes에는 정수, None, 튜플(Integer, Nones로 구성된 튜플로 입력 array와 크기 동일)이 들어갈 수 있다. Integer를 in_axes로 넣으면 어떤 축으로 모든 arguments에 대해 map over할 것인지 결정한다.
    • None을 넣는다면 어떤 axis로도 map하지 않겠다는 의미이다. 튜플을 넣는다면 각 positional argument가 어떤 축에 map될 것인지를 결정한다. 여기에서 axis integer는 input array의 차원 수(ndim)만큼 [-ndim,ndim] 사이의 값을 넣어야 한다.
    • fun의 positional arguements가 container pytree 타입일 경우에는 in_axes의 해당 element가 곧 matching container가 된다고 한다. 이 부분은 잘 이해가 안되어서 아래에 전문을 써놓는다.
    • vmap의 argument 중에는 option으로 axis_size를 지정해줄 수 있는데 이렇게 axis_size를 명백하게 지정해놓거나, in_axes의 적어도 하나 argument는 none이 되면 안된다. 이는 입력으로 들어온 in_axes의 사이즈와 mapped arguments와 사이즈가 일치해야하기 때문이라고 한다.
If the positional arguments to fun are container (pytree) types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to fun. See this link for more detail: https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees

 

  • out_axes : output에서 나타나는 mapped axis.
  • axis_name : Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.
  • axis_size : Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments.

솔직히 이 내용만 보면 이해하기가 좀 난해했다. 예제를 보면서 이해한다.

import jax.numpy as jnp
>>>
vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

 

위의 예제를 보면 vv라는 함수는 벡터끼리 내적을 해주는 함수이다. 그런데 내가 입력 x를 벡터가 아닌 배치 사이즈만큼의 벡터를 넣고 싶다고 하자. 그러면 [a]크기의 1차원의 벡터가 [b,a]크기의 2차원 행렬이 된다. 그리고 입력 y는 [a] 크기의 1차원 벡터를 넣는다고 하자.

 

생각해서보면 계산은 [b,a] * [a] = [b]가 되어야 한다. 그러나 vv라는 함수에 바로 이 x,y를 넣으면 오류가 난다. 대신 for문을 배치사이즈만큼 반복하는 대신 vmap을 이용하여 빠르게 계산할 수 있다.

이 때 x의 0차원은 broadcasting을 통해 계산되어야 하는 부분이다.

그래서 mv를 보면 in_axes=(0,None)으로 되어있다. x에 해당하는 input은 0차원에 대해서 mapping하고, y에 해당하는 input은 아무 차원도 mapping하지 않겠다는 뜻이다. 그리고 out_axes=0으로 지정해줌으로써 mapping되는 0차원을 그대로 0차원에 넣었다.

 

또한 mm을 보면 x가 [b,a]이고 y가 [a,c]인 상황을 말한다. 여기서 vmap할 함수는 mv라는 점을 주의하자. y가 [a,c]가 되면 [c]에 해당하는 차원이 늘었다. 따라서 c에 대해서 mv를 mapping해줘야 하는 것이다.

따라서 mm=vmap(mv, (None, 1), 1)을 보면 in_axes=(None,1)로 설정함으로써 [b,a]차원은 그대로 두고 [a,c]에서 c에 해당하는 1축에 대해서 mapping한 다음에, [b,c]차원을 만들기 위해 이 mapping한 c를 out_axes=1로 설정함으로써 두번째 인덱스에 넣는다.

 


다른 예제로는 아래와 같다. 주석 일부는 내가 달았다.

A, B, C, D = 2, 3, 4, 5
x = jnp.ones((A, B)) # [2,3]
y = jnp.ones((B, C)) # [3,4]
z = jnp.ones((C, D)) # [4,5]
def foo(tree_arg):
  x, (y, z) = tree_arg
  return jnp.dot(x, jnp.dot(y, z)) # [2,3] * ([3,4] * [4,5])
tree = (x, (y, z))
print(foo(tree)) # [2,5]
[[12. 12. 12. 12. 12.]
 [12. 12. 12. 12. 12.]]
from jax import vmap
K = 6  # batch size
x = jnp.ones((K, A, B))  # batch axis in different locations
y = jnp.ones((B, K, C))
z = jnp.ones((C, D, K))
tree = (x, (y, z))
vfoo = vmap(foo, in_axes=((0, (1, 2)),)) # out_axes=0
print(vfoo(tree).shape)
(6, 2, 5)

위의 vmap을 import하기 전의 코드를 보면 x,y,z는 모두 2차원 행렬이다. 그런데 아래에서는 x,y,z가 3차원이 되고 batch 축이 각자 다른 위치에 있는 상황이다. 그래서 vmap을 할 때 in_axes=((0,(1,2)),) 으로 설정이 되었다.

엄청 복잡하게 써져있어서 이해하기 힘든데 생각하면 다음과 같다.

 

in_axes=(1,)과 같은 형태처럼 1 대신 (0,(1,2))로 넣은 것이다. 그 이유는 foo에 들어가는 형태가 tree=(x,(y,z))라는 형태 딱 하나의 argument만 받기 때문이다.

그 tree안에 (0,(1,2)) -> (x,(y,z)) ; 즉, x의 0번 인덱스, y의 1번 인덱스, z의 2번 인덱스를 mapping하겠다는 뜻이다.

out_axes는 따로 지정해주지 않아서 결과를 보면 (6,2,5)로 그냥 맨 앞의 인덱스가 배치 사이즈에 해당하는 축이 되었다.


다른 예제로는 입력이 dictionary인 경우이다. 

dct = {'a': 0., 'b': jnp.arange(5.)}
x = 1.
def foo(dct, x):
 return dct['a'] + dct['b'] + x
out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
print(out)
[1. 2. 3. 4. 5.]

 

입력에 들어가는 dct에서 key['b']가 벡터인 것을 볼 수 있다. 따라서 다른 값들에 대해 broadcasting하기 위해

in_axes=({'a':None, 'b':0},None)라고 설정해서 'b'에서 0번 인덱스를 mapping하였다.

 


다음은 out_axes에 대한 특징을 알 수 있는 예제들이다.

 

print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None),out_axes=(0, None))(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), 8.0)

위를 보면 out_axes=(0,None)으로 했기 때문에 output를 보면 (  mapping된 결과, 그렇지 않은 결과)가 된다.

 

[0+4, 0+5] 와 (4.)*(2.)인 결과를 튜플로 묶어서 나온 것이다.

 

out_axes=0으로 두면 mapped axis에 대해서 모두 broadcasting한다.

print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))

axis_name을 사용한 예제인데 잘 모르겠다.

xs = jnp.arange(3. * 4.).reshape(3, 4)
print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
 [12. 15. 18. 21.]
 [12. 15. 18. 21.]]
저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기  (2) 2023.03.22
[JAX] Cholesky decomposition error 230318 기준  (0) 2023.03.18
[AI] Sampyl에 대한 간단한 설명  (0) 2023.03.05
[PyTorch] 개별 파라미터 learning rate 다르게 설정 및 learning rate 확인  (2) 2023.03.05
[JAX] Windows에서도 JAX 사용하기  (0) 2023.02.24
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [JAX] Gradient, Jacobian, Hessian 등 미분값 구하기
  • [JAX] Cholesky decomposition error 230318 기준
  • [AI] Sampyl에 대한 간단한 설명
  • [PyTorch] 개별 파라미터 learning rate 다르게 설정 및 learning rate 확인
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (471)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (35)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (1)
      • 수치해석 Numerical Analysis (28)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (95)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (56)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 실험 Experiment (1)
      • 유학 생활 Daily (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    teps
    딥러닝
    수치해석
    obsidian
    우분투
    MATLAB
    IEEE
    고체역학
    LaTeX
    서버
    에러기록
    JAX
    텝스공부
    텝스
    Python
    Statics
    생산성
    Julia
    ChatGPT
    옵시디언
    Numerical Analysis
    pytorch
    Zotero
    Linear algebra
    논문작성
    WOX
    인공지능
    논문작성법
    Dear abby
    matplotlib
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] JAX vmap에 대한 설명
상단으로

티스토리툴바