일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- 수식삽입
- 텝스공부
- Dear abby
- WOX
- pytorch
- 수치해석
- Numerical Analysis
- teps
- 딥러닝
- Statics
- obsidian
- 논문작성
- 고체역학
- 생산성
- JAX
- MATLAB
- Linear algebra
- Python
- ChatGPT
- IEEE
- Zotero
- 논문작성법
- 텝스
- LaTeX
- 인공지능
- 에러기록
- Julia
- 옵시디언
- matplotlib
- 우분투
- Today
- Total
뛰는 놈 위에 나는 공대생
[PyTorch] 모델 저장/불러오기 및 모델 수정하기 본문
빠른 학습을 위하여 이미 학습된 파라미터를 가지고 오고 싶을 때가 있다.
그래서 이 글에서는 기본적인 모델 저장, 불어오기 뿐 아니라 내가 좀 더 커스터마이징 할 수 있는 방법에 대해서 고민했다.
1. 모델 저장
파이토치 매뉴얼에서는 다음과 같이 모델을 저장하라고 권한다.
PATH = '(모델이름).pth'
torch.save(modelA.state_dict(), PATH)
이렇게 해서 pth 파일을 저장할 수 있다. state_dict는 내부 상태 사전(internal state dictionary)으로서 학습된 모델의 매개변수를 저장한다.
modelA.state_dict()를 출력하면 다음과 같은 형태가 나온다.
위의 OrderedDict는 dictionary 자료형인데 파이썬에서 순서 정렬에 유리하게 만든 딕셔너리 자료형이라고 한다.
또는 아예 모델 형태를 포함하여 저장하는 방법도 있다. 이 경우에는 모델 자체를 저장한 것이기 때문에 나중에 불러올 떄 바로 model 변수에 넣어서 사용할 수도 있다.
torch.save(model, 'model.pth')
2. 모델 불러오기
모델의 가중치만 저장한 경우
model.load_state_dict(torch.load('model_weights.pth'))
모델 자체를 저장한 경우
model = torch.load('model.pth')
다음과 같이 load할 수 있다.
위의 내용은 튜토리얼에 잘 나와있기 때문에 쉽게 이해할 수 있을 것이다.
그런데 내가 다루고자 하는 본론은 아래 내용이다.
3. 내가 모델에 추가적인 변수가 있을 때
나의 경우에는 일반 신경망에서 추가적으로 parameter를 nn.Parameter와 model.register_parameter를 사용하여서 모델에 추가했다. 그런데 나중에 그 모델을 저장하고 쓸 때는 그 파라미터를 제외해서 load를 하고 싶다고 해보자.
만약 그 파라미터를 삭제하지 않고 내가 정한 모델에 집어넣으려고 하면 다음과 같은 에러를 보게 될 것이다.
나는 추가적인 'Xcp'라는 파라미터를 넣었기 때문에 아무 생각없이 가중치를 옮기면 위와 같은 문제가 생긴다.
그래서 저렇게 모델을 저장한 다음에 파라미터를 삭제하는 방법에 대해서 알아야한다. 이는 위의 OrderedDict에서 파라미터를 제외하는 것과 동일하다.
OrderedDict에서 키와 value를 삭제하는 방법은 딕셔너리와 같다.
del dct[key]
또는
d.pop(your_key)
다음과 같이 파라미터를 key에 넣어서 모델의 파라미터가 저장된 딕셔너리에서 추가 파라미터를 제외하면 끝.
참고문헌
https://tutorials.pytorch.kr/beginner/saving_loading_models.html#warmstart
https://tutorials.pytorch.kr/beginner/basics/saveloadrun_tutorial.html
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[인공지능] CUDA & cuDNN 최신 버전 설치 (CUDA 11.6 이상) (0) | 2023.02.07 |
---|---|
[PyTorch] retain_graph = True라고 했음에도 backward 문제가 발생하는 경우 (0) | 2023.01.05 |
[딥러닝] Backpropagation을 위한 Automatic differentiation 이론/코딩 (0) | 2022.07.09 |
[인공지능] CUDA & cuDNN 설치하는 방법 (0) | 2021.08.30 |
[머신러닝] Boosting method (0) | 2021.05.26 |