[JAX] 학습한 모델 저장 및 로드

2023. 6. 19. 17:44·연구 Research/인공지능 Artificial Intelligent

기존에 자주 사용되면 파이토치나 텐서플로에서도 그렇듯이 학습한 모델을 저장하는 것은 필수이다.

나중에 다시 결과를 출력해야하거나 Transfer learning 등에 활용해야하기 때문이다. 이 글에서는 JAX 모델을 학습한 다음, 저장하고 다시 로드하는 방법에 대해서 다룬다.

 

 

1. 학습 후 저장

 

이 글에서는 model이라는 class 안에 optimizers를 정의하고 그 안에 있는 loss 등의 함수로 학습을 하고 있었다.

class 안에

self.opt_init, \
self.opt_update, \
self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                              decay_steps=1000, 
                                                              decay_rate=0.9))
self.opt_state = self.opt_init(params)

다음과 같이 JAX의 optimizer는 세 가지를 return하는데, opt_init은 학습할 파라미터를 처음에 초기화하는 것이다. 여기서 말하는 파라미터는 당연히 뉴럴네트워크의 weight와 bias이다.

그래서 맨 아래 코드를 보면 opt_init으로 초기화하여 self.opt_state에 저장한 다음에 opt_update로 파라미터를 업데이트하고, 이 파라미터를 가져올 때는 get_params를 써서 가져온다.

 

따라서 아래 코드를 보면 model.opt_state에 저장된 최종적으로 학습이 끝난 파라미터를 get_params로 가져오고 이를 ravel_pytree라는 함수로 pytree 형태로 만든다. 이를 npy 파일로 저장하는 것이다.

정확히는 ravel_pytree는 pytree of arrays를 1D array로 바꿔주는 함수이다. 이 형태로 변경해야 npy형태로 저장할 수 있기 때문이다.

from jax.flatten_util import ravel_pytree
# Save the trained model
flat_params, _  = ravel_pytree(model.get_params(model.opt_state))
np.save('save_model.npy', flat_params)

 

 

2. 다시 load 

save_model_path = 'save_model.npy'

layers = [50, 50, 50, 50, 50, 50]
model = NN(layers)
flat_params = np.load(save_model_path)
params = model.unravel_params(flat_params)
model.opt_state = model.opt_init(params)

 

그래서 다시 모델에 학습한 데이터를 집어넣으려면 다음과 같이 model 안에 있는 optimizer의 opt_init을 통해 파라미터를 초기화시켜주면 된다.

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

Neural networks의 convergence, convexity에 대한 논문  (0) 2023.07.31
[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax  (0) 2023.07.28
[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제  (0) 2023.06.13
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념  (0) 2023.05.16
[JAX] 메모리 부족 문제 해결  (1) 2023.04.26
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • Neural networks의 convergence, convexity에 대한 논문
  • [JAX] JAX 기반 Neural ODE 라이브러리 : diffrax
  • [JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제
  • [인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (460)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Numerical Analysis
    수치해석
    MATLAB
    Zotero
    matplotlib
    텝스
    옵시디언
    teps
    딥러닝
    텝스공부
    우분투
    Python
    인공지능
    Statics
    논문작성법
    Dear abby
    논문작성
    생산성
    Julia
    JAX
    ChatGPT
    obsidian
    LaTeX
    고체역학
    서버
    에러기록
    Linear algebra
    WOX
    pytorch
    IEEE
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] 학습한 모델 저장 및 로드
상단으로

티스토리툴바