[에러기록] Pytorch 모델 weight가 업데이트되는지 확인

2022. 12. 30. 21:34·프로그래밍 Programming

Pytorch 모델을 업데이트하려고 optimizer를 쓰는데 loss가 전혀 변하지 않는 것을 보고 업데이트 된 것인지 확인하고 싶었다.

 

다음 코드와 같이 일치하지 않는 부분이 있는지 확인하였다.

 

optimizer.zero_grad()
a = list(model.parameters())[0]
print(list(model.parameters())[0].grad) # gradient가 잘 계산되었는지 확인, None이면 이상한 것

loss.backward()
optimizer.step()
b = list(model.parameters())[0]

print(torch.equal(a.data, b.data)) # a와 b가 일치하는지 확인

 

업데이트 전 후의 parameter 값을 비교함으로써 업데이트되는지 확인할 수 있다.

다만 주의할 점은 learning rate를 너무 작게 잡고 gradient가 작으면 한 스텝마다 거의 업데이트되지 않는다. 이 와중에 자료형이 float32이면 분명 업데이트 되었음에도 수치적으로 표현이 되지 않아 a,b가 동일하게 나올 수 있다.

 

그래서 이런 경우에는 동일한 input을 넣어서 model의 output이 동일한지 비교해보도록 한다.

 


나의 경우에는 업데이트가 되지 않는 것을 위의 코드로 확인했고

알고보니 model을 optimizer를 정의한 후에 만들어서 그런 것이었다(...멍청)

저작자표시 비영리 변경금지 (새창열림)

'프로그래밍 Programming' 카테고리의 다른 글

[JupyterLab] 코드 줄 번호 default 표시, 폰트 사이즈, family 변경  (0) 2023.01.02
[에러기록] RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'  (0) 2022.12.30
[에러기록] oserror: [winerror 182] 운영 체제가 %1을(를) 실행할 수 없습니다. Error loading "\lib\site-packages\torch\lib\shm.dll" or one of its dependencies  (2) 2022.12.29
[에러기록] 아나콘다와 관련된 수많은 에러들  (0) 2022.12.29
[에러기록] CondaHTTPError: HTTP 000 CONNECTION FAILED for url  (0) 2022.12.26
'프로그래밍 Programming' 카테고리의 다른 글
  • [JupyterLab] 코드 줄 번호 default 표시, 폰트 사이즈, family 변경
  • [에러기록] RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
  • [에러기록] oserror: [winerror 182] 운영 체제가 %1을(를) 실행할 수 없습니다. Error loading "\lib\site-packages\torch\lib\shm.dll" or one of its dependencies
  • [에러기록] 아나콘다와 관련된 수많은 에러들
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (468)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (27)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 실험 Experiment (1)
      • 유학 생활 Daily (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    LaTeX
    ChatGPT
    Zotero
    Julia
    obsidian
    논문작성법
    pytorch
    텝스
    고체역학
    Numerical Analysis
    Linear algebra
    Statics
    인공지능
    우분투
    Python
    matplotlib
    옵시디언
    생산성
    수치해석
    텝스공부
    에러기록
    Dear abby
    논문작성
    JAX
    MATLAB
    teps
    IEEE
    딥러닝
    서버
    WOX
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[에러기록] Pytorch 모델 weight가 업데이트되는지 확인
상단으로

티스토리툴바