inject_hyperparams라는 함수로 optax의 optimizer를 묶어서 사용하면 hyperparams를 관찰할 수 있다.
# Wrap the optimizer to inject the hyperparameters
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
# Since we injected hyperparams, we can access them directly here
print(f'Available hyperparams: {" ".join(opt_state.hyperparams.keys())}\n')
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
# Get the updated learning rate
lr = opt_state.hyperparams['learning_rate']
print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')
return params
params = fit(initial_params, optimizer)