일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- 텝스공부
- Dear abby
- Python
- 논문작성법
- obsidian
- Julia
- ChatGPT
- IEEE
- pytorch
- 수치해석
- Numerical Analysis
- Zotero
- 인공지능
- Linear algebra
- Statics
- 딥러닝
- 생산성
- teps
- 논문작성
- JAX
- 우분투
- MATLAB
- 에러기록
- 고체역학
- 수식삽입
- WOX
- LaTeX
- 텝스
- 옵시디언
- matplotlib
- Today
- Total
목록연구 Research (93)
뛰는 놈 위에 나는 공대생
데이터 결과를 보고나서 이를 저장하기 위해 일일이 옮겨적지 않고 dataframe으로 만든 다음에 csv로 export하는 방법이 훨씬 편하다. 간단하게 표현하면 아래와 같은 코드를 사용한다. result_name = (어쩌구저쩌구) filename = result_folder + result_name + '.csv' write_csv = pd.DataFrame( record_matrix ) # data_frame 생성 # 또는 write_csv = pd.DataFrame( record_matrix, columns = ['A','B','C'] ) write_csv.to_csv(filename) dataframe을 사용하면 그 dataframe에서의 데이터를 종합하여 보여주는 기능이 있다. result_m..
JAX를 쓰다가 너무 많은 양의 데이터를 쓰다보니 메모리 부족(OOM: Out of memory) 현상을 겪었다. 근본적인 해결책은 달리 없다. 데이터가 너무 많아서 생기는 문제이니 데이터 양을 줄이던가 아니면 병렬 컴퓨팅을 하는 방법이 있다. 병렬 컴퓨팅을 간단한 코드에서는 실행해보았는데 큰 네트워크에서는 해본 적이 없다. 일단 임시방편으로는 다음과 같다. 메모리 부족이 쉽게 발생할 수 있는 이유는 JAX에서 처음에 import를 할 때 대부분의 메모리를 미리 할당해놓기 때문이다. 따라서 이 preallocation을 막거나 줄이면 도움이 된다. 1. Preallocation 중단 XLA_PYTHON_CLIENT_PREALLOCATE=false # 구체적으로는 다음과 같이 구현한다. import os ..
이번에 리눅스 환경에서 CUDA, cuDNN을 설치하면서 있었던 시행착오를 기록하는 글이다. 기본적으로 Linux 환경에서 설치하는 방법은 매뉴얼에 잘 나와있어서 이 링크를 참고하면 되기는 하는데 디테일하게는 고민할 부분들이 있다. 이 글에서 주의할 점은 1) 나는 이미 CUDA, cuDNN을 다른 사람이 설치해놓은 버전이 있었다. (그 버전들은 root에 설치되어 있었다.) 2) 추가적으로 CUDA 다른 버전을 쓰고 싶어서 설치하기 시작했다. 3) 우분투 user 중에 나의 계정이 있지만 내 계정에 local하게 설치하는 것이 아니라 모든 유저가 쓸 수 있도록 설치하였다. 나만 쓰고 싶으면 내 home directory에서 시작되는 경로에 설치하면 된다. 공용 서버컴의 경우에는 root 계정에 이미 설..
JAX에서 사용 가능한 device를 찾는 방법은 다음과 같다. import jax jax.devices() >> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)] 여기에서는 device가 4개이기 때문에 하나를 지정해서 쓰고 싶을 수 있다. 방법을 찾아보니 document에서는 jax.default_device = jax.devices("gpu")[2] # default로 세번째 gpu를 쓰고 싶은 경우 다음과 같이 쓰면 된다고 했지만 실제로는 적용이 되지 않았다. 좀 더 확실한 방법으로는, JAX를 impor..
이 글은 JAX 버전 맞추느라 여러 🐶고생한 경험을 바탕으로 작성하였다. 0. 요구 버전에 대한 이해 JAX는 설치할 때 요구하는 버전이 있다. 개별 gpu에 따라도 달라져서 까다롭긴한데 JAX currently ships three CUDA wheel variants: CUDA 12.0 and CuDNN 8.8. CUDA 11.8 and CuDNN 8.6. CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued with jax 0.4.8. 위의 세 버전이 가능하다고 하는데, 이는 최신 JAX 버전(230412 기준)에 따른 것이다. CUDA 11.4라고 적혀있는 경우에는 11.4 이상이면서 CuDNN 8.2 이상이면 된다. ..
JAX로 학습하는 도중에 NAN값이 나와서 어디부터 원인인지 찾기가 어려웠다. 이럴 때는 아래 코드를 추가하면 된다. from jax.config import config config.update("jax_debug_nans", True) 이렇게 할 경우에 NAN이 발생하는 즉시 어떤 코드에서 문제가 발생하는지를 알려주고 코드가 종료된다.
Gaussian process에서 사용하는 커널 종류는 다양할 수 있지만 여기서는 Radial Basis Fuction을 이용해서 gaussian process 샘플들을 구하고 이에 대한 관찰을 시각화하는 방법에 대해서 이야기한다. RBF 함수는 Paris Perdikaris 교수님의 수업자료를 참고하였다. $k(x_1,x_2)=\eta \exp\left( \dfrac{(x_1 -x_2)^{2}}{2l^2}\right)$ 커널함수가 이렇게 설정되어 있을 때 우리가 조절할 수 있는 파라미터는 scale factor인 $\eta$와 length인 $l$이다. 개념적으로 생각하였을 때 random process인 gaussian process는 $\mathbf{x}~\mathcal{N}(\mathbf{0}, ..
가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.차츰 업데이트 할 예정 1. 기본 학습 코드import jaximport jax.numpy as jnpfrom jax import grad, jit, vmapfrom jax import random# Define a simple neural network modeldef init_params(layer_sizes, key): params = [] for i in range(1, len(layer_sizes)): key, subkey = random.split(key) w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i])) b = j..
본 글에서는 JAX로 미분값을 구하는 방법에 대해서 다룬다. JAX에서는 미분값을 구하기 위해 grad, jacfwd, jacrev를 제공하기 때문에 몇 가지 예제를 통해서 익숙해지고자 한다. 일단 크게 scalar-valued function과 vector-valued function으로 나누고, 각 function이 한 개의 변수에만 의존하는지, 또는 두 개 이상의 변수에만 의존하는지를 따진다. 예제코드는 유튜브 튜토리얼 + JAX 매뉴얼을 참고하였다. 1. Scalar-valued function일 때 Gradient는 scalar-valued univariate function에 대한 기울기 Jacobian은 vector-valued or scalar-valued multivariate funct..
현재 기준(230318)으로 cholesky decomposition을 사용할 때 행렬 크기가 50정도 넘어가면 nan을 출력하는 오류가 있다. jnp.linalg.cholesky(K) jax.random.multivariate_normal(subkeys[0], np.zeros((N_samples,)), K) 이 때문에 cholesky decomposition을 쓰는 다른 함수들도 영향을 받았는데 jax.random.multivariate_normal의 경우에도 랜덤하게 추출하는 과정에서 cholesky decomposition을 쓴다. cholesky decomposition은 어쩔 수 없을 것 같고 jax.random.multivariate_normal(subkeys[0], np.zeros((N_sa..