Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] 학습한 모델 저장 및 로드 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] 학습한 모델 저장 및 로드

보통의공대생 2023. 6. 19. 17:44

기존에 자주 사용되면 파이토치나 텐서플로에서도 그렇듯이 학습한 모델을 저장하는 것은 필수이다.

나중에 다시 결과를 출력해야하거나 Transfer learning 등에 활용해야하기 때문이다. 이 글에서는 JAX 모델을 학습한 다음, 저장하고 다시 로드하는 방법에 대해서 다룬다.

 

 

1. 학습 후 저장

 

이 글에서는 model이라는 class 안에 optimizers를 정의하고 그 안에 있는 loss 등의 함수로 학습을 하고 있었다.

class 안에

self.opt_init, \
self.opt_update, \
self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                              decay_steps=1000, 
                                                              decay_rate=0.9))
self.opt_state = self.opt_init(params)

다음과 같이 JAX의 optimizer는 세 가지를 return하는데, opt_init은 학습할 파라미터를 처음에 초기화하는 것이다. 여기서 말하는 파라미터는 당연히 뉴럴네트워크의 weight와 bias이다.

그래서 맨 아래 코드를 보면 opt_init으로 초기화하여 self.opt_state에 저장한 다음에 opt_update로 파라미터를 업데이트하고, 이 파라미터를 가져올 때는 get_params를 써서 가져온다.

 

따라서 아래 코드를 보면 model.opt_state에 저장된 최종적으로 학습이 끝난 파라미터를 get_params로 가져오고 이를 ravel_pytree라는 함수로 pytree 형태로 만든다. 이를 npy 파일로 저장하는 것이다.

정확히는 ravel_pytree는 pytree of arrays를 1D array로 바꿔주는 함수이다. 이 형태로 변경해야 npy형태로 저장할 수 있기 때문이다.

from jax.flatten_util import ravel_pytree
# Save the trained model
flat_params, _  = ravel_pytree(model.get_params(model.opt_state))
np.save('save_model.npy', flat_params)

 

 

2. 다시 load 

save_model_path = 'save_model.npy'

layers = [50, 50, 50, 50, 50, 50]
model = NN(layers)
flat_params = np.load(save_model_path)
params = model.unravel_params(flat_params)
model.opt_state = model.opt_init(params)

 

그래서 다시 모델에 학습한 데이터를 집어넣으려면 다음과 같이 model 안에 있는 optimizer의 opt_init을 통해 파라미터를 초기화시켜주면 된다.

 

Comments