[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제

2023. 6. 13. 20:00·연구 Research/인공지능 Artificial Intelligent

JAX를 통해 병렬로 뉴럴 네트워크를 학습하는 예제를 고민하였다.

JAX에서 제공해주는 예제도 있지만 이는 아주 심플한 선형 모델의 파라미터를 regression하는 문제이기 때문에 실제 뉴럴 네트워크 모델과는 괴리가 좀 있어서 직접 예제를 만들었다.

 

JAX 0.3.1. 버전

 

1. JAX에서 제공하는 예제

 

import jax
jax.devices()

>> [GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0)]

 

필자는 gpu 4개를 가지고 병렬 컴퓨팅을 사용했다.

 

import numpy as np
import jax.numpy as jnp

from typing import NamedTuple, Tuple
import functools
 
# class for storing model parameters
class Params(NamedTuple):
    weight: jnp.ndarray
    bias: jnp.ndarray
    
# function for initializing model parameters
def init(rng) -> Params:
    """Returns the initial model params."""
    weights_key, bias_key = jax.random.split(rng)
    weight = jax.random.normal(weights_key, ())
    bias = jax.random.normal(bias_key, ())
    return Params(weight, bias)

LEARNING_RATE = 0.005

# function for computing the MSE loss
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
    """Computes the least squares error of the model's predictions on x against y."""
    pred = params.weight * xs + params.bias
    return jnp.mean((pred - ys) ** 2)
 
# function for performing one SGD update step (fwd & bwd pass)
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
    loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
    grads = jax.lax.pmean(grads, axis_name='num_devices')
    loss = jax.lax.pmean(loss, axis_name='num_devices')
    new_params = jax.tree_map(
        lambda param, g: param - g * LEARNING_RATE, params, grads)
    return new_params, loss

위 코드에서는 파라미터를 초기화하는 함수, 손실 함수, 파라미터를 업데이트하는 함수를 각각 정의한다.

 

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise

# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

추정하고자 하는 파라미터 true_w, true_b는 2와 -1로 설정되어있다.

이를 바탕으로 학습에 사용할 xs,ys를 얻는다.

추정하고자 하는 파라미터(w,b)를 일단 각 디바이스 갯수만큼 불러와서 replicated_params에 넣는다.

 

def split(arr):
    """Splits the first axis of `arr` evenly across the number of devices."""
    return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)

type(x_split)

 

학습할 데이터는 128개의 (xs,ys)쌍이었는데 이를 병렬로 사용하기 위해 (4,32,1) shape으로 바꾼다. (device 개수가 4개이므로)

 

def type_after_update(name, obj):
  print(f"after first `update()`, `{name}` is a", type(obj))

# Actual training loop.
for i in range(1000):
    # This is where the params and data gets communicated to devices:
    replicated_params, loss = update(replicated_params, x_split, y_split)
    
    # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
    # indicating that they're on the devices.
    # `x_split`, of course, remains a NumPy array on the host.
    if i == 0:
        type_after_update('replicated_params.weight', replicated_params.weight)
        type_after_update('loss', loss)
        type_after_update('x_split', x_split)
        
    if i % 100 == 0:
        # Note that loss is actually an array of shape [num_devices], with identical
        # entries, because each device returns its copy of the loss.
        # So, we take the first element to print it.
        print(f"Step {i:3d}, loss: {loss[0]:.3f}")


# Plot results.

# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))

그 다음에 업데이트 함수를 통해 파라미터를 구하면 끝.

 

아래는 학습에서 출력된 결과이다.

after first `update()`, `replicated_params.weight` is a <class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>
after first `update()`, `loss` is a <class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>
after first `update()`, `x_split` is a <class 'numpy.ndarray'>
Step   0, loss: 5.860
Step 100, loss: 0.837
Step 200, loss: 0.280
Step 300, loss: 0.218
Step 400, loss: 0.212
Step 500, loss: 0.211
Step 600, loss: 0.211
Step 700, loss: 0.211
Step 800, loss: 0.211
Step 900, loss: 0.211

 

 

결과를 출력하면 다음과 같이 비교적 정확한 w,b를 추정한 것을 알 수 있다.

 

import matplotlib.pyplot as plt
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend()
plt.show()

 

 

여기까지 하면 병렬로 한 것은 알겠는데 실제 뉴럴 네트워크를 학습하려면 어떻게 해야할지 막막할 수 있다.

 

 

2. 내가 만든 뉴럴 네트워크 예제

 

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
from jax import random
from jax.example_libraries import optimizers
from typing import NamedTuple, Tuple
import functools
 
# class for storing model parameters
class Params(NamedTuple):
    weight: jnp.ndarray
    bias: jnp.ndarray

# Define the neural network
def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    w, b = params
    print(w.shape)
    print(b.shape)
    return relu(jnp.dot(inputs, w) + b)

def loss_fn(params, inputs, targets):
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

# Initialize the parameters
def initialize_params(input_size, output_size, key):
    params = []
    w_key, b_key = random.split(key)
    w = random.normal(w_key, (input_size, output_size))
    b = random.normal(b_key, (output_size,))
    params.extend([w, b])
    return params

# Initialize devices
num_devices = jax.local_device_count()
devices = jax.local_devices()
print(f'Using {num_devices} devices')

# Prepare data
input_size, output_size = 10, 1
batch_size = 64
num_batches = 1000

keys = random.split(random.PRNGKey(0), num_devices)
params = initialize_params(input_size, num_devices, keys[0])

# Generate some dummy data
inputs = random.normal(random.PRNGKey(1), (num_devices, batch_size, input_size))
targets = random.normal(random.PRNGKey(2), (num_devices, batch_size, output_size))

# Initialize the Adam optimizer
learning_rate = 0.001
init_fun, update_fun, get_params = optimizers.adam(learning_rate)
opt_state = init_fun(params)
replicated_params = jax.tree_map(lambda x: jnp.array([x] * num_devices), params)
# Define the update function
@functools.partial(jax.pmap, axis_size=4,axis_name='num_devices')
def update(params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:

    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
    
    grads = jax.lax.pmean(grads, axis_name='num_devices')
    loss = jax.lax.pmean(loss, axis_name='num_devices')

    opt_state = init_fun(params)
    update_opt_state= update_fun(0, grads, opt_state)
    update_params = get_params( update_opt_state )
    
    # new params
    params_new = []
    params_new.extend([update_params[0], update_params[1]])
    print("update : ", update_params[0].shape)
    new_params = jax.tree_map(lambda x: jnp.array([x] * num_devices), params_new)
    print("new params : ", new_params[0].shape)
    return params_new, loss

num_iter = 3
# Parallel training
for i in range(num_iter):

    replicated_params, loss = update(replicated_params, inputs, targets)
    print("replication : ", replicated_params[0].shape)
    print(f"iteration : {i}, loss : {loss}")
    
print("Training complete.")

 

 

출력 결과

(10, 4) #<- prediction에 있는 weight
(4,) #<- prediction에 있는 bias
update :  (10, 4)
new params :  (4, 10, 4)
replication :  (4, 10, 4)
iteration : 0, loss : [5.9947453 5.9947453 5.9947453 5.9947453]
replication :  (4, 10, 4)
iteration : 1, loss : [5.983531 5.983531 5.983531 5.983531]
replication :  (4, 10, 4)
iteration : 2, loss : [5.972336 5.972336 5.972336 5.972336]
Training complete.

 

저작자표시 비영리 변경금지

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[JAX] JAX 기반 Neural ODE 라이브러리 : diffrax  (0) 2023.07.28
[JAX] 학습한 모델 저장 및 로드  (0) 2023.06.19
[인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념  (0) 2023.05.16
[JAX] 메모리 부족 문제 해결  (1) 2023.04.26
[인공지능] Ubuntu 18.04에서 CUDA, CuDNN 설치  (1) 2023.04.16
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [JAX] JAX 기반 Neural ODE 라이브러리 : diffrax
  • [JAX] 학습한 모델 저장 및 로드
  • [인공지능] 딥러닝, 머신러닝에서 uncertainty/error 개념
  • [JAX] 메모리 부족 문제 해결
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (458)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability &amp; Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (6)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    MATLAB
    딥러닝
    Statics
    Linear algebra
    수식삽입
    LaTeX
    논문작성법
    JAX
    obsidian
    Numerical Analysis
    Python
    옵시디언
    텝스공부
    에러기록
    Zotero
    고체역학
    teps
    ChatGPT
    matplotlib
    인공지능
    텝스
    Julia
    WOX
    pytorch
    수치해석
    우분투
    논문작성
    Dear abby
    생산성
    IEEE
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] 병렬컴퓨팅 예제 - jax.pmap으로 신경망 학습 예제
상단으로

티스토리툴바