Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- Linear algebra
- 수식삽입
- 텝스공부
- IEEE
- 생산성
- 텝스
- 옵시디언
- LaTeX
- MATLAB
- 우분투
- teps
- Julia
- 고체역학
- 논문작성법
- JAX
- obsidian
- pytorch
- 논문작성
- 딥러닝
- ChatGPT
- Numerical Analysis
- Python
- 에러기록
- 인공지능
- Dear abby
- matplotlib
- WOX
- 수치해석
- Zotero
- Statics
Archives
- Today
- Total
뛰는 놈 위에 나는 공대생
[JAX] L-BFGS optimizer로 학습하는 예제 코드 본문
연구 Research/인공지능 Artificial Intelligent
[JAX] L-BFGS optimizer로 학습하는 예제 코드
보통의공대생 2023. 11. 2. 15:20Quasi Newton method를 이용해서 JAX에서 최적화를 시키고 싶었는데 JAX 자체가 Deep learning에 포커스가 있고 대부분 딥러닝이 Gradient descent method로 최적화를 하다보니 라이브러리를 찾게 되었다.
대부분 jaxopt라는 라이브러리를 추천했기 때문에 이걸로 수행해보았다.
jaxopt에는 jaxopt.ScipyMinimize와 jaxopt.LBFGS가 있는데 다른 분들의 시도를 보니 ScipyMinimize가 더 성능이 괜찮은 것 같다. ScipyMinimize는 scipy에 있는 최적화를 사용한 것이고 LBFGS는 직접 만든 것 같은데 line search 방법 등이 다르다고 한다.
import jax
import jax.numpy as np
from jax import random, grad, vmap, jit, hessian
from jax.example_libraries import optimizers
from jax.experimental.ode import odeint
from jax.config import config
from jax.ops import index_update, index
from jax import lax
from jax.flatten_util import ravel_pytree
from jax import jacfwd, jacrev
import jax.nn as nn
# utility
import itertools
# Define MLP
def MLP(layers, activation=nn.relu):
# Vanilla MLP
def init(rng_key):
def init_layer(key, d_in, d_out):
k1, k2 = random.split(key)
glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
W = glorot_stddev * random.normal(k1, (d_in, d_out))
b = np.zeros(d_out)
return W, b
key, *keys = random.split(rng_key, len(layers))
params = list(map(init_layer, keys, layers[:-1], layers[1:]))
return params
def apply(params, inputs):
for W, b in params[:-1]:
outputs = np.dot(inputs, W) + b
inputs = activation(outputs)
W, b = params[-1]
outputs = np.dot(inputs, W) + b
return outputs
return init, apply
다음과 같이 모델을 정의하고 아래 코드처럼 loss function과 regression하고 싶은 target function으로 데이터를 생성한다.
layer_sizes = [1, 100, 100, 100, 100, 1]
init, apply = MLP(layer_sizes, activation=nn.relu)
params = init(rng_key = random.PRNGKey(1234))
params = tuple(params)
# loss function to be minimized
def loss(params, inputs, targets):
output = apply(params, inputs)
loss = np.mean((output.flatten() - targets.flatten())**2)
return loss
# target function for regression
def target_func(inputs):
outputs = 0.15 * np.sqrt(inputs)
return outputs
# generate dataset
inputs = np.linspace(0.0, 1.0, 100).reshape((-1,1))
targets = vmap(target_func)(inputs)
마지막으로 최적화를 한다.
# Run the optimization
solver_1 = jaxopt.ScipyMinimize(method = "l-bfgs-b", fun=loss, tol = 1e-12, maxiter = 5000)
solver_1_sol = solver_1.run(params, inputs, targets) # arguments for loss function
run을 사용할 때 들어가는 argument는 loss function에 들어가는 argument들이다. ScipyMinimize를 잘 살펴보면 추가적으로 보조적인 데이터를 넣을 수 있는 걸로 안다.
아래 링크는 코드를 좀 수정하거나 좀 더 풍부하게 사용하기 위해 참고할 수 있는 jaxopt 문서이다.
https://jaxopt.github.io/stable/_modules/jaxopt/_src/scipy_wrappers.html#ScipyMinimize.run
결과를 확인하면 다음과 같다.
params_opt = solver_1_sol.params
predict = apply(params_opt, inputs)
plt.plot(inputs, targets, label='True')
plt.plot(inputs, predict, '--',label='Neural Network')
plt.legend()
plt.show()
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[연구] SciML 분야 라이브러리 기록 (0) | 2024.03.11 |
---|---|
[chatGPT] chatGPT 프롬프트 엔지니어링 (0) | 2023.11.28 |
[Deep learning] Bayesian Neural Network (1) (0) | 2023.10.23 |
[인공지능] Learning에서 scaling이 중요한가 (0) | 2023.09.20 |
[JAX] vmap과 jit의 속도 (0) | 2023.09.20 |
Comments