아직 jax가 한글화가 많이 안 되어있어서 기본적인 기능은 내가 적어놓으려고 한다.
jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)
function을 argument axes에 대해서 mapping해주는 기능.
- fun : mapping할 function
- in_axes : function에 들어가는 input을 의미한다. 정수, None, Python container(tuple/list/dict 모두 가능)을 지원한다. 이는 모두 mapping할 input array 축을 의미한다.
- 만약 fun의 argument가 array이면 in_axes에는 정수, None, 튜플(Integer, Nones로 구성된 튜플로 입력 array와 크기 동일)이 들어갈 수 있다. Integer를 in_axes로 넣으면 어떤 축으로 모든 arguments에 대해 map over할 것인지 결정한다.
- None을 넣는다면 어떤 axis로도 map하지 않겠다는 의미이다. 튜플을 넣는다면 각 positional argument가 어떤 축에 map될 것인지를 결정한다. 여기에서 axis integer는 input array의 차원 수(ndim)만큼 [-ndim,ndim] 사이의 값을 넣어야 한다.
- fun의 positional arguements가 container pytree 타입일 경우에는 in_axes의 해당 element가 곧 matching container가 된다고 한다. 이 부분은 잘 이해가 안되어서 아래에 전문을 써놓는다.
- vmap의 argument 중에는 option으로 axis_size를 지정해줄 수 있는데 이렇게 axis_size를 명백하게 지정해놓거나, in_axes의 적어도 하나 argument는 none이 되면 안된다. 이는 입력으로 들어온 in_axes의 사이즈와 mapped arguments와 사이즈가 일치해야하기 때문이라고 한다.
If the positional arguments to fun are container (pytree) types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to fun. See this link for more detail: https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
- out_axes : output에서 나타나는 mapped axis.
- axis_name : Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.
- axis_size : Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments.
솔직히 이 내용만 보면 이해하기가 좀 난해했다. 예제를 보면서 이해한다.
import jax.numpy as jnp
>>>
vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
위의 예제를 보면 vv라는 함수는 벡터끼리 내적을 해주는 함수이다. 그런데 내가 입력 x를 벡터가 아닌 배치 사이즈만큼의 벡터를 넣고 싶다고 하자. 그러면 [a]크기의 1차원의 벡터가 [b,a]크기의 2차원 행렬이 된다. 그리고 입력 y는 [a] 크기의 1차원 벡터를 넣는다고 하자.
생각해서보면 계산은 [b,a] * [a] = [b]가 되어야 한다. 그러나 vv라는 함수에 바로 이 x,y를 넣으면 오류가 난다. 대신 for문을 배치사이즈만큼 반복하는 대신 vmap을 이용하여 빠르게 계산할 수 있다.
이 때 x의 0차원은 broadcasting을 통해 계산되어야 하는 부분이다.
그래서 mv를 보면 in_axes=(0,None)으로 되어있다. x에 해당하는 input은 0차원에 대해서 mapping하고, y에 해당하는 input은 아무 차원도 mapping하지 않겠다는 뜻이다. 그리고 out_axes=0으로 지정해줌으로써 mapping되는 0차원을 그대로 0차원에 넣었다.
또한 mm을 보면 x가 [b,a]이고 y가 [a,c]인 상황을 말한다. 여기서 vmap할 함수는 mv라는 점을 주의하자. y가 [a,c]가 되면 [c]에 해당하는 차원이 늘었다. 따라서 c에 대해서 mv를 mapping해줘야 하는 것이다.
따라서 mm=vmap(mv, (None, 1), 1)을 보면 in_axes=(None,1)로 설정함으로써 [b,a]차원은 그대로 두고 [a,c]에서 c에 해당하는 1축에 대해서 mapping한 다음에, [b,c]차원을 만들기 위해 이 mapping한 c를 out_axes=1로 설정함으로써 두번째 인덱스에 넣는다.
다른 예제로는 아래와 같다. 주석 일부는 내가 달았다.
A, B, C, D = 2, 3, 4, 5
x = jnp.ones((A, B)) # [2,3]
y = jnp.ones((B, C)) # [3,4]
z = jnp.ones((C, D)) # [4,5]
def foo(tree_arg):
x, (y, z) = tree_arg
return jnp.dot(x, jnp.dot(y, z)) # [2,3] * ([3,4] * [4,5])
tree = (x, (y, z))
print(foo(tree)) # [2,5]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
from jax import vmap
K = 6 # batch size
x = jnp.ones((K, A, B)) # batch axis in different locations
y = jnp.ones((B, K, C))
z = jnp.ones((C, D, K))
tree = (x, (y, z))
vfoo = vmap(foo, in_axes=((0, (1, 2)),)) # out_axes=0
print(vfoo(tree).shape)
(6, 2, 5)
위의 vmap을 import하기 전의 코드를 보면 x,y,z는 모두 2차원 행렬이다. 그런데 아래에서는 x,y,z가 3차원이 되고 batch 축이 각자 다른 위치에 있는 상황이다. 그래서 vmap을 할 때 in_axes=((0,(1,2)),) 으로 설정이 되었다.
엄청 복잡하게 써져있어서 이해하기 힘든데 생각하면 다음과 같다.
in_axes=(1,)과 같은 형태처럼 1 대신 (0,(1,2))로 넣은 것이다. 그 이유는 foo에 들어가는 형태가 tree=(x,(y,z))라는 형태 딱 하나의 argument만 받기 때문이다.
그 tree안에 (0,(1,2)) -> (x,(y,z)) ; 즉, x의 0번 인덱스, y의 1번 인덱스, z의 2번 인덱스를 mapping하겠다는 뜻이다.
out_axes는 따로 지정해주지 않아서 결과를 보면 (6,2,5)로 그냥 맨 앞의 인덱스가 배치 사이즈에 해당하는 축이 되었다.
다른 예제로는 입력이 dictionary인 경우이다.
dct = {'a': 0., 'b': jnp.arange(5.)}
x = 1.
def foo(dct, x):
return dct['a'] + dct['b'] + x
out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
print(out)
[1. 2. 3. 4. 5.]
입력에 들어가는 dct에서 key['b']가 벡터인 것을 볼 수 있다. 따라서 다른 값들에 대해 broadcasting하기 위해
in_axes=({'a':None, 'b':0},None)라고 설정해서 'b'에서 0번 인덱스를 mapping하였다.
다음은 out_axes에 대한 특징을 알 수 있는 예제들이다.
print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None),out_axes=(0, None))(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), 8.0)
위를 보면 out_axes=(0,None)으로 했기 때문에 output를 보면 ( mapping된 결과, 그렇지 않은 결과)가 된다.
[0+4, 0+5] 와 (4.)*(2.)인 결과를 튜플로 묶어서 나온 것이다.
out_axes=0으로 두면 mapped axis에 대해서 모두 broadcasting한다.
print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
axis_name을 사용한 예제인데 잘 모르겠다.
xs = jnp.arange(3. * 4.).reshape(3, 4)
print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
[12. 15. 18. 21.]
[12. 15. 18. 21.]]
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 (2) | 2023.03.22 |
---|---|
[JAX] Cholesky decomposition error 230318 기준 (0) | 2023.03.18 |
[AI] Sampyl에 대한 간단한 설명 (0) | 2023.03.05 |
[PyTorch] 개별 파라미터 learning rate 다르게 설정 및 learning rate 확인 (2) | 2023.03.05 |
[JAX] Windows에서도 JAX 사용하기 (0) | 2023.02.24 |