Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- obsidian
- 수식삽입
- Dear abby
- ChatGPT
- Linear algebra
- 딥러닝
- WOX
- 우분투
- Statics
- matplotlib
- 텝스공부
- 에러기록
- 텝스
- LaTeX
- 옵시디언
- JAX
- 인공지능
- Zotero
- 수치해석
- teps
- Julia
- Numerical Analysis
- 논문작성
- pytorch
- 논문작성법
- Python
- 고체역학
- 생산성
- MATLAB
- IEEE
Archives
- Today
- Total
뛰는 놈 위에 나는 공대생
[JAX] vmap과 jit의 속도 본문
JAX를 쓰다보니 분명 vmap을 사용했음에도 안에 있는 루프는 빨리 되지만 정작 vmap을 나올 때 느려지는 현상을 발견하였다. 구체적으로 알아보려면
jax.block_until_ready를 써보라고 하는데
make_jaxpr 같은 경우에는 컴파일할 때 각 변수 flow를 보여주는 역할을 해서 도움이 될 수도 있다.
vmap과 jit의 시간 차이를 알아보려면 아래의 예제를 사용해볼 수 있다.
from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp
n, d = 512, 64
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))
mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))
for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
run = lambda: f(a, b).block_until_ready()
t = timeit(run, setup=run, number=1000)
print(f'{t:.3f}')
각 소요시간을 보면
0.104
0.412
0.083
0.274
0.084
vmap 안에 jit 함수를 넣는 것보다 vmap을 jit으로 묶는 것이 더 좋음을 알 수 있다.
vmap이 jit과 거의 비슷하다고 느꼈는데 실제로는 jit을 쓰고 안 쓰고가 성능에 큰 영향을 준다. 대신 jit을 사용하려면 class를 처리하거나 static 변수를 입력을 넣으면 안되는 등 신경써야 할 부분이 많다.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[Deep learning] Bayesian Neural Network (1) (0) | 2023.10.23 |
---|---|
[인공지능] Learning에서 scaling이 중요한가 (0) | 2023.09.20 |
[인공지능] 인공지능 라이브러리 정리 (0) | 2023.08.24 |
[JAX] JAX에서 gradient 추척을 멈추는 방법 (0) | 2023.08.22 |
[PyTorch] 인공지능 재현성을 위한 설정과 주의할 점 (0) | 2023.08.11 |
Comments