Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] Windows에서도 JAX 사용하기 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] Windows에서도 JAX 사용하기

보통의공대생 2023. 2. 24. 15:20

JAX는 아직 리눅스에서밖에 사용이 안된다.

그래서 윈도우에서 돌릴 수 있는 방법을 찾아보았는데 최신 버전은 불가능하고 예전 버전은 가능하다. JAX가 아직 초기이다보니 버전마다 많이 바뀌어서 불편한 점이 있지만 일단 시도한 경험을 공유한다.

 

 

아래 링크를 들어가면 대략적인 instruction을 알 수 있다.

 

https://github.com/cloudhan/jax-windows-builder

 

GitHub - cloudhan/jax-windows-builder: A community supported Windows build for jax.

A community supported Windows build for jax. Contribute to cloudhan/jax-windows-builder development by creating an account on GitHub.

github.com

 

1. JAX 설치

 

위 링크에 들어가보면 pip으로 install 할 수 있는데 이는 jax만 설치하는 것이다. 필자의 경우에는 아래 있는 명령어를 사용해서 설치했다. 참고로 호환문제 때문에 저렇게만 쓰면 곤란하고 jax[cuda111]==버전 이렇게 지정해야 된다.

 

아래에서 설명할 jaxlib 설치할 때 보면 jaxlib이 0.3번대 밖에 없어서 jax 0.4이상 버전은 호환이 불가능하다.

 

다음과 같이 ==0.3.1처럼 버전을 지정해줄 수 있다.

 

pip install jax[cuda111]==0.3.1 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

 

2.  Jaxlib 설치

 

위와 같이 jax만 설치하면 안되고 jaxlib이 필요하다. 그래서 다음 링크에 들어가서 wheel을 다운 받아야 한다.

 

https://whls.blob.core.windows.net/unstable/index.html

 

https://whls.blob.core.windows.net/unstable/index.html

 

whls.blob.core.windows.net

위에서 보면 여러 개의 whl 파일이 있는데 본인이 설치하고 싶은 버전으로 다운 받는다.

 

일단 필자는 테스트용으로 0.3.7을 다운 받았다.

 

cp39가 무슨 뜻인지 몰라서 알아보니 다음과 같다.

These stand for the version of CPython (i.e. the Python official distribution you get from python.org) which the wheel files are built for.

본인이 설치하고자 하는 가상환경의 파이썬 버전을 맞춰주면 된다.

이걸 다운 받고 

 

# download jaxlib from https://whls.blob.core.windows.net/unstable/index.html
pip install <jaxlib_whl> # <jaxlib_whl>에 아까 다운받은 파일명을 넣어야 한다.
pip install jax
pip install jax[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

위 코드를 실행시켜서 jaxlib까지 설치한다. jaxlib을 설치할 때는 당연히 경로가 그 파일(.whl)이 있는 경로에서 실행해야 한다.

아래의 jax 설치는 jax==0.x.x 과 같이 버전을 정해줘야한다.

 

 

3. 버전 맞춰주기

 

하지만 여기서 문제가 있는 게 우리가 설치한 jax는 버전을 특정하지 않으면 가장 최신인 0.4.4가 설치되고, 위의 jaxlib 0.3.7은 너무 옛날 것이라서 호환이 안된다. 따라서 jax를 다운그레이드 시키는 것을 권한다.

 

버전을 고려할 때는 개발 과정을 찾아봐야한다.

https://jax.readthedocs.io/en/latest/changelog.html

 

나의 경우에는 jax 0.3.1을 써야해서 jaxlib은 0.3.0을 설치해야했다.

 

 

 

 

Comments