[JAX] optax에서 learning rate 확인하는 방법
·
프로그래밍 Programming
inject_hyperparams라는 함수로 optax의 optimizer를 묶어서 사용하면 hyperparams를 관찰할 수 있다. # Wrap the optimizer to inject the hyperparameters optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule) def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params: opt_state = optimizer.init(params) # Since we injected hyperparams, we can access them directly here print(f'A..