[JAX] vmap과 jit의 속도
·
연구 Research/인공지능 Artificial Intelligent
JAX를 쓰다보니 분명 vmap을 사용했음에도 안에 있는 루프는 빨리 되지만 정작 vmap을 나올 때 느려지는 현상을 발견하였다. 구체적으로 알아보려면 jax.make_jaxpr 또는 jax.block_until_ready를 써보라고 하는데 make_jaxpr 같은 경우에는 컴파일할 때 각 변수 flow를 보여주는 역할을 해서 도움이 될 수도 있다. vmap과 jit의 시간 차이를 알아보려면 아래의 예제를 사용해볼 수 있다. from functools import partial from timeit import timeit from jax import vmap, jit, random, numpy as jnp n, d = 512, 64 a = random.normal(random.PRNGKey(0), (n..