JAX 사용시 발생할 수 있는 문제점이다.
JAX의 최신버전은 가장 최신의 CUDA, cuDNN 버전을 요구하기 때문에 다음과 같은 에러를 만나게 된다.
E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:433] Loaded runtime CuDNN library: 8.8.0 but source was compiled with: 8.9.1. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgr
CUDA12.1 버전에 cuDNN 8.9.1버전(현재 기준으로 가장 최신버전은 8.9.2인데 보통 요구하는 cuDNN보다 더 높은 버전을 맞춰주면 괜찮다)을 맞춰줘야하는데 cuDNN 8.8.0버전을 쓰고 있어서 문제가 되었다. cuDNN을 다시 다운받아서 CUDA 파일에 옮겨넣어주면 된다. 자세한 방법은 다음 글을 참고한다.
이렇게 한 다음에 내가 원하는 CUDA가 있는 폴더에서 cudnn_version.h 파일을 열어보면 버전을 확인할 수 있다.
cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9
#define CUDNN_PATCHLEVEL 2
--
#define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
/* cannot use constexpr here since this is a C-only file */
8.9.2로 제대로 되어있음을 확인하였다.
그 다음에 다시 시도를 했는데 아래와 같은 오류를 봤다.
2023-08-02 13:02:50.302414: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:445] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED 2023-08-02 13:02:50.302471: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:449] Memory usage: 3735552 bytes free, 25438126080 bytes total. 2023-08-02 13:02:50.302500: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:459] Possibly insufficient driver version: 530.30.2
이 오류는 driver 버전 문제일 수도 있다고 했지만
그래서 살펴보니 JAX 설치할 때 525.60.13이므로 문제가 없었다.
There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend installing CUDA and CUDNN using the pip wheels, since it is much easier!
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since NVIDIA has dropped support for Kepler GPUs in its software.
You must first install the NVIDIA driver. We recommend installing the newest driver available from NVIDIA, but the driver must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. If you need to use an newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.
결론
내가 가상환경을 많이 만들고 그 안에서 JAX 버전이 여러개가 설치가 되었었다.
그리고 CUDA 역시 11.2부터 12.1까지 다양하게 보유하고 있었는데 모든 경로를 환경변수로 추가를 해둔 상태였다.
비록 JAX 버전을 CUDA12에 맞게 설치했다해도 11.8버전 경로가 먼저 있기 때문에 그 안에서 cuDNN을 찾다가 자꾸 cuDNN 8.8.0을 만나서 문제가 되었던 것이었다.
아래 그림처럼 리눅스 계정 안에 가상환경1, 가상환경2가 있는데
환경변수를 CUDA 11.8 다음에 CUDA 12.1이 나오도록 하였다. 그랬더니 가상환경1이 문제를 일으킨 것인데
환경변수의 경로 순서를 12.1이 먼저 나오게 하니까 해결이 되었다.
'프로그래밍 Programming' 카테고리의 다른 글
[JAX] optax에서 learning rate 확인하는 방법 (0) | 2023.08.23 |
---|---|
[JAX] NaN, Inf 값 처리 및 조건에 맞는 요소 찾기 (0) | 2023.08.13 |
[git blog] jekyll 테마 적용하면서 발생한 에러들 (0) | 2023.05.10 |
[JupyterLab] 코드 줄 번호 default 표시, 폰트 사이즈, family 변경 (0) | 2023.01.02 |
[에러기록] RuntimeError: Expected a 'cuda' device type for generator but found 'cpu' (0) | 2022.12.30 |