Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[PyTorch] 모델 저장/불러오기 및 모델 수정하기 본문

연구 Research/인공지능 Artificial Intelligent

[PyTorch] 모델 저장/불러오기 및 모델 수정하기

보통의공대생 2023. 1. 4. 22:29

빠른 학습을 위하여 이미 학습된 파라미터를 가지고 오고 싶을 때가 있다.

그래서 이 글에서는 기본적인 모델 저장, 불어오기 뿐 아니라 내가 좀 더 커스터마이징 할 수 있는 방법에 대해서 고민했다.

 

1. 모델 저장

 

파이토치 매뉴얼에서는 다음과 같이 모델을 저장하라고 권한다.

PATH = '(모델이름).pth'
torch.save(modelA.state_dict(), PATH)

 

이렇게 해서 pth 파일을 저장할 수 있다. state_dict는 내부 상태 사전(internal state dictionary)으로서 학습된 모델의 매개변수를 저장한다.

modelA.state_dict()를 출력하면 다음과 같은 형태가 나온다.

 

 

위의 OrderedDict는 dictionary 자료형인데 파이썬에서 순서 정렬에 유리하게 만든 딕셔너리 자료형이라고 한다.

 

 

또는 아예 모델 형태를 포함하여 저장하는 방법도 있다. 이 경우에는 모델 자체를 저장한 것이기 때문에 나중에 불러올 떄 바로 model 변수에 넣어서 사용할 수도 있다.

torch.save(model, 'model.pth')

 

 

2. 모델 불러오기

 

모델의 가중치만 저장한 경우

model.load_state_dict(torch.load('model_weights.pth'))

 

모델 자체를 저장한 경우

model = torch.load('model.pth')

 

다음과 같이 load할 수 있다.

 

 

위의 내용은 튜토리얼에 잘 나와있기 때문에 쉽게 이해할 수 있을 것이다.

그런데 내가 다루고자 하는 본론은 아래 내용이다.

 

 

 

 

3. 내가 모델에 추가적인 변수가 있을 때

 

나의 경우에는 일반 신경망에서 추가적으로 parameter를 nn.Parameter와 model.register_parameter를 사용하여서 모델에 추가했다. 그런데 나중에 그 모델을 저장하고 쓸 때는 그 파라미터를 제외해서 load를 하고 싶다고 해보자.

 

만약 그 파라미터를 삭제하지 않고 내가 정한 모델에 집어넣으려고 하면 다음과 같은 에러를 보게 될 것이다.

나는 추가적인 'Xcp'라는 파라미터를 넣었기 때문에 아무 생각없이 가중치를 옮기면 위와 같은 문제가 생긴다.

 

그래서 저렇게 모델을 저장한 다음에 파라미터를 삭제하는 방법에 대해서 알아야한다. 이는 위의 OrderedDict에서 파라미터를 제외하는 것과 동일하다.

 

OrderedDict에서 키와 value를 삭제하는 방법은 딕셔너리와 같다.

del dct[key]
또는
d.pop(your_key)

다음과 같이 파라미터를 key에 넣어서 모델의 파라미터가 저장된 딕셔너리에서 추가 파라미터를 제외하면 끝.

 


 

참고문헌

 

https://tutorials.pytorch.kr/beginner/saving_loading_models.html#warmstart

 

모델 저장하기 & 불러오기

Author: Matthew Inkawhich, 번역: 박정환,. 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다 읽는 것도 좋은 방법이지만, 필요한 사용 예의 코드만 참고하

tutorials.pytorch.kr

https://tutorials.pytorch.kr/beginner/basics/saveloadrun_tutorial.html

 

모델 저장하고 불러오기

파이토치(PyTorch) 기본 익히기|| 빠른 시작|| 텐서(Tensor)|| Dataset과 Dataloader|| 변형(Transform)|| 신경망 모델 구성하기|| Autograd|| 최적화(Optimization)|| 모델 저장하고 불러오기 이번 장에서는 저장하기나

tutorials.pytorch.kr

 

Comments