현재 기준(230318)으로 cholesky decomposition을 사용할 때 행렬 크기가 50정도 넘어가면 nan을 출력하는 오류가 있다.
jnp.linalg.cholesky(K)
jax.random.multivariate_normal(subkeys[0], np.zeros((N_samples,)), K)
이 때문에 cholesky decomposition을 쓰는 다른 함수들도 영향을 받았는데
jax.random.multivariate_normal의 경우에도 랜덤하게 추출하는 과정에서 cholesky decomposition을 쓴다.
cholesky decomposition은 어쩔 수 없을 것 같고
jax.random.multivariate_normal(subkeys[0], np.zeros((N_samples,)), K, method='svd')
jax.random.multivariate_normal의 경우에는 다른 메소드(위에서는 svd)를 쓰면 에러없이 사용할 수 있다.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] 기본 Neural Networks 모델 (0) | 2023.03.23 |
---|---|
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 (2) | 2023.03.22 |
[JAX] JAX vmap에 대한 설명 (0) | 2023.03.17 |
[AI] Sampyl에 대한 간단한 설명 (0) | 2023.03.05 |
[PyTorch] 개별 파라미터 learning rate 다르게 설정 및 learning rate 확인 (2) | 2023.03.05 |