[JAX] vmap과 jit의 속도

2023. 9. 20. 14:04·연구 Research/인공지능 Artificial Intelligent

JAX를 쓰다보니 분명 vmap을 사용했음에도 안에 있는 루프는 빨리 되지만 정작 vmap을 나올 때 느려지는 현상을 발견하였다. 구체적으로 알아보려면

 

jax.make_jaxpr 또는

jax.block_until_ready를 써보라고 하는데 

 

make_jaxpr 같은 경우에는 컴파일할 때 각 변수 flow를 보여주는 역할을 해서 도움이 될 수도 있다.

 

vmap과 jit의 시간 차이를 알아보려면 아래의 예제를 사용해볼 수 있다.

 

from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp

n, d = 512, 64
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))

mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))

for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
  run = lambda: f(a, b).block_until_ready()
  t = timeit(run, setup=run, number=1000)
  print(f'{t:.3f}')

각 소요시간을 보면

0.104
0.412
0.083
0.274
0.084

vmap 안에 jit 함수를 넣는 것보다 vmap을 jit으로 묶는 것이 더 좋음을 알 수 있다.

 

vmap이 jit과 거의 비슷하다고 느꼈는데 실제로는 jit을 쓰고 안 쓰고가 성능에 큰 영향을 준다. 대신 jit을 사용하려면 class를 처리하거나 static 변수를 입력을 넣으면 안되는 등 신경써야 할 부분이 많다.

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[Deep learning] Bayesian Neural Network (1)  (0) 2023.10.23
[인공지능] Learning에서 scaling이 중요한가  (0) 2023.09.20
[인공지능] 인공지능 라이브러리 정리  (0) 2023.08.24
[JAX] JAX에서 gradient 추척을 멈추는 방법  (0) 2023.08.22
[PyTorch] 인공지능 재현성을 위한 설정과 주의할 점  (0) 2023.08.11
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [Deep learning] Bayesian Neural Network (1)
  • [인공지능] Learning에서 scaling이 중요한가
  • [인공지능] 인공지능 라이브러리 정리
  • [JAX] JAX에서 gradient 추척을 멈추는 방법
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (459)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (6)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    IEEE
    텝스
    생산성
    수치해석
    텝스공부
    딥러닝
    Numerical Analysis
    ChatGPT
    WOX
    Julia
    옵시디언
    LaTeX
    Zotero
    matplotlib
    obsidian
    Dear abby
    에러기록
    고체역학
    Statics
    pytorch
    우분투
    JAX
    인공지능
    Linear algebra
    teps
    서버
    논문작성
    MATLAB
    Python
    논문작성법
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] vmap과 jit의 속도
상단으로

티스토리툴바