[JAX] 기본 Neural Networks 모델

2023. 3. 23. 13:16·연구 Research/인공지능 Artificial Intelligent

가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.

차츰 업데이트 할 예정

 

1. 기본 학습 코드

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = relu(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
    predictions = vmap(predict, in_axes=(None,0))(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to update the parameters using gradient descent
@jit
def update(params, inputs, targets, learning_rate):
    grads = grad(mean_squared_error)(params, inputs, targets)
    return [(param - learning_rate * grad) for param, grad in zip(params, grads)]

def target_func(inputs):
    targets = jnp.heaviside(inputs, 0.0)
    return targets

# Generate some training data
key = random.PRNGKey(0)
inputs = random.normal(key, (100, 1))
targets = target_func(inputs)

# Initialize the parameters
layer_sizes = [1, 20, 1]
params = init_params(layer_sizes, key)

# Train the network
learning_rate = 0.1
for i in range(1000):
    params = update(params, inputs, targets, learning_rate)
    loss = mean_squared_error(params, inputs, targets)
    if (i % 100) == 0:
        print(f"iteration {i+1}, loss {loss}")

 

 

여기서는 간단하게 gradient를 구하고 이를 바탕으로 gradient descent를 적용하였다.

JAX에 있는 experimental.optimizers를 사용하는 방법이 있기 때문에 이를 이용한다.

 

 

2. JAX 내의 optimizer를 사용하는 경우

 

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.experimental import optimizers

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = relu(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
	predictions = predict(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
    return jax.grad(mean_squared_error)(params, inputs, targets)

# Generate some training data
key = random.PRNGKey(0)
inputs = random.normal(key, (100, 10))
targets = random.normal(key, (100, 1))

# Initialize the parameters
layer_sizes = [10, 20, 1]
params = init_params(layer_sizes, key)

# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(learning_rate=0.1)
opt_state = opt_init(params)

# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
    params = get_params(opt_state)
    grads = grad_fun(params, inputs, targets)
    return opt_update(i, grads, opt_state)

# Train the network
num_iterations = 100
for i in range(num_iterations):
    opt_state = update(i, opt_state, inputs, targets)
    params = get_params(opt_state)
    loss = mean_squared_error(params, inputs, targets)
    print(f"iteration {i+1}, loss {loss}")

 

업데이트 ( JAX 0.4.14 버전 기준 )

업데이트 이후에는 experimental.optimizers 가 사라졌다.

대신 optimizers를 example_libraries에 있는 것을 사용했는데, JAX에서는 optax 등의 다른 라이브러리 사용을 권장하고 있긴 하다.

 

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.example_libraries import optimizers

def target_func(y):
    sigma = 0.15
    diff = sigma * jnp.sqrt(y)
    return diff

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = jnp.tanh(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
    return jax.grad(mean_squared_error)(params, inputs, targets)

# Generate some training data
key = random.PRNGKey(0)

inputs = random.uniform(key, (100, 1), minval=0.0, maxval=1.0)
targets = jax.vmap(target_func)(inputs)

# Initialize the parameters
layer_sizes = [1, 64, 64, 1]
params = init_params(layer_sizes, key)

# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(params)

# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
    params = get_params(opt_state)
    grads = grad_fun(params, inputs, targets)
    return opt_update(i, grads, opt_state)

# Train the network
num_iterations = 5000
for i in range(num_iterations):
    opt_state = update(i, opt_state, inputs, targets)
    params = get_params(opt_state)
    if i % 100 == 0:
        loss = mean_squared_error(params, inputs, targets)
        print(f"iteration {i+1}, loss {loss}")
저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[JAX] 학습 중 NaN 값이 나올 때 찾는 방법  (0) 2023.03.28
[JAX] Gaussian process 파라미터에 따른 결과 visualization  (0) 2023.03.24
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기  (2) 2023.03.22
[JAX] Cholesky decomposition error 230318 기준  (0) 2023.03.18
[JAX] JAX vmap에 대한 설명  (0) 2023.03.17
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [JAX] 학습 중 NaN 값이 나올 때 찾는 방법
  • [JAX] Gaussian process 파라미터에 따른 결과 visualization
  • [JAX] Gradient, Jacobian, Hessian 등 미분값 구하기
  • [JAX] Cholesky decomposition error 230318 기준
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (460)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    텝스
    우분투
    텝스공부
    Linear algebra
    JAX
    논문작성법
    서버
    에러기록
    Statics
    Zotero
    Numerical Analysis
    Python
    고체역학
    Dear abby
    ChatGPT
    수치해석
    IEEE
    WOX
    MATLAB
    pytorch
    matplotlib
    논문작성
    LaTeX
    옵시디언
    인공지능
    teps
    딥러닝
    obsidian
    생산성
    Julia
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] 기본 Neural Networks 모델
상단으로

티스토리툴바