Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기

보통의공대생 2023. 4. 12. 19:04

이 글은 JAX 버전 맞추느라 여러 🐶고생한 경험을 바탕으로 작성하였다.

 

 

0. 요구 버전에 대한 이해

 

JAX는 설치할 때 요구하는 버전이 있다. 개별 gpu에 따라도 달라져서 까다롭긴한데

 

JAX currently ships three CUDA wheel variants:
CUDA 12.0 and CuDNN 8.8.
CUDA 11.8 and CuDNN 8.6.
CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued with jax 0.4.8.

 

위의 세 버전이 가능하다고 하는데, 이는 최신 JAX 버전(230412 기준)에 따른 것이다.

CUDA 11.4라고 적혀있는 경우에는 11.4 이상이면서 CuDNN 8.2 이상이면 된다.

 

아래 그림은 각 라이브러리 버전에 대한 설명이다.

 

 

클릭하면 확대

여기서 내가 어려움을 겪었던 것은 Torch와 JAX, 그리고 CUDA & cudnn 버전이 잘 안 맞으면서 발생한 문제이다.

가능하면 모든 CUDA 버전을 설치하고 cudnn은 CUDA 버전과 호환되는 것들 중에 가장 높은 것을 설치해주는 게 좋다고 생각한다.

 

1. 문제 상황

 

원래는 윈도우에서 JAX를 돌리다가 이제 리눅스에서 돌리고 싶어서 환경을 옮겼다.

윈도우에서 JAX를 설치하는 법은 다음 글에서 확인할 수 있다.

 

그런데 아래와 같은 오류를 발견하였다.

 

E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

 

 

이 오류는 JAX가 cudnn 버전을 8.6.0을 원하는데 실제로 사용되는 것은 8.5.0이라서 문제가 있다고 말해주고 있다.

내가 CUDA를 설치할 때 cudnn 버전을 8.8.0으로 사용했기 때문에 문제가 없을 줄 알았지만 파이토치가 1.13.0에 CUDA 11.7 - cudnn 8.5.0 을 사용하기 때문에 발생하는 문제였다.

그러나 파이토치에서 cudnn 8.5.0 이상으로 쓰려면 파이토치 자체를 2.0 이상으로 써야했다. (이미 파이토치 2.0 이상 - CUDA 12.1 cudnn 8.8.1 을 해놓은 상태여서 이렇게 해결할 수도 있다.) 그런데 나는 파이토치, CUDA, cudnn 버전을 그대로 두고 싶었다.

 

그래서 일단 JAX/jaxlib을 uninstall하면

 

이렇게 jaxlib을 설치할 때 필요한 CUDA, cudnn 버전이 나오는 것을 확인할 수 있다. 문제는 JAX를 설치할 때는 어떤 cudnn 버전 이상이 필요한지 알 수가 없다. (혹시 알 수 있는 방법이 있다면 알려주세요..)

 

당장 드는 생각은 jaxlib 하위 버전을 설치해보면서 저 글을 확인하는 수밖에 없을 것 같다.

jax와 jaxlib의 change log을 작성한 페이지도 있지만 안타깝게도 호환되는 cudnn에 대한 내용은 거의 없다.

 

 

 

2. 특정 jax/jaxlib 버전 설치방법

 

그래서 그냥 훨씬 옛날 버전인 jax 0.3.1 & jaxlib 0.3.0을 쓰기로 했다.

 

예전 버전은 아래 사이트에서 확인할 수 있다.

 

(CUDA 말고 CPU 버전)

https://storage.googleapis.com/jax-releases/jax_releases.html

 

https://storage.googleapis.com/jax-releases/jax_releases.html

 

storage.googleapis.com

 

(CUDA 버전)

https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

 

https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

 

storage.googleapis.com

 

위 사이트에 들어가면 

 

 

 

이렇게 다운받을 수 있는 wheel이 있는데 여기에서 사용할 cuda 버전과 cudnn 버전에 대한 힌트를 얻을 수 있다.

예를 들면,

cudnn88 -> cudnn 8.8.x

cuda12 -> cuda 12.x

이렇게 이해할 수 있다.

 

따라서 아래와 같이 작성해서 설치한다.

pip install jaxlib==<버전> -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jax==<버전> -f https://storage.googleapis.com/jax-releases/jax_releases.html

# CUDA 버전
pip install jaxlib==<버전> -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install jax==<버전> -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

 

 

3. 제대로 됐는지 확인 절차

 

1) GPU 사용 여부 확인

 

GPU를 사용하는지 확인하는 코드는 다음과 같다.

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

 

 

다음 코드를 입력할 때 GPU라고 나와야 제대로 CUDA와 cudnn을 인식한 것이라 볼 수 있다.

보통은 버전 호환문제이고, 인식이 안되는 또다른 경우로는 환경변수에 cuda 경로가 없어서이다.

또한 위에서 jax 중에서도 cuda 버전으로 설치하지 않으면 당연히 GPU를 인식하지 못한다.

 

2) PyTorch가 사용하는 CUDA, cudnn 확인

 

import torch
print("cudnn version:{}".format(torch.backends.cudnn.version()))
print("cuda version: {}".format(torch.version.cuda))

내가 의도한 cuda와 cudnn 버전이 사용되는지 확인한다.

 


참고자료

 

https://github.com/google/jax

Comments