다음과 같이 jax.lax.stop_gradient로 묶어준 결과를 사용하면 그 이전까지는 gradient가 기록되지 않는다.
jax.lax.stop_gradient(sol.ts)
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] vmap과 jit의 속도 (0) | 2023.09.20 |
---|---|
[인공지능] 인공지능 라이브러리 정리 (0) | 2023.08.24 |
[PyTorch] 인공지능 재현성을 위한 설정과 주의할 점 (0) | 2023.08.11 |
Neural networks의 convergence, convexity에 대한 논문 (0) | 2023.07.31 |
[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax (0) | 2023.07.28 |