Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] L-BFGS optimizer로 학습하는 예제 코드 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] L-BFGS optimizer로 학습하는 예제 코드

보통의공대생 2023. 11. 2. 15:20

Quasi Newton method를 이용해서 JAX에서 최적화를 시키고 싶었는데 JAX 자체가 Deep learning에 포커스가 있고 대부분 딥러닝이 Gradient descent method로 최적화를 하다보니 라이브러리를 찾게 되었다.

 

대부분 jaxopt라는 라이브러리를 추천했기 때문에 이걸로 수행해보았다.

 

jaxopt에는 jaxopt.ScipyMinimize와 jaxopt.LBFGS가 있는데 다른 분들의 시도를 보니 ScipyMinimize가 더 성능이 괜찮은 것 같다. ScipyMinimize는 scipy에 있는 최적화를 사용한 것이고 LBFGS는 직접 만든 것 같은데 line search 방법 등이 다르다고 한다.

 

 

import jax
import jax.numpy as np
from jax import random, grad, vmap, jit, hessian
from jax.example_libraries import optimizers
from jax.experimental.ode import odeint
from jax.config import config
from jax.ops import index_update, index
from jax import lax
from jax.flatten_util import ravel_pytree
from jax import jacfwd, jacrev
import jax.nn as nn

# utility
import itertools

# Define MLP
def MLP(layers, activation=nn.relu):
  # Vanilla MLP
    def init(rng_key):
        def init_layer(key, d_in, d_out):
            k1, k2 = random.split(key)
            glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
            W = glorot_stddev * random.normal(k1, (d_in, d_out))
            b = np.zeros(d_out)
            return W, b
        key, *keys = random.split(rng_key, len(layers))
        params = list(map(init_layer, keys, layers[:-1], layers[1:]))
        return params
    def apply(params, inputs):
        for W, b in params[:-1]:
            outputs = np.dot(inputs, W) + b
            inputs = activation(outputs)
        W, b = params[-1]
        outputs = np.dot(inputs, W) + b
        return outputs
    return init, apply

 

다음과 같이 모델을 정의하고 아래 코드처럼 loss function과 regression하고 싶은 target function으로 데이터를 생성한다.

 

 

 

layer_sizes = [1, 100, 100, 100, 100, 1]

init, apply = MLP(layer_sizes, activation=nn.relu)
params = init(rng_key = random.PRNGKey(1234))
params = tuple(params)

# loss function to be minimized
def loss(params, inputs, targets):
    output = apply(params, inputs)
    loss = np.mean((output.flatten() - targets.flatten())**2)
    return loss

# target function for regression
def target_func(inputs):
    outputs = 0.15 * np.sqrt(inputs)
    return outputs

# generate dataset
inputs = np.linspace(0.0, 1.0, 100).reshape((-1,1))
targets = vmap(target_func)(inputs)

 

 

마지막으로 최적화를 한다.

# Run the optimization
solver_1 = jaxopt.ScipyMinimize(method = "l-bfgs-b", fun=loss, tol = 1e-12, maxiter = 5000)
solver_1_sol = solver_1.run(params, inputs, targets) # arguments for loss function

 

run을 사용할 때 들어가는 argument는 loss function에 들어가는 argument들이다. ScipyMinimize를 잘 살펴보면 추가적으로 보조적인 데이터를 넣을 수 있는 걸로 안다.

 

아래 링크는 코드를 좀 수정하거나 좀 더 풍부하게 사용하기 위해 참고할 수 있는 jaxopt 문서이다.

https://jaxopt.github.io/stable/_modules/jaxopt/_src/scipy_wrappers.html#ScipyMinimize.run

 

jaxopt._src.scipy_wrappers — JAXopt 0.8 documentation

© Copyright 2021-2022, the JAXopt authors.

jaxopt.github.io

 

 

결과를 확인하면 다음과 같다.

 

params_opt = solver_1_sol.params
predict = apply(params_opt, inputs)

plt.plot(inputs, targets, label='True')
plt.plot(inputs, predict, '--',label='Neural Network')
plt.legend()
plt.show()

 

 

 

Comments