[에러기록] ImportError: cannot import name 'index_update' from 'jax.ops'
·
기타
from jax.ops import index_update, index = index_update(, index[], ) % example u = index_update(u, index[0, :], g(t)) 다음 코드를 실행할 때 발생하는 문제이다. JAX 0.3.2부터는 jax.opt.index_update, jax.opt.index가 사라졌기 때문에 이 기능을 쓰기 위해서는 jax와 jaxlib을 0.3.2 버전 전으로 돌려야 한다. 혹은 저 위의 기능은 특정 인덱스에 배열 값을 바꾸는 코드이기 때문에 x = x.at[idx].set(y) 와 같이 특정 인덱스에 y라는 값으로 바꾸는 코드로 바꿔줄 수 있다.