Neural ODE를 구현해놓은 코드는 torchdiffeq인데 학습이 너무 느리다는 생각이 들었다.
여러가지를 테스트해봐야 하는 입장에서 아무리 좋은 GPU를 써도 코드가 뒷받침되지 않으면 학습하는 데 시간이 오래 걸린다.
최근 JAX가 이런 측면에서 효과적이라는 것을 알아서 JAX 기반의 Neural ODE 코드를 찾아보았다.
https://docs.kidger.site/diffrax/
Diffrax
Diffrax in a nutshell Diffrax is a JAX-based library providing numerical differential equation solvers. Features include: ODE/SDE/CDE (ordinary/stochastic/controlled) solvers; lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit
docs.kidger.site
저자 분이 쓴 글을 정리하면서 라이브러리에 익숙해지려고 한다.
1. Neural network (NN) 구성
Diffrax는 Equinox 라는 직접 만든 NN Library를 사용하는데 이에 대한 설명은 다음 사이트에서 확인할 수 있다.
기존에도 JAX를 통해 쉽게 NN을 만들 수 있는 라이브러리가 있었다.
Equinox document에서 말하기를 JAX는 functional programming을 지향하고, 이는 PyTorch의 객체 지향 프로그래밍과 다르기 때문에 parameterized functions을 만드는 것이 두 가지 형태로 형성되었다. Stax이라는 라이브러리처럼 아예 객체 지향 접근을 배제하거나, Objax, Haiku, Flax처럼 객체 지향과 functional programming 사이를 연결하여 JAX와 통합하는 것이다.
Equinox는 PyTrees와 transformations(jit, grad, vmap 등)만을 이용해 객체 지향적으로 parameteried functions을 만들 수 있도록 하였다고 한다. PyTorch와 비슷한 문법으로 구현하지만 내부는 JAX를 쓰기 때문에 PyTorch보다 빠르게 수행할 수 있다.
2. Ordinary differential equation
ODE simulation을 다양한 solver를 통해 수행할 수 있다.
solver의 경우에는 확인한 바로는
Explicit solver
Implicit solver
Implicit-Explicit solver
세 종류가 있다.
ODE solver에는 stiff 또는 non-stiff problem으로 나뉜다.
stiff problem은 일반적인 방법으로도 잘 풀리기 때문에 explicit Runge-Kutta로도 충분하다.
반면에 non-stiff problem은 implicit RK 등의 연산량이 많은 방법이 필요하다.
3. Stochastic differential equation
라이브러리에서는 multivariate인 경우가 없어서 multivariate function인 경우 예시를 만들었다.
import jax.random as jrandom
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree
def f(t, y, args):
return jnp.stack([0, 0, 0], axis=-1)
t0, t1 = 1, 3
drift = lambda t, y, args: f(t,y,args)
diffusion = lambda t, y, args: jnp.ones((3,2)) # check whether the result is same
diffusion = lambda t, y, args: jnp.array([[0.0,1.0],[1.0,0.0],[0.0,0.0]]) # check two noises are different
brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(2,), key=jrandom.PRNGKey(0))
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver = Euler()
saveat = SaveAt(ts=ts, controller_state=True)
sol = diffeqsolve(terms, solver, t0, t1, dt0=0.05, y0=jnp.stack([1.0,1.0,1.0], axis=-1), saveat=saveat)
print(sol.controller_state)
print(sol.ys)
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[PyTorch] 인공지능 재현성을 위한 설정과 주의할 점 (0) | 2023.08.11 |
---|---|
Neural networks의 convergence, convexity에 대한 논문 (0) | 2023.07.31 |
[JAX] 학습한 모델 저장 및 로드 (0) | 2023.06.19 |
[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제 (0) | 2023.06.13 |
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념 (0) | 2023.05.16 |