[JAX] optax에서 learning rate 확인하는 방법

2023. 8. 23. 20:05·프로그래밍 Programming

inject_hyperparams라는 함수로 optax의 optimizer를 묶어서 사용하면 hyperparams를 관찰할 수 있다.

 

# Wrap the optimizer to inject the hyperparameters
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  # Since we injected hyperparams, we can access them directly here
  print(f'Available hyperparams: {" ".join(opt_state.hyperparams.keys())}\n')

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      # Get the updated learning rate
      lr = opt_state.hyperparams['learning_rate']
      print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')

  return params

params = fit(initial_params, optimizer)
저작자표시 비영리 변경금지 (새창열림)

'프로그래밍 Programming' 카테고리의 다른 글

[Julia] Julia 프로그래밍 공부자료  (0) 2024.02.06
[JAX] 버전에 따른 변화  (0) 2023.11.22
[JAX] NaN, Inf 값 처리 및 조건에 맞는 요소 찾기  (0) 2023.08.13
[에러기록] cudnn 버전 문제 (E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:433] Loaded runtime CuDNN library: 8.8.0 but source was compiled with: 8.9.1. CuDNN library needs to have matching major version and equal or higher minor version. ..  (0) 2023.08.02
[git blog] jekyll 테마 적용하면서 발생한 에러들  (0) 2023.05.10
'프로그래밍 Programming' 카테고리의 다른 글
  • [Julia] Julia 프로그래밍 공부자료
  • [JAX] 버전에 따른 변화
  • [JAX] NaN, Inf 값 처리 및 조건에 맞는 요소 찾기
  • [에러기록] cudnn 버전 문제 (E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:433] Loaded runtime CuDNN library: 8.8.0 but source was compiled with: 8.9.1. CuDNN library needs to have matching major version and equal or higher minor version. ..
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (459) N
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (6)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    서버
    WOX
    수치해석
    Numerical Analysis
    고체역학
    옵시디언
    Python
    JAX
    딥러닝
    논문작성
    인공지능
    생산성
    Linear algebra
    MATLAB
    Dear abby
    에러기록
    teps
    Julia
    텝스
    Statics
    텝스공부
    obsidian
    ChatGPT
    matplotlib
    우분투
    Zotero
    pytorch
    IEEE
    논문작성법
    LaTeX
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] optax에서 learning rate 확인하는 방법
상단으로

티스토리툴바