기존에 자주 사용되면 파이토치나 텐서플로에서도 그렇듯이 학습한 모델을 저장하는 것은 필수이다.
나중에 다시 결과를 출력해야하거나 Transfer learning 등에 활용해야하기 때문이다. 이 글에서는 JAX 모델을 학습한 다음, 저장하고 다시 로드하는 방법에 대해서 다룬다.
1. 학습 후 저장
이 글에서는 model이라는 class 안에 optimizers를 정의하고 그 안에 있는 loss 등의 함수로 학습을 하고 있었다.
class 안에
self.opt_init, \
self.opt_update, \
self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3,
decay_steps=1000,
decay_rate=0.9))
self.opt_state = self.opt_init(params)
다음과 같이 JAX의 optimizer는 세 가지를 return하는데, opt_init은 학습할 파라미터를 처음에 초기화하는 것이다. 여기서 말하는 파라미터는 당연히 뉴럴네트워크의 weight와 bias이다.
그래서 맨 아래 코드를 보면 opt_init으로 초기화하여 self.opt_state에 저장한 다음에 opt_update로 파라미터를 업데이트하고, 이 파라미터를 가져올 때는 get_params를 써서 가져온다.
따라서 아래 코드를 보면 model.opt_state에 저장된 최종적으로 학습이 끝난 파라미터를 get_params로 가져오고 이를 ravel_pytree라는 함수로 pytree 형태로 만든다. 이를 npy 파일로 저장하는 것이다.
정확히는 ravel_pytree는 pytree of arrays를 1D array로 바꿔주는 함수이다. 이 형태로 변경해야 npy형태로 저장할 수 있기 때문이다.
from jax.flatten_util import ravel_pytree
# Save the trained model
flat_params, _ = ravel_pytree(model.get_params(model.opt_state))
np.save('save_model.npy', flat_params)
2. 다시 load
save_model_path = 'save_model.npy'
layers = [50, 50, 50, 50, 50, 50]
model = NN(layers)
flat_params = np.load(save_model_path)
params = model.unravel_params(flat_params)
model.opt_state = model.opt_init(params)
그래서 다시 모델에 학습한 데이터를 집어넣으려면 다음과 같이 model 안에 있는 optimizer의 opt_init을 통해 파라미터를 초기화시켜주면 된다.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
Neural networks의 convergence, convexity에 대한 논문 (0) | 2023.07.31 |
---|---|
[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax (0) | 2023.07.28 |
[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제 (0) | 2023.06.13 |
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념 (0) | 2023.05.16 |
[JAX] 메모리 부족 문제 해결 (1) | 2023.04.26 |