[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax
·
연구 Research/인공지능 Artificial Intelligent
Neural ODE를 구현해놓은 코드는 torchdiffeq인데 학습이 너무 느리다는 생각이 들었다. 여러가지를 테스트해봐야 하는 입장에서 아무리 좋은 GPU를 써도 코드가 뒷받침되지 않으면 학습하는 데 시간이 오래 걸린다. 최근 JAX가 이런 측면에서 효과적이라는 것을 알아서 JAX 기반의 Neural ODE 코드를 찾아보았다. https://docs.kidger.site/diffrax/ Diffrax Diffrax in a nutshell Diffrax is a JAX-based library providing numerical differential equation solvers. Features include: ODE/SDE/CDE (ordinary/stochastic/controlled) sol..
[JAX] 학습한 모델 저장 및 로드
·
연구 Research/인공지능 Artificial Intelligent
기존에 자주 사용되면 파이토치나 텐서플로에서도 그렇듯이 학습한 모델을 저장하는 것은 필수이다. 나중에 다시 결과를 출력해야하거나 Transfer learning 등에 활용해야하기 때문이다. 이 글에서는 JAX 모델을 학습한 다음, 저장하고 다시 로드하는 방법에 대해서 다룬다. 1. 학습 후 저장 이 글에서는 model이라는 class 안에 optimizers를 정의하고 그 안에 있는 loss 등의 함수로 학습을 하고 있었다. class 안에 self.opt_init, \ self.opt_update, \ self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, decay_steps=1000, decay_rate=0.9)) self.opt_..
[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제
·
연구 Research/인공지능 Artificial Intelligent
JAX를 통해 병렬로 뉴럴 네트워크를 학습하는 예제를 고민하였다.JAX에서 제공해주는 예제도 있지만 이는 아주 심플한 선형 모델의 파라미터를 regression하는 문제이기 때문에 실제 뉴럴 네트워크 모델과는 괴리가 좀 있어서 직접 예제를 만들었다. JAX 0.3.1. 버전 1. JAX에서 제공하는 예제 import jaxjax.devices()>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)] 필자는 gpu 4개를 가지고 병렬 컴퓨팅을 사용했다. import numpy as npimport jax.numpy..
[matplotlib] x,y축 format 지정하는 방법
·
연구 Research/데이터과학 Data Science
matplotlib에서 log scale그래프를 그리다가 다음과 같이 y축 숫자표기가 너무 크다는 것을 발견하고 이를 수정하기 위한 코드를 작성하였다. 여러 방법을 찾아보긴 했는데 내가 느끼기에 가장 간단하고 범용성이 높은 방법은 다음과 같다. 1. axes 인스턴스 필요 대부분의 matplotlib 그림에서 고급 기능을 쓰기 위해서는 axes 인스턴스를 필요로 한다. 이 axes는 내가 그리고자 하는 figure에 할당된 class인데 그 내부에서 구체적으로 설정하는 매서드가 담겨있어서 이것에 접근해야한다. plt.plot(num_history, train_mse_history) plt.ylabel('MSE') plt.xlabel('epoch') plt.yscale('symlog') ax = plt.g..
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념
·
연구 Research/인공지능 Artificial Intelligent
머신러닝, 딥러닝을 공부하다보면 다양한 에러에 대해서 배우게 된다.이 글에서는 이러한 용어들의 혼동을 막고자 종합적으로 정리하는 글이다. 1. Model bias, Estimation bias, Estimation variance  그 중에 하나가 model bias, estimation bias, estimation variance이다.    위 그림을 보면 이 개념을 설명하는 결과라고 볼 수 있다. Model bias는 실제 값과 내가 모델로 만든 값 중에서 가장 실제값을 가깝게 반영하는 모델에서 발생하는 바이어스이다. 즉, 내가 정한 모델은 아무리 최적화를 시켜도 truth 값과 완벽하게 일치하지 않을 수 있다. 예를 들어 비선형 모델인데 내가 아무리 선형 모델로 fitting을 하려고 해도 모델 ..
[Matplotlib] 3D scatter plot 그리는 코드
·
연구 Research/데이터과학 Data Science
matplotlib에서 3D scatter plot을 그리는 방법 3D plot을 그리기 위한 코드는 여러 방식이 있을 텐데 아래 방식이 가장 스탠다드인 것 같아서 이렇게 사용한다. max_value = np.amax( np.abs(data) ) # max값으로 축 제한 fig = plt.figure(constrained_layout=True) ax = fig.add_subplot(projection='3d') ax.scatter(data[:,0], data[:,1], data[:,2], marker='o', color ='r', alpha=1.0) ax.set_xlim([-max_value, max_value]) ax.set_ylim([-max_value, max_value]) ax.set_zlim([..
[데이터과학] Pandas에서 dataframe 생성 및 export
·
연구 Research/데이터과학 Data Science
데이터 결과를 보고나서 이를 저장하기 위해 일일이 옮겨적지 않고 dataframe으로 만든 다음에 csv로 export하는 방법이 훨씬 편하다. 간단하게 표현하면 아래와 같은 코드를 사용한다. result_name = (어쩌구저쩌구) filename = result_folder + result_name + '.csv' write_csv = pd.DataFrame( record_matrix ) # data_frame 생성 # 또는 write_csv = pd.DataFrame( record_matrix, columns = ['A','B','C'] ) write_csv.to_csv(filename) dataframe을 사용하면 그 dataframe에서의 데이터를 종합하여 보여주는 기능이 있다. result_m..
[JAX] 메모리 부족 문제 해결
·
연구 Research/인공지능 Artificial Intelligent
JAX를 쓰다가 너무 많은 양의 데이터를 쓰다보니 메모리 부족(OOM: Out of memory) 현상을 겪었다. 근본적인 해결책은 달리 없다. 데이터가 너무 많아서 생기는 문제이니 데이터 양을 줄이던가 아니면 병렬 컴퓨팅을 하는 방법이 있다. 병렬 컴퓨팅을 간단한 코드에서는 실행해보았는데 큰 네트워크에서는 해본 적이 없다. 일단 임시방편으로는 다음과 같다. 메모리 부족이 쉽게 발생할 수 있는 이유는 JAX에서 처음에 import를 할 때 대부분의 메모리를 미리 할당해놓기 때문이다. 따라서 이 preallocation을 막거나 줄이면 도움이 된다. 1. Preallocation 중단 XLA_PYTHON_CLIENT_PREALLOCATE=false # 구체적으로는 다음과 같이 구현한다. import os ..
[인공지능] Ubuntu 18.04에서 CUDA, CuDNN 설치
·
연구 Research/인공지능 Artificial Intelligent
이번에 리눅스 환경에서 CUDA, cuDNN을 설치하면서 있었던 시행착오를 기록하는 글이다. 기본적으로 Linux 환경에서 설치하는 방법은 매뉴얼에 잘 나와있어서 이 링크를 참고하면 되기는 하는데 디테일하게는 고민할 부분들이 있다. 이 글에서 주의할 점은 1) 나는 이미 CUDA, cuDNN을 다른 사람이 설치해놓은 버전이 있었다. (그 버전들은 root에 설치되어 있었다.) 2) 추가적으로 CUDA 다른 버전을 쓰고 싶어서 설치하기 시작했다. 3) 우분투 user 중에 나의 계정이 있지만 내 계정에 local하게 설치하는 것이 아니라 모든 유저가 쓸 수 있도록 설치하였다. 나만 쓰고 싶으면 내 home directory에서 시작되는 경로에 설치하면 된다. 공용 서버컴의 경우에는 root 계정에 이미 설..
[JAX] device 확인, default device 설정
·
연구 Research/인공지능 Artificial Intelligent
JAX에서 사용 가능한 device를 찾는 방법은 다음과 같다. import jax jax.devices() >> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)] 여기에서는 device가 4개이기 때문에 하나를 지정해서 쓰고 싶을 수 있다. 방법을 찾아보니 document에서는 jax.default_device = jax.devices("gpu")[2] # default로 세번째 gpu를 쓰고 싶은 경우 다음과 같이 쓰면 된다고 했지만 실제로는 적용이 되지 않았다. 좀 더 확실한 방법으로는, JAX를 impor..