JAX로 학습하는 도중에 NAN값이 나와서 어디부터 원인인지 찾기가 어려웠다. 이럴 때는 아래 코드를 추가하면 된다.
from jax.config import config
config.update("jax_debug_nans", True)
이렇게 할 경우에 NAN이 발생하는 즉시 어떤 코드에서 문제가 발생하는지를 알려주고 코드가 종료된다.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] device 확인, default device 설정 (0) | 2023.04.13 |
---|---|
[JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기 (0) | 2023.04.12 |
[JAX] Gaussian process 파라미터에 따른 결과 visualization (0) | 2023.03.24 |
[JAX] 기본 Neural Networks 모델 (0) | 2023.03.23 |
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 (2) | 2023.03.22 |