[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax

2023. 7. 28. 14:57·연구 Research/인공지능 Artificial Intelligent

 

Neural ODE를 구현해놓은 코드는 torchdiffeq인데 학습이 너무 느리다는 생각이 들었다.

여러가지를 테스트해봐야 하는 입장에서 아무리 좋은 GPU를 써도 코드가 뒷받침되지 않으면 학습하는 데 시간이 오래 걸린다.

최근 JAX가 이런 측면에서 효과적이라는 것을 알아서 JAX 기반의 Neural ODE 코드를 찾아보았다.

 

https://docs.kidger.site/diffrax/

 

Diffrax

Diffrax in a nutshell Diffrax is a JAX-based library providing numerical differential equation solvers. Features include: ODE/SDE/CDE (ordinary/stochastic/controlled) solvers; lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit

docs.kidger.site

 

저자 분이 쓴 글을 정리하면서 라이브러리에 익숙해지려고 한다.

 


1. Neural network (NN) 구성

 

Diffrax는 Equinox 라는 직접 만든 NN Library를 사용하는데 이에 대한 설명은 다음 사이트에서 확인할 수 있다.

기존에도 JAX를 통해 쉽게 NN을 만들 수 있는 라이브러리가 있었다.

 

Equinox document에서 말하기를 JAX는 functional programming을 지향하고, 이는 PyTorch의 객체 지향 프로그래밍과 다르기 때문에 parameterized functions을 만드는 것이 두 가지 형태로 형성되었다. Stax이라는 라이브러리처럼 아예 객체 지향 접근을 배제하거나, Objax, Haiku, Flax처럼 객체 지향과 functional programming 사이를 연결하여 JAX와 통합하는 것이다.

 

Equinox는 PyTrees와 transformations(jit, grad, vmap 등)만을 이용해 객체 지향적으로 parameteried functions을 만들 수 있도록 하였다고 한다. PyTorch와 비슷한 문법으로 구현하지만 내부는 JAX를 쓰기 때문에 PyTorch보다 빠르게 수행할 수 있다.

 

 

2. Ordinary differential equation

 

ODE simulation을 다양한 solver를 통해 수행할 수 있다.

solver의 경우에는 확인한 바로는

Explicit solver

Implicit solver

Implicit-Explicit solver

세 종류가 있다.

 

ODE solver에는 stiff 또는 non-stiff problem으로 나뉜다.

stiff problem은 일반적인 방법으로도 잘 풀리기 때문에 explicit Runge-Kutta로도 충분하다.

반면에 non-stiff problem은 implicit RK 등의 연산량이 많은 방법이 필요하다.

 

 

 

3. Stochastic differential equation

 

 

라이브러리에서는 multivariate인 경우가 없어서 multivariate function인 경우 예시를 만들었다.

import jax.random as jrandom
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree

def f(t, y, args):
    return jnp.stack([0, 0, 0], axis=-1)

t0, t1 = 1, 3

drift = lambda t, y, args: f(t,y,args)
diffusion = lambda t, y, args: jnp.ones((3,2)) # check whether the result is same
diffusion = lambda t, y, args: jnp.array([[0.0,1.0],[1.0,0.0],[0.0,0.0]]) # check two noises are different

brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(2,), key=jrandom.PRNGKey(0))
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver = Euler()
saveat = SaveAt(ts=ts, controller_state=True)

sol = diffeqsolve(terms, solver, t0, t1, dt0=0.05, y0=jnp.stack([1.0,1.0,1.0], axis=-1), saveat=saveat)
print(sol.controller_state)
print(sol.ys)

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[PyTorch] 인공지능 재현성을 위한 설정과 주의할 점  (0) 2023.08.11
Neural networks의 convergence, convexity에 대한 논문  (0) 2023.07.31
[JAX] 학습한 모델 저장 및 로드  (0) 2023.06.19
[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제  (0) 2023.06.13
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념  (0) 2023.05.16
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [PyTorch] 인공지능 재현성을 위한 설정과 주의할 점
  • Neural networks의 convergence, convexity에 대한 논문
  • [JAX] 학습한 모델 저장 및 로드
  • [JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (468)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (27)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 실험 Experiment (1)
      • 유학 생활 Daily (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    LaTeX
    우분투
    텝스
    WOX
    텝스공부
    Zotero
    인공지능
    고체역학
    Linear algebra
    teps
    옵시디언
    matplotlib
    obsidian
    Julia
    MATLAB
    Numerical Analysis
    논문작성
    Python
    JAX
    IEEE
    ChatGPT
    논문작성법
    생산성
    딥러닝
    Statics
    서버
    수치해석
    에러기록
    pytorch
    Dear abby
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax
상단으로

티스토리툴바