[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax
·
연구 Research/인공지능 Artificial Intelligent
Neural ODE를 구현해놓은 코드는 torchdiffeq인데 학습이 너무 느리다는 생각이 들었다. 여러가지를 테스트해봐야 하는 입장에서 아무리 좋은 GPU를 써도 코드가 뒷받침되지 않으면 학습하는 데 시간이 오래 걸린다. 최근 JAX가 이런 측면에서 효과적이라는 것을 알아서 JAX 기반의 Neural ODE 코드를 찾아보았다. https://docs.kidger.site/diffrax/ Diffrax Diffrax in a nutshell Diffrax is a JAX-based library providing numerical differential equation solvers. Features include: ODE/SDE/CDE (ordinary/stochastic/controlled) sol..