[PyTorch] 모델 저장/불러오기 및 모델 수정하기

2023. 1. 4. 22:29·연구 Research/인공지능 Artificial Intelligent

빠른 학습을 위하여 이미 학습된 파라미터를 가지고 오고 싶을 때가 있다.

그래서 이 글에서는 기본적인 모델 저장, 불어오기 뿐 아니라 내가 좀 더 커스터마이징 할 수 있는 방법에 대해서 고민했다.

 

1. 모델 저장

 

파이토치 매뉴얼에서는 다음과 같이 모델을 저장하라고 권한다.

PATH = '(모델이름).pth'
torch.save(modelA.state_dict(), PATH)

 

이렇게 해서 pth 파일을 저장할 수 있다. state_dict는 내부 상태 사전(internal state dictionary)으로서 학습된 모델의 매개변수를 저장한다.

modelA.state_dict()를 출력하면 다음과 같은 형태가 나온다.

 

 

위의 OrderedDict는 dictionary 자료형인데 파이썬에서 순서 정렬에 유리하게 만든 딕셔너리 자료형이라고 한다.

 

 

또는 아예 모델 형태를 포함하여 저장하는 방법도 있다. 이 경우에는 모델 자체를 저장한 것이기 때문에 나중에 불러올 떄 바로 model 변수에 넣어서 사용할 수도 있다.

torch.save(model, 'model.pth')

 

 

2. 모델 불러오기

 

모델의 가중치만 저장한 경우

model.load_state_dict(torch.load('model_weights.pth'))

 

모델 자체를 저장한 경우

model = torch.load('model.pth')

 

다음과 같이 load할 수 있다.

 

 

위의 내용은 튜토리얼에 잘 나와있기 때문에 쉽게 이해할 수 있을 것이다.

그런데 내가 다루고자 하는 본론은 아래 내용이다.

 

 

 

 

3. 내가 모델에 추가적인 변수가 있을 때

 

나의 경우에는 일반 신경망에서 추가적으로 parameter를 nn.Parameter와 model.register_parameter를 사용하여서 모델에 추가했다. 그런데 나중에 그 모델을 저장하고 쓸 때는 그 파라미터를 제외해서 load를 하고 싶다고 해보자.

 

만약 그 파라미터를 삭제하지 않고 내가 정한 모델에 집어넣으려고 하면 다음과 같은 에러를 보게 될 것이다.

나는 추가적인 'Xcp'라는 파라미터를 넣었기 때문에 아무 생각없이 가중치를 옮기면 위와 같은 문제가 생긴다.

 

그래서 저렇게 모델을 저장한 다음에 파라미터를 삭제하는 방법에 대해서 알아야한다. 이는 위의 OrderedDict에서 파라미터를 제외하는 것과 동일하다.

 

OrderedDict에서 키와 value를 삭제하는 방법은 딕셔너리와 같다.

del dct[key]
또는
d.pop(your_key)

다음과 같이 파라미터를 key에 넣어서 모델의 파라미터가 저장된 딕셔너리에서 추가 파라미터를 제외하면 끝.

 


 

참고문헌

 

https://tutorials.pytorch.kr/beginner/saving_loading_models.html#warmstart

 

모델 저장하기 & 불러오기

Author: Matthew Inkawhich, 번역: 박정환,. 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다 읽는 것도 좋은 방법이지만, 필요한 사용 예의 코드만 참고하

tutorials.pytorch.kr

https://tutorials.pytorch.kr/beginner/basics/saveloadrun_tutorial.html

 

모델 저장하고 불러오기

파이토치(PyTorch) 기본 익히기|| 빠른 시작|| 텐서(Tensor)|| Dataset과 Dataloader|| 변형(Transform)|| 신경망 모델 구성하기|| Autograd|| 최적화(Optimization)|| 모델 저장하고 불러오기 이번 장에서는 저장하기나

tutorials.pytorch.kr

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[인공지능] Windows 기준 CUDA & cuDNN 최신 버전 설치 (CUDA 11.6 이상)  (0) 2023.02.07
[PyTorch] retain_graph = True라고 했음에도 backward 문제가 발생하는 경우  (0) 2023.01.05
[딥러닝] Backpropagation을 위한 Automatic differentiation 이론/코딩  (0) 2022.07.09
[인공지능] CUDA & cuDNN 설치하는 방법  (0) 2021.08.30
[머신러닝] Boosting method  (0) 2021.05.26
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [인공지능] Windows 기준 CUDA & cuDNN 최신 버전 설치 (CUDA 11.6 이상)
  • [PyTorch] retain_graph = True라고 했음에도 backward 문제가 발생하는 경우
  • [딥러닝] Backpropagation을 위한 Automatic differentiation 이론/코딩
  • [인공지능] CUDA & cuDNN 설치하는 방법
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (460)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    딥러닝
    Numerical Analysis
    옵시디언
    에러기록
    Linear algebra
    수치해석
    우분투
    인공지능
    고체역학
    생산성
    Python
    WOX
    matplotlib
    텝스공부
    Julia
    논문작성
    LaTeX
    IEEE
    Dear abby
    teps
    논문작성법
    서버
    Statics
    obsidian
    Zotero
    MATLAB
    JAX
    ChatGPT
    텝스
    pytorch
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[PyTorch] 모델 저장/불러오기 및 모델 수정하기
상단으로

티스토리툴바