Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] vmap과 jit의 속도 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] vmap과 jit의 속도

보통의공대생 2023. 9. 20. 14:04

JAX를 쓰다보니 분명 vmap을 사용했음에도 안에 있는 루프는 빨리 되지만 정작 vmap을 나올 때 느려지는 현상을 발견하였다. 구체적으로 알아보려면

 

jax.make_jaxpr 또는

jax.block_until_ready를 써보라고 하는데 

 

make_jaxpr 같은 경우에는 컴파일할 때 각 변수 flow를 보여주는 역할을 해서 도움이 될 수도 있다.

 

vmap과 jit의 시간 차이를 알아보려면 아래의 예제를 사용해볼 수 있다.

 

from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp

n, d = 512, 64
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))

mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))

for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
  run = lambda: f(a, b).block_until_ready()
  t = timeit(run, setup=run, number=1000)
  print(f'{t:.3f}')

각 소요시간을 보면

0.104
0.412
0.083
0.274
0.084

vmap 안에 jit 함수를 넣는 것보다 vmap을 jit으로 묶는 것이 더 좋음을 알 수 있다.

 

vmap이 jit과 거의 비슷하다고 느꼈는데 실제로는 jit을 쓰고 안 쓰고가 성능에 큰 영향을 준다. 대신 jit을 사용하려면 class를 처리하거나 static 변수를 입력을 넣으면 안되는 등 신경써야 할 부분이 많다.

Comments