[JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기
·
연구 Research/인공지능 Artificial Intelligent
이 글은 JAX 버전 맞추느라 여러 🐶고생한 경험을 바탕으로 작성하였다. 0. 요구 버전에 대한 이해 JAX는 설치할 때 요구하는 버전이 있다. 개별 gpu에 따라도 달라져서 까다롭긴한데 JAX currently ships three CUDA wheel variants: CUDA 12.0 and CuDNN 8.8. CUDA 11.8 and CuDNN 8.6. CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued with jax 0.4.8. 위의 세 버전이 가능하다고 하는데, 이는 최신 JAX 버전(230412 기준)에 따른 것이다. CUDA 11.4라고 적혀있는 경우에는 11.4 이상이면서 CuDNN 8.2 이상이면 된다. ..
[JAX] 학습 중 NaN 값이 나올 때 찾는 방법
·
연구 Research/인공지능 Artificial Intelligent
JAX로 학습하는 도중에 NAN값이 나와서 어디부터 원인인지 찾기가 어려웠다. 이럴 때는 아래 코드를 추가하면 된다. from jax.config import config config.update("jax_debug_nans", True) 이렇게 할 경우에 NAN이 발생하는 즉시 어떤 코드에서 문제가 발생하는지를 알려주고 코드가 종료된다.
[JAX] Gaussian process 파라미터에 따른 결과 visualization
·
연구 Research/인공지능 Artificial Intelligent
Gaussian process에서 사용하는 커널 종류는 다양할 수 있지만 여기서는 Radial Basis Fuction을 이용해서 gaussian process 샘플들을 구하고 이에 대한 관찰을 시각화하는 방법에 대해서 이야기한다. RBF 함수는 Paris Perdikaris 교수님의 수업자료를 참고하였다. k(x1,x2)=ηexp((x1x2)22l2)k(x1,x2)=ηexp((x1x2)22l2) 커널함수가 이렇게 설정되어 있을 때 우리가 조절할 수 있는 파라미터는 scale factor인 ηη와 length인 ll이다. 개념적으로 생각하였을 때 random process인 gaussian process는 $\mathbf{x}~\mathcal{N}(\mathbf{0}, ..
[JAX] 기본 Neural Networks 모델
·
연구 Research/인공지능 Artificial Intelligent
가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.차츰 업데이트 할 예정 1. 기본 학습 코드import jaximport jax.numpy as jnpfrom jax import grad, jit, vmapfrom jax import random# Define a simple neural network modeldef init_params(layer_sizes, key): params = [] for i in range(1, len(layer_sizes)): key, subkey = random.split(key) w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i])) b = j..
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기
·
연구 Research/인공지능 Artificial Intelligent
본 글에서는 JAX로 미분값을 구하는 방법에 대해서 다룬다. JAX에서는 미분값을 구하기 위해 grad, jacfwd, jacrev를 제공하기 때문에 몇 가지 예제를 통해서 익숙해지고자 한다. 일단 크게 scalar-valued function과 vector-valued function으로 나누고, 각 function이 한 개의 변수에만 의존하는지, 또는 두 개 이상의 변수에만 의존하는지를 따진다. 예제코드는 유튜브 튜토리얼 + JAX 매뉴얼을 참고하였다. 1. Scalar-valued function일 때 Gradient는 scalar-valued univariate function에 대한 기울기 Jacobian은 vector-valued or scalar-valued multivariate funct..
[JAX] Cholesky decomposition error 230318 기준
·
연구 Research/인공지능 Artificial Intelligent
현재 기준(230318)으로 cholesky decomposition을 사용할 때 행렬 크기가 50정도 넘어가면 nan을 출력하는 오류가 있다. jnp.linalg.cholesky(K) jax.random.multivariate_normal(subkeys[0], np.zeros((N_samples,)), K) 이 때문에 cholesky decomposition을 쓰는 다른 함수들도 영향을 받았는데 jax.random.multivariate_normal의 경우에도 랜덤하게 추출하는 과정에서 cholesky decomposition을 쓴다. cholesky decomposition은 어쩔 수 없을 것 같고 jax.random.multivariate_normal(subkeys[0], np.zeros((N_sa..
[JAX] JAX vmap에 대한 설명
·
연구 Research/인공지능 Artificial Intelligent
아직 jax가 한글화가 많이 안 되어있어서 기본적인 기능은 내가 적어놓으려고 한다. jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None) function을 argument axes에 대해서 mapping해주는 기능. fun : mapping할 function in_axes : function에 들어가는 input을 의미한다. 정수, None, Python container(tuple/list/dict 모두 가능)을 지원한다. 이는 모두 mapping할 input array 축을 의미한다. 만약 fun의 argument가 array이면 in_axes에는 정수, None, 튜플(Integer, None..
[AI] Sampyl에 대한 간단한 설명
·
연구 Research/인공지능 Artificial Intelligent
Sampyl이라는 라이브러리를 사용하기 전에 간단한 예제가 있어서 posterior distribution과 sampler에 대한 tutorial을 좀 정리해보았다. YN(μ,σ2)μ=β0+β1x1+β2x2 다음과 같이 β가 계수의 어떤 y와 x 간의 모델이 있다고 해보자. 이 문제에서는 β=[2,1,4]인 경우이다. 따라서 아래와 같이 x0=1,x1,x2에 대하여 y가 분포되어있다. # Number of data points N = 200 # True parameters sigma = 1 true_b = ..
[PyTorch] 개별 파라미터 learning rate 다르게 설정 및 learning rate 확인
·
연구 Research/인공지능 Artificial Intelligent
코드 안에 네트워크가 2개가 있고 이 2개의 네트워크를 각각 다른 learning rate로 학습하고 싶을 때 사용하는 코드다. 아래와 같이 개별로 learning rate를 설정하면 net2 안에 있는 파라미터는 0.001로 학습되고 net1 안에 있는 파라미터는 0.01로 학습된다. optimizer = optim.Adam([ {'params': func.net1.parameters()}, {'params': func.net2.parameters(), 'lr': 0.001} ], lr=0.01) optimizer.param_groups[0]['capturable'] = True print(optimizer.param_groups[0]['lr']) print(optimizer.param_groups[1..
[JAX] Windows에서도 JAX 사용하기
·
연구 Research/인공지능 Artificial Intelligent
JAX는 아직 리눅스에서밖에 사용이 안된다. 그래서 윈도우에서 돌릴 수 있는 방법을 찾아보았는데 최신 버전은 불가능하고 예전 버전은 가능하다. JAX가 아직 초기이다보니 버전마다 많이 바뀌어서 불편한 점이 있지만 일단 시도한 경험을 공유한다. 아래 링크를 들어가면 대략적인 instruction을 알 수 있다. https://github.com/cloudhan/jax-windows-builder GitHub - cloudhan/jax-windows-builder: A community supported Windows build for jax. A community supported Windows build for jax. Contribute to cloudhan/jax-windows-builder deve..