[에러기록] ImportError: cannot import name 'index_update' from 'jax.ops'
·
기타
from jax.ops import index_update, index = index_update(, index[], ) % example u = index_update(u, index[0, :], g(t)) 다음 코드를 실행할 때 발생하는 문제이다. JAX 0.3.2부터는 jax.opt.index_update, jax.opt.index가 사라졌기 때문에 이 기능을 쓰기 위해서는 jax와 jaxlib을 0.3.2 버전 전으로 돌려야 한다. 혹은 저 위의 기능은 특정 인덱스에 배열 값을 바꾸는 코드이기 때문에 x = x.at[idx].set(y) 와 같이 특정 인덱스에 y라는 값으로 바꾸는 코드로 바꿔줄 수 있다.
[JAX] JAX 설치 및 GPU 사용하기
·
프로그래밍 Programming/파이썬 Python
1. JAX 설치 JAX를 설치하는 방법에 대해서는 installment guide에 잘 나와있다.다음 링크로 가면 jax 설치법에 대한 문서를 볼 수 있다. 필자는 GPU를 쓰고 싶었기 때문에 다음을 설치했다. conda를 보통 사용하기 때문에 이렇게 했다. conda install jax cuda-nvcc -c conda-forge -c nvidia 이렇게 설치하고 나서 문제가 발생하였다. 2. JAX에서 GPU 사용하기다음과 같은 warning code를 만났다. WARNING - No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)  우분투에서 가상환경을 만들고 여기서 JAX를 설치했는..