Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] 기본 Neural Networks 모델 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] 기본 Neural Networks 모델

보통의공대생 2023. 3. 23. 13:16

가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.

차츰 업데이트 할 예정

 

1. 기본 학습 코드

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = relu(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
    predictions = vmap(predict, in_axes=(None,0))(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to update the parameters using gradient descent
@jit
def update(params, inputs, targets, learning_rate):
    grads = grad(mean_squared_error)(params, inputs, targets)
    return [(param - learning_rate * grad) for param, grad in zip(params, grads)]

def target_func(inputs):
    targets = jnp.heaviside(inputs, 0.0)
    return targets

# Generate some training data
key = random.PRNGKey(0)
inputs = random.normal(key, (100, 1))
targets = target_func(inputs)

# Initialize the parameters
layer_sizes = [1, 20, 1]
params = init_params(layer_sizes, key)

# Train the network
learning_rate = 0.1
for i in range(1000):
    params = update(params, inputs, targets, learning_rate)
    loss = mean_squared_error(params, inputs, targets)
    if (i % 100) == 0:
        print(f"iteration {i+1}, loss {loss}")

 

 

여기서는 간단하게 gradient를 구하고 이를 바탕으로 gradient descent를 적용하였다.

JAX에 있는 experimental.optimizers를 사용하는 방법이 있기 때문에 이를 이용한다.

 

 

2. JAX 내의 optimizer를 사용하는 경우

 

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.experimental import optimizers

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = relu(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
	predictions = predict(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
    return jax.grad(mean_squared_error)(params, inputs, targets)

# Generate some training data
key = random.PRNGKey(0)
inputs = random.normal(key, (100, 10))
targets = random.normal(key, (100, 1))

# Initialize the parameters
layer_sizes = [10, 20, 1]
params = init_params(layer_sizes, key)

# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(learning_rate=0.1)
opt_state = opt_init(params)

# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
    params = get_params(opt_state)
    grads = grad_fun(params, inputs, targets)
    return opt_update(i, grads, opt_state)

# Train the network
num_iterations = 100
for i in range(num_iterations):
    opt_state = update(i, opt_state, inputs, targets)
    params = get_params(opt_state)
    loss = mean_squared_error(params, inputs, targets)
    print(f"iteration {i+1}, loss {loss}")

 

업데이트 ( JAX 0.4.14 버전 기준 )

업데이트 이후에는 experimental.optimizers 가 사라졌다.

대신 optimizers를 example_libraries에 있는 것을 사용했는데, JAX에서는 optax 등의 다른 라이브러리 사용을 권장하고 있긴 하다.

 

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.example_libraries import optimizers

def target_func(y):
    sigma = 0.15
    diff = sigma * jnp.sqrt(y)
    return diff

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = jnp.tanh(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
    return jax.grad(mean_squared_error)(params, inputs, targets)

# Generate some training data
key = random.PRNGKey(0)

inputs = random.uniform(key, (100, 1), minval=0.0, maxval=1.0)
targets = jax.vmap(target_func)(inputs)

# Initialize the parameters
layer_sizes = [1, 64, 64, 1]
params = init_params(layer_sizes, key)

# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(params)

# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
    params = get_params(opt_state)
    grads = grad_fun(params, inputs, targets)
    return opt_update(i, grads, opt_state)

# Train the network
num_iterations = 5000
for i in range(num_iterations):
    opt_state = update(i, opt_state, inputs, targets)
    params = get_params(opt_state)
    if i % 100 == 0:
        loss = mean_squared_error(params, inputs, targets)
        print(f"iteration {i+1}, loss {loss}")
Comments