Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[인공지능] Learning에서 scaling이 중요한가 본문

연구 Research/인공지능 Artificial Intelligent

[인공지능] Learning에서 scaling이 중요한가

보통의공대생 2023. 9. 20. 20:30

 

인공지능을 하다보면 경험적으로 알게 되는 것들이 있는데 그 중 하나가 scaling의 문제이다.

간단한 regression 문제를 풀어보자.

 

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.example_libraries import optimizers

def diffusion(t, y, args):
    sigma = 0.15
    diff = sigma * jnp.sqrt(y)
    return diff

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = jax.nn.relu(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
    return jax.grad(mean_squared_error)(params, inputs, targets)

# Generate some training data
key = random.PRNGKey(0)

inputs = random.uniform(key, (100, 1), minval=0.0, maxval=1.0)
targets = jax.vmap(diffusion)([],inputs,[])

inputs = (inputs  - jnp.mean(inputs)) / (jnp.std(inputs))
# targets = (targets - jnp.min(targets)) / (jnp.max(targets)-jnp.min(targets))
targets = (targets - jnp.mean(targets)) / (jnp.std(targets))

# Initialize the parameters
layer_sizes = [1, 64, 64,64, 1]
params = init_params(layer_sizes, key)

# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(params)

# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
    params = get_params(opt_state)
    grads = grad_fun(params, inputs, targets)
    return opt_update(i, grads, opt_state)

# Train the network
num_iterations = 2000
for i in range(num_iterations):
    opt_state = update(i, opt_state, inputs, targets)
    params = get_params(opt_state)
    if i % 100 == 0:
        loss = mean_squared_error(params, inputs, targets)
        print(f"iteration {i+1}, loss {loss}")

 

코드는 길지만 Neural network를 정의하고 그걸 예측하는 문제를 만드는 것이다.

case는 다음과 같다.

 

input는 0에서 1 사이로 랜덤 샘플링한 값이고 regression하고자 하는 함수는 정해져있다.

 

1) input 값 그대로, target 값을 scaling없이 그대로 사용

2) input값 그대로, target 값을 min-max scaling 적용

3) input값, target값 모두 standardization scaling 적용

 

target값은 0에서 1 사이의 값으로 되어있기 때문에 scaling을 적용하지 않아도 괜찮다고 생각했다.

하지만 같은 네트워크 구조에 같은 iteration이어도 결과가 확연히 차이난다.

 

1)  target 값을 scaling없이 그대로 사용

 

빠르게 수렴하는 것이 불가능하다.

 

2) input값 그대로, target 값을 min-max scaling 적용

 

앞선 경우에 비해 훨씬 잘 수렴한다.

 

3) input값, target값 모두 standardization scaling 적용

 


 

다변수 함수를 regression하려고 하는 경우라면 정규화는 기본적으로 필요하다. 하지만 변수가 한 개밖에 없는 함수여도 scaling을 어떻게 하느냐에 따라 수렴에 영향을 미칠 수 있다.

Comments