Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] device 확인, default device 설정 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] device 확인, default device 설정

보통의공대생 2023. 4. 13. 16:38

JAX에서 사용 가능한 device를 찾는 방법은 다음과 같다.

import jax
jax.devices()
>> [GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0)]

 

 

여기에서는 device가 4개이기 때문에 하나를 지정해서 쓰고 싶을 수 있다.

방법을 찾아보니 document에서는

jax.default_device = jax.devices("gpu")[2] # default로 세번째 gpu를 쓰고 싶은 경우

다음과 같이 쓰면 된다고 했지만 실제로는 적용이 되지 않았다.

 

좀 더 확실한 방법으로는, JAX를 import하기 전에

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

이렇게 인식할 수 있는 devices를 지정한다.

 

import jax
nmp = jax.numpy.ones(4)
print(nmp.device())
print(jax.devices())
>> gpu:0
[GpuDevice(id=0, process_index=0)]

그 다음에 배열을 만들어서 확인을 했는데 device를 조회해보면 id=0 밖에 인식이 되지 않는다.

gpu:0이지만 실제로는 내가 지정한 세 번째 gpu이다.

이를 확인하는 방법은 nvidia-smi를 터미널에 입력하는 것이다.

 

 

다음과 같이 JAX는 처음에 import할 때 90퍼센트 가까이 preallocation하기 때문에 위와 같이 세 번째 gpu의 메모리가 차지된 것을 확인할 수 있다.

Comments