Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- 논문작성법
- Linear algebra
- 수치해석
- obsidian
- pytorch
- WOX
- 우분투
- ChatGPT
- 논문작성
- Dear abby
- LaTeX
- Zotero
- 텝스
- MATLAB
- 딥러닝
- Statics
- JAX
- 에러기록
- IEEE
- teps
- matplotlib
- 수식삽입
- 고체역학
- 인공지능
- 텝스공부
- Python
- 생산성
- 옵시디언
- Julia
- Numerical Analysis
Archives
- Today
- Total
뛰는 놈 위에 나는 공대생
[JAX] 기본 Neural Networks 모델 본문
가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.
차츰 업데이트 할 예정
1. 기본 학습 코드
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# Define a simple neural network model
def init_params(layer_sizes, key):
params = []
for i in range(1, len(layer_sizes)):
key, subkey = random.split(key)
w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
b = jnp.zeros(layer_sizes[i])
params.extend([w, b])
return params
def relu(x):
return jnp.maximum(0, x)
def predict(params, inputs):
for i in range(0, len(params)-1, 2):
w, b = params[i], params[i+1]
outputs = jnp.dot(inputs, w) + b
inputs = relu(outputs)
return outputs
# Define a loss function
def mean_squared_error(params, inputs, targets):
predictions = vmap(predict, in_axes=(None,0))(params, inputs)
return jnp.mean((predictions - targets)**2)
# Define a function to update the parameters using gradient descent
@jit
def update(params, inputs, targets, learning_rate):
grads = grad(mean_squared_error)(params, inputs, targets)
return [(param - learning_rate * grad) for param, grad in zip(params, grads)]
def target_func(inputs):
targets = jnp.heaviside(inputs, 0.0)
return targets
# Generate some training data
key = random.PRNGKey(0)
inputs = random.normal(key, (100, 1))
targets = target_func(inputs)
# Initialize the parameters
layer_sizes = [1, 20, 1]
params = init_params(layer_sizes, key)
# Train the network
learning_rate = 0.1
for i in range(1000):
params = update(params, inputs, targets, learning_rate)
loss = mean_squared_error(params, inputs, targets)
if (i % 100) == 0:
print(f"iteration {i+1}, loss {loss}")
여기서는 간단하게 gradient를 구하고 이를 바탕으로 gradient descent를 적용하였다.
JAX에 있는 experimental.optimizers를 사용하는 방법이 있기 때문에 이를 이용한다.
2. JAX 내의 optimizer를 사용하는 경우
import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.experimental import optimizers
# Define a simple neural network model
def init_params(layer_sizes, key):
params = []
for i in range(1, len(layer_sizes)):
key, subkey = random.split(key)
w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
b = jnp.zeros(layer_sizes[i])
params.extend([w, b])
return params
def relu(x):
return jnp.maximum(0, x)
def predict(params, inputs):
for i in range(0, len(params)-1, 2):
w, b = params[i], params[i+1]
outputs = jnp.dot(inputs, w) + b
inputs = relu(outputs)
return outputs
# Define a loss function
def mean_squared_error(params, inputs, targets):
predictions = predict(params, inputs)
return jnp.mean((predictions - targets)**2)
# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
return jax.grad(mean_squared_error)(params, inputs, targets)
# Generate some training data
key = random.PRNGKey(0)
inputs = random.normal(key, (100, 10))
targets = random.normal(key, (100, 1))
# Initialize the parameters
layer_sizes = [10, 20, 1]
params = init_params(layer_sizes, key)
# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(learning_rate=0.1)
opt_state = opt_init(params)
# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
params = get_params(opt_state)
grads = grad_fun(params, inputs, targets)
return opt_update(i, grads, opt_state)
# Train the network
num_iterations = 100
for i in range(num_iterations):
opt_state = update(i, opt_state, inputs, targets)
params = get_params(opt_state)
loss = mean_squared_error(params, inputs, targets)
print(f"iteration {i+1}, loss {loss}")
업데이트 ( JAX 0.4.14 버전 기준 )
업데이트 이후에는 experimental.optimizers 가 사라졌다.
대신 optimizers를 example_libraries에 있는 것을 사용했는데, JAX에서는 optax 등의 다른 라이브러리 사용을 권장하고 있긴 하다.
import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.example_libraries import optimizers
def target_func(y):
sigma = 0.15
diff = sigma * jnp.sqrt(y)
return diff
# Define a simple neural network model
def init_params(layer_sizes, key):
params = []
for i in range(1, len(layer_sizes)):
key, subkey = random.split(key)
w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
b = jnp.zeros(layer_sizes[i])
params.extend([w, b])
return params
def relu(x):
return jnp.maximum(0, x)
def predict(params, inputs):
for i in range(0, len(params)-1, 2):
w, b = params[i], params[i+1]
outputs = jnp.dot(inputs, w) + b
inputs = jnp.tanh(outputs)
return outputs
# Define a loss function
def mean_squared_error(params, inputs, targets):
predictions = predict(params, inputs)
return jnp.mean((predictions - targets)**2)
# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
return jax.grad(mean_squared_error)(params, inputs, targets)
# Generate some training data
key = random.PRNGKey(0)
inputs = random.uniform(key, (100, 1), minval=0.0, maxval=1.0)
targets = jax.vmap(target_func)(inputs)
# Initialize the parameters
layer_sizes = [1, 64, 64, 1]
params = init_params(layer_sizes, key)
# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(params)
# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
params = get_params(opt_state)
grads = grad_fun(params, inputs, targets)
return opt_update(i, grads, opt_state)
# Train the network
num_iterations = 5000
for i in range(num_iterations):
opt_state = update(i, opt_state, inputs, targets)
params = get_params(opt_state)
if i % 100 == 0:
loss = mean_squared_error(params, inputs, targets)
print(f"iteration {i+1}, loss {loss}")
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] 학습 중 NaN 값이 나올 때 찾는 방법 (0) | 2023.03.28 |
---|---|
[JAX] Gaussian process 파라미터에 따른 결과 visualization (0) | 2023.03.24 |
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 (2) | 2023.03.22 |
[JAX] Cholesky decomposition error 230318 기준 (0) | 2023.03.18 |
[JAX] JAX vmap에 대한 설명 (0) | 2023.03.17 |
Comments