일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- 에러기록
- Dear abby
- teps
- LaTeX
- 딥러닝
- JAX
- 고체역학
- Julia
- Linear algebra
- 생산성
- obsidian
- WOX
- Python
- Zotero
- IEEE
- Numerical Analysis
- 텝스공부
- 논문작성법
- pytorch
- 수치해석
- 우분투
- 인공지능
- 옵시디언
- 논문작성
- 텝스
- 수식삽입
- matplotlib
- ChatGPT
- Statics
- MATLAB
- Today
- Total
뛰는 놈 위에 나는 공대생
[JAX] JAX 설치 및 GPU 사용하기 본문
1. JAX 설치
JAX를 설치하는 방법에 대해서는 installment guide에 잘 나와있다.
다음 링크로 가면 jax 설치법에 대한 문서를 볼 수 있다.
필자는 GPU를 쓰고 싶었기 때문에 다음을 설치했다. conda를 보통 사용하기 때문에 이렇게 했다.
conda install jax cuda-nvcc -c conda-forge -c nvidia
이렇게 설치하고 나서 문제가 발생하였다.
2. JAX에서 GPU 사용하기
다음과 같은 warning code를 만났다.
WARNING - No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
우분투에서 가상환경을 만들고 여기서 JAX를 설치했는데 분명 installation에 나온 것처럼 따라갔는데도 아래와 같이 CPU를 인식하지 못했다. 참고로 Python 버전은 3.9, CUDA는 11.4를 쓰고 있는 상황이었으므로 GPU를 쓰기에는 문제가 없었다.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
GPU가 있는데도 CPU로 출력된다. 무엇보다도 PyTorch에서는 문제없이 GPU를 사용할 수 있었던 경험으로 보아, CUDA와 연결이 잘 안된 것은 아니었다.
인터넷에서 검색해보니 다음과 같은 내용을 찾았다.
conda-forge 채널에서 설치한 cudatoolkit은 ptxas가 없기 때문에 cuda-nvcc를 nvidia channel로 설치해야한다는 것이었다. 그 컴퓨터에 있는 cuda는 내가 설치한 게 아니라서 conda-forge 채널로 설치한지도 몰랐는데 잘 안되는 것을 보니 그런 것으로 보였다.
Note the cudatoolkit distributed by conda-forge is missing ptxas, which JAX requires. You must therefore either install the cuda-nvcc package from the nvidia channel, or install CUDA on your machine separately so that ptxas is in your path. The channel order above is important (conda-forge before nvidia). We are working on simplifying this.
이것 때문에 cuda를 다 지우고 다시 시작하기는 너무 번거로워서 포기했다.
그래서 새로 가상환경을 파고 아래와 같이 PyPI 채널로 JAX를 다운받았고, GPU를 인식하는 것을 확인했다.
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 버전 지정하고 싶을 때
pip install --upgrade "jax[cuda]"==0.4.23 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
PyPI 채널과 conda-forge 채널을 섞어 쓰는 게 별로 안 좋아서 PyPI를 잘 안 쓴다. 그래서 구글을 열심히 뒤져서 사람들이 해보라는 거 다 해봤는데 다 안 통해서 결국 pip으로 설치한 것이다.
나와 비슷한 문제가 있다면 pip으로 설치해보는 것을 권한다..
더 구체적으로 cudnn과 cuda version을 지정할 수 있다.
다음 링크에서 cuda,cudnn버전을 확인한 다음에, 다음과 같이 jaxlib를 지정하고 jaxlib와 호환되는 jax 버전을 지정해서 설치해주면 된다. 보통은 jax와 jaxlib 버전 숫자를 맞춰주는 게 편하다.
pip install jaxlib==0.4.29+cuda12.cudnn91 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
JAX에 대한 Document : https://jax.readthedocs.io/en/latest/user_guides.html#user-guides
'프로그래밍 Programming > 파이썬 Python' 카테고리의 다른 글
[Python] MCMC Sampling library (0) | 2023.08.24 |
---|---|
[Jupyter notebook] 내가 설정한 주피터 노트북 테마 (0) | 2023.02.23 |
[Matplotlib] legend 그림 바깥에 배치/원하는 위치에 배치 (0) | 2023.01.14 |
[Matplotlib] Matplotlib 폰트 스타일 바꾸기 (0) | 2023.01.13 |
[PyTorch] 특정 조건에 맞는 텐서 출력/인덱싱 등 (2) | 2023.01.08 |