JAX가 버전에 따라 조금씩 달라지는 부분이 있어서 정리해놓는 글.
1) 배열 원소 업데이트 방식
# JAX 0.3 이전
input_data_test = index_update(input_data_test, index[i,:], input_data_tmp)
# JAX 0.4
input_data_test = input_data_test.at[i,:].set(input_data_tmp)
2) optimizers
0.3버전에서는 optimizer가 기본적으로 제공이 되는데 이 방식이 0.4에서는 바뀐다.