[JAX] 지속적인 kernel crash 여러 가지 원인
·
카테고리 없음
jupyer lab/notebook을 쓰다가 kernel이 터지는 경우가 있는데 다음과 같은 경우들이 있다. https://github.com/microsoft/vscode-jupyter/wiki/Kernel-crashes Kernel crashesVS Code Jupyter extension. Contribute to microsoft/vscode-jupyter development by creating an account on GitHub.github.com  내가 겪은 대표적인 원인들은 다음과 같다.1. 라이브러리 설치 실패 - 호환 안됨 등의 문제2. gpu에 업로드한 데이터가 너무 많아서 문제  1번의 경우에는 재설치하고 버전 호환을 신경써서 설치해야한다. 아나콘다 등의 버전 관리 시스템을 쓰면..
[JAX] L-BFGS optimizer로 학습하는 예제 코드
·
연구 Research/인공지능 Artificial Intelligent
Quasi Newton method를 이용해서 JAX에서 최적화를 시키고 싶었는데 JAX 자체가 Deep learning에 포커스가 있고 대부분 딥러닝이 Gradient descent method로 최적화를 하다보니 라이브러리를 찾게 되었다. 대부분 jaxopt라는 라이브러리를 추천했기 때문에 이걸로 수행해보았다. jaxopt에는 jaxopt.ScipyMinimize와 jaxopt.LBFGS가 있는데 다른 분들의 시도를 보니 ScipyMinimize가 더 성능이 괜찮은 것 같다. ScipyMinimize는 scipy에 있는 최적화를 사용한 것이고 LBFGS는 직접 만든 것 같은데 line search 방법 등이 다르다고 한다. import jax import jax.numpy as np from jax ..
[JAX] optax에서 learning rate 확인하는 방법
·
프로그래밍 Programming
inject_hyperparams라는 함수로 optax의 optimizer를 묶어서 사용하면 hyperparams를 관찰할 수 있다. # Wrap the optimizer to inject the hyperparameters optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule) def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params: opt_state = optimizer.init(params) # Since we injected hyperparams, we can access them directly here print(f'A..
[JAX] NaN, Inf 값 처리 및 조건에 맞는 요소 찾기
·
프로그래밍 Programming
JAX에서 값을 쓰다가 어려운 부분이 있으면 대체로 numpy에 있는 함수들과 비슷한 것이 많아서 찾기 쉽다. 1. NaN 값 찾기 import jax.numpy as jnp a = jnp.array([jnp.nan,1,0,jnp.nan]) x = jnp.isnan(a) print(x) >> [ True False False True] 2. +-Inf 값 찾기 import jax.numpy as jnp a = jnp.array([jnp.inf,1,0,-jnp.inf]) x = jnp.isinf(a) print(x) >> [ True False False True] 1,2번 항목을 보면 boolean array으로 나오기 때문에 위 코드에서 'a'라는 array에 대해 indexing하면 NaN값을 추출할..
[JAX] 학습한 모델 저장 및 로드
·
연구 Research/인공지능 Artificial Intelligent
기존에 자주 사용되면 파이토치나 텐서플로에서도 그렇듯이 학습한 모델을 저장하는 것은 필수이다. 나중에 다시 결과를 출력해야하거나 Transfer learning 등에 활용해야하기 때문이다. 이 글에서는 JAX 모델을 학습한 다음, 저장하고 다시 로드하는 방법에 대해서 다룬다. 1. 학습 후 저장 이 글에서는 model이라는 class 안에 optimizers를 정의하고 그 안에 있는 loss 등의 함수로 학습을 하고 있었다. class 안에 self.opt_init, \ self.opt_update, \ self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, decay_steps=1000, decay_rate=0.9)) self.opt_..
[JAX] 메모리 부족 문제 해결
·
연구 Research/인공지능 Artificial Intelligent
JAX를 쓰다가 너무 많은 양의 데이터를 쓰다보니 메모리 부족(OOM: Out of memory) 현상을 겪었다. 근본적인 해결책은 달리 없다. 데이터가 너무 많아서 생기는 문제이니 데이터 양을 줄이던가 아니면 병렬 컴퓨팅을 하는 방법이 있다. 병렬 컴퓨팅을 간단한 코드에서는 실행해보았는데 큰 네트워크에서는 해본 적이 없다. 일단 임시방편으로는 다음과 같다. 메모리 부족이 쉽게 발생할 수 있는 이유는 JAX에서 처음에 import를 할 때 대부분의 메모리를 미리 할당해놓기 때문이다. 따라서 이 preallocation을 막거나 줄이면 도움이 된다. 1. Preallocation 중단 XLA_PYTHON_CLIENT_PREALLOCATE=false # 구체적으로는 다음과 같이 구현한다. import os ..
[JAX] device 확인, default device 설정
·
연구 Research/인공지능 Artificial Intelligent
JAX에서 사용 가능한 device를 찾는 방법은 다음과 같다. import jax jax.devices() >> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)] 여기에서는 device가 4개이기 때문에 하나를 지정해서 쓰고 싶을 수 있다. 방법을 찾아보니 document에서는 jax.default_device = jax.devices("gpu")[2] # default로 세번째 gpu를 쓰고 싶은 경우 다음과 같이 쓰면 된다고 했지만 실제로는 적용이 되지 않았다. 좀 더 확실한 방법으로는, JAX를 impor..
[JAX] 학습 중 NaN 값이 나올 때 찾는 방법
·
연구 Research/인공지능 Artificial Intelligent
JAX로 학습하는 도중에 NAN값이 나와서 어디부터 원인인지 찾기가 어려웠다. 이럴 때는 아래 코드를 추가하면 된다. from jax.config import config config.update("jax_debug_nans", True) 이렇게 할 경우에 NAN이 발생하는 즉시 어떤 코드에서 문제가 발생하는지를 알려주고 코드가 종료된다.
[JAX] 기본 Neural Networks 모델
·
연구 Research/인공지능 Artificial Intelligent
가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.차츰 업데이트 할 예정 1. 기본 학습 코드import jaximport jax.numpy as jnpfrom jax import grad, jit, vmapfrom jax import random# Define a simple neural network modeldef init_params(layer_sizes, key): params = [] for i in range(1, len(layer_sizes)): key, subkey = random.split(key) w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i])) b = j..
[JAX] Cholesky decomposition error 230318 기준
·
연구 Research/인공지능 Artificial Intelligent
현재 기준(230318)으로 cholesky decomposition을 사용할 때 행렬 크기가 50정도 넘어가면 nan을 출력하는 오류가 있다. jnp.linalg.cholesky(K) jax.random.multivariate_normal(subkeys[0], np.zeros((N_samples,)), K) 이 때문에 cholesky decomposition을 쓰는 다른 함수들도 영향을 받았는데 jax.random.multivariate_normal의 경우에도 랜덤하게 추출하는 과정에서 cholesky decomposition을 쓴다. cholesky decomposition은 어쩔 수 없을 것 같고 jax.random.multivariate_normal(subkeys[0], np.zeros((N_sa..