[JAX] L-BFGS optimizer로 학습하는 예제 코드

2023. 11. 2. 15:20·연구 Research/인공지능 Artificial Intelligent

Quasi Newton method를 이용해서 JAX에서 최적화를 시키고 싶었는데 JAX 자체가 Deep learning에 포커스가 있고 대부분 딥러닝이 Gradient descent method로 최적화를 하다보니 라이브러리를 찾게 되었다.

 

대부분 jaxopt라는 라이브러리를 추천했기 때문에 이걸로 수행해보았다.

 

jaxopt에는 jaxopt.ScipyMinimize와 jaxopt.LBFGS가 있는데 다른 분들의 시도를 보니 ScipyMinimize가 더 성능이 괜찮은 것 같다. ScipyMinimize는 scipy에 있는 최적화를 사용한 것이고 LBFGS는 직접 만든 것 같은데 line search 방법 등이 다르다고 한다.

 

 

import jax
import jax.numpy as np
from jax import random, grad, vmap, jit, hessian
from jax.example_libraries import optimizers
from jax.experimental.ode import odeint
from jax.config import config
from jax.ops import index_update, index
from jax import lax
from jax.flatten_util import ravel_pytree
from jax import jacfwd, jacrev
import jax.nn as nn

# utility
import itertools

# Define MLP
def MLP(layers, activation=nn.relu):
  # Vanilla MLP
    def init(rng_key):
        def init_layer(key, d_in, d_out):
            k1, k2 = random.split(key)
            glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
            W = glorot_stddev * random.normal(k1, (d_in, d_out))
            b = np.zeros(d_out)
            return W, b
        key, *keys = random.split(rng_key, len(layers))
        params = list(map(init_layer, keys, layers[:-1], layers[1:]))
        return params
    def apply(params, inputs):
        for W, b in params[:-1]:
            outputs = np.dot(inputs, W) + b
            inputs = activation(outputs)
        W, b = params[-1]
        outputs = np.dot(inputs, W) + b
        return outputs
    return init, apply

 

다음과 같이 모델을 정의하고 아래 코드처럼 loss function과 regression하고 싶은 target function으로 데이터를 생성한다.

 

 

 

layer_sizes = [1, 100, 100, 100, 100, 1]

init, apply = MLP(layer_sizes, activation=nn.relu)
params = init(rng_key = random.PRNGKey(1234))
params = tuple(params)

# loss function to be minimized
def loss(params, inputs, targets):
    output = apply(params, inputs)
    loss = np.mean((output.flatten() - targets.flatten())**2)
    return loss

# target function for regression
def target_func(inputs):
    outputs = 0.15 * np.sqrt(inputs)
    return outputs

# generate dataset
inputs = np.linspace(0.0, 1.0, 100).reshape((-1,1))
targets = vmap(target_func)(inputs)

 

 

마지막으로 최적화를 한다.

# Run the optimization
solver_1 = jaxopt.ScipyMinimize(method = "l-bfgs-b", fun=loss, tol = 1e-12, maxiter = 5000)
solver_1_sol = solver_1.run(params, inputs, targets) # arguments for loss function

 

run을 사용할 때 들어가는 argument는 loss function에 들어가는 argument들이다. ScipyMinimize를 잘 살펴보면 추가적으로 보조적인 데이터를 넣을 수 있는 걸로 안다.

 

아래 링크는 코드를 좀 수정하거나 좀 더 풍부하게 사용하기 위해 참고할 수 있는 jaxopt 문서이다.

https://jaxopt.github.io/stable/_modules/jaxopt/_src/scipy_wrappers.html#ScipyMinimize.run

 

jaxopt._src.scipy_wrappers — JAXopt 0.8 documentation

© Copyright 2021-2022, the JAXopt authors.

jaxopt.github.io

 

 

결과를 확인하면 다음과 같다.

 

params_opt = solver_1_sol.params
predict = apply(params_opt, inputs)

plt.plot(inputs, targets, label='True')
plt.plot(inputs, predict, '--',label='Neural Network')
plt.legend()
plt.show()

 

 

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[연구] SciML 분야 라이브러리 기록  (0) 2024.03.11
[chatGPT] chatGPT 프롬프트 엔지니어링  (0) 2023.11.28
[Deep learning] Bayesian Neural Network (1)  (0) 2023.10.23
[인공지능] Learning에서 scaling이 중요한가  (0) 2023.09.20
[JAX] vmap과 jit의 속도  (0) 2023.09.20
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [연구] SciML 분야 라이브러리 기록
  • [chatGPT] chatGPT 프롬프트 엔지니어링
  • [Deep learning] Bayesian Neural Network (1)
  • [인공지능] Learning에서 scaling이 중요한가
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (459)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (6)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Numerical Analysis
    Zotero
    Julia
    MATLAB
    Python
    생산성
    obsidian
    IEEE
    고체역학
    우분투
    텝스
    teps
    인공지능
    Statics
    옵시디언
    LaTeX
    Linear algebra
    pytorch
    논문작성법
    에러기록
    텝스공부
    ChatGPT
    JAX
    matplotlib
    수치해석
    논문작성
    서버
    WOX
    딥러닝
    Dear abby
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] L-BFGS optimizer로 학습하는 예제 코드
상단으로

티스토리툴바