JAX를 통해 병렬로 뉴럴 네트워크를 학습하는 예제를 고민하였다.
JAX에서 제공해주는 예제도 있지만 이는 아주 심플한 선형 모델의 파라미터를 regression하는 문제이기 때문에 실제 뉴럴 네트워크 모델과는 괴리가 좀 있어서 직접 예제를 만들었다.
JAX 0.3.1. 버전
1. JAX에서 제공하는 예제
import jax
jax.devices()
>> [GpuDevice(id=0, process_index=0),
GpuDevice(id=1, process_index=0),
GpuDevice(id=2, process_index=0),
GpuDevice(id=3, process_index=0)]
필자는 gpu 4개를 가지고 병렬 컴퓨팅을 사용했다.
import numpy as np
import jax.numpy as jnp
from typing import NamedTuple, Tuple
import functools
# class for storing model parameters
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
# function for initializing model parameters
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
LEARNING_RATE = 0.005
# function for computing the MSE loss
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * xs + params.bias
return jnp.mean((pred - ys) ** 2)
# function for performing one SGD update step (fwd & bwd pass)
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
grads = jax.lax.pmean(grads, axis_name='num_devices')
loss = jax.lax.pmean(loss, axis_name='num_devices')
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grads)
return new_params, loss
위 코드에서는 파라미터를 초기화하는 함수, 손실 함수, 파라미터를 업데이트하는 함수를 각각 정의한다.
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
추정하고자 하는 파라미터 true_w, true_b는 2와 -1로 설정되어있다.
이를 바탕으로 학습에 사용할 xs,ys를 얻는다.
추정하고자 하는 파라미터(w,b)를 일단 각 디바이스 갯수만큼 불러와서 replicated_params에 넣는다.
def split(arr):
"""Splits the first axis of `arr` evenly across the number of devices."""
return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])
# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)
type(x_split)
학습할 데이터는 128개의 (xs,ys)쌍이었는데 이를 병렬로 사용하기 위해 (4,32,1) shape으로 바꾼다. (device 개수가 4개이므로)
def type_after_update(name, obj):
print(f"after first `update()`, `{name}` is a", type(obj))
# Actual training loop.
for i in range(1000):
# This is where the params and data gets communicated to devices:
replicated_params, loss = update(replicated_params, x_split, y_split)
# The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
# indicating that they're on the devices.
# `x_split`, of course, remains a NumPy array on the host.
if i == 0:
type_after_update('replicated_params.weight', replicated_params.weight)
type_after_update('loss', loss)
type_after_update('x_split', x_split)
if i % 100 == 0:
# Note that loss is actually an array of shape [num_devices], with identical
# entries, because each device returns its copy of the loss.
# So, we take the first element to print it.
print(f"Step {i:3d}, loss: {loss[0]:.3f}")
# Plot results.
# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))
그 다음에 업데이트 함수를 통해 파라미터를 구하면 끝.
아래는 학습에서 출력된 결과이다.
after first `update()`, `replicated_params.weight` is a <class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>
after first `update()`, `loss` is a <class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>
after first `update()`, `x_split` is a <class 'numpy.ndarray'>
Step 0, loss: 5.860
Step 100, loss: 0.837
Step 200, loss: 0.280
Step 300, loss: 0.218
Step 400, loss: 0.212
Step 500, loss: 0.211
Step 600, loss: 0.211
Step 700, loss: 0.211
Step 800, loss: 0.211
Step 900, loss: 0.211
결과를 출력하면 다음과 같이 비교적 정확한 w,b를 추정한 것을 알 수 있다.
import matplotlib.pyplot as plt
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend()
plt.show()
여기까지 하면 병렬로 한 것은 알겠는데 실제 뉴럴 네트워크를 학습하려면 어떻게 해야할지 막막할 수 있다.
2. 내가 만든 뉴럴 네트워크 예제
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
from jax import random
from jax.example_libraries import optimizers
from typing import NamedTuple, Tuple
import functools
# class for storing model parameters
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
# Define the neural network
def relu(x):
return jnp.maximum(0, x)
def predict(params, inputs):
w, b = params
print(w.shape)
print(b.shape)
return relu(jnp.dot(inputs, w) + b)
def loss_fn(params, inputs, targets):
predictions = predict(params, inputs)
return jnp.mean((predictions - targets) ** 2)
# Define a simple neural network model
def init_params(layer_sizes, key):
params = []
for i in range(1, len(layer_sizes)):
key, subkey = random.split(key)
w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
b = jnp.zeros(layer_sizes[i])
params.extend([w, b])
return params
# Initialize the parameters
def initialize_params(input_size, output_size, key):
params = []
w_key, b_key = random.split(key)
w = random.normal(w_key, (input_size, output_size))
b = random.normal(b_key, (output_size,))
params.extend([w, b])
return params
# Initialize devices
num_devices = jax.local_device_count()
devices = jax.local_devices()
print(f'Using {num_devices} devices')
# Prepare data
input_size, output_size = 10, 1
batch_size = 64
num_batches = 1000
keys = random.split(random.PRNGKey(0), num_devices)
params = initialize_params(input_size, num_devices, keys[0])
# Generate some dummy data
inputs = random.normal(random.PRNGKey(1), (num_devices, batch_size, input_size))
targets = random.normal(random.PRNGKey(2), (num_devices, batch_size, output_size))
# Initialize the Adam optimizer
learning_rate = 0.001
init_fun, update_fun, get_params = optimizers.adam(learning_rate)
opt_state = init_fun(params)
replicated_params = jax.tree_map(lambda x: jnp.array([x] * num_devices), params)
# Define the update function
@functools.partial(jax.pmap, axis_size=4,axis_name='num_devices')
def update(params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
# Compute the gradients on the given minibatch (individually on each device).
loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
grads = jax.lax.pmean(grads, axis_name='num_devices')
loss = jax.lax.pmean(loss, axis_name='num_devices')
opt_state = init_fun(params)
update_opt_state= update_fun(0, grads, opt_state)
update_params = get_params( update_opt_state )
# new params
params_new = []
params_new.extend([update_params[0], update_params[1]])
print("update : ", update_params[0].shape)
new_params = jax.tree_map(lambda x: jnp.array([x] * num_devices), params_new)
print("new params : ", new_params[0].shape)
return params_new, loss
num_iter = 3
# Parallel training
for i in range(num_iter):
replicated_params, loss = update(replicated_params, inputs, targets)
print("replication : ", replicated_params[0].shape)
print(f"iteration : {i}, loss : {loss}")
print("Training complete.")
출력 결과
(10, 4) #<- prediction에 있는 weight
(4,) #<- prediction에 있는 bias
update : (10, 4)
new params : (4, 10, 4)
replication : (4, 10, 4)
iteration : 0, loss : [5.9947453 5.9947453 5.9947453 5.9947453]
replication : (4, 10, 4)
iteration : 1, loss : [5.983531 5.983531 5.983531 5.983531]
replication : (4, 10, 4)
iteration : 2, loss : [5.972336 5.972336 5.972336 5.972336]
Training complete.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax (0) | 2023.07.28 |
---|---|
[JAX] 학습한 모델 저장 및 로드 (0) | 2023.06.19 |
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념 (0) | 2023.05.16 |
[JAX] 메모리 부족 문제 해결 (1) | 2023.04.26 |
[인공지능] Ubuntu 18.04에서 CUDA, CuDNN 설치 (1) | 2023.04.16 |