[인공지능] Learning에서 scaling이 중요한가

2023. 9. 20. 20:30·연구 Research/인공지능 Artificial Intelligent

 

인공지능을 하다보면 경험적으로 알게 되는 것들이 있는데 그 중 하나가 scaling의 문제이다.

간단한 regression 문제를 풀어보자.

 

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.example_libraries import optimizers

def diffusion(t, y, args):
    sigma = 0.15
    diff = sigma * jnp.sqrt(y)
    return diff

# Define a simple neural network model
def init_params(layer_sizes, key):
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i]))
        b = jnp.zeros(layer_sizes[i])
        params.extend([w, b])
    return params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, inputs):
    for i in range(0, len(params)-1, 2):
        w, b = params[i], params[i+1]
        outputs = jnp.dot(inputs, w) + b
        inputs = jax.nn.relu(outputs)
    return outputs

# Define a loss function
def mean_squared_error(params, inputs, targets):
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets)**2)

# Define a function to calculate the gradients of the loss function
@jax.jit
def grad_fun(params, inputs, targets):
    return jax.grad(mean_squared_error)(params, inputs, targets)

# Generate some training data
key = random.PRNGKey(0)

inputs = random.uniform(key, (100, 1), minval=0.0, maxval=1.0)
targets = jax.vmap(diffusion)([],inputs,[])

inputs = (inputs  - jnp.mean(inputs)) / (jnp.std(inputs))
# targets = (targets - jnp.min(targets)) / (jnp.max(targets)-jnp.min(targets))
targets = (targets - jnp.mean(targets)) / (jnp.std(targets))

# Initialize the parameters
layer_sizes = [1, 64, 64,64, 1]
params = init_params(layer_sizes, key)

# Create an optimizer
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(params)

# Define a function to update the parameters using the optimizer
@jax.jit
def update(i, opt_state, inputs, targets):
    params = get_params(opt_state)
    grads = grad_fun(params, inputs, targets)
    return opt_update(i, grads, opt_state)

# Train the network
num_iterations = 2000
for i in range(num_iterations):
    opt_state = update(i, opt_state, inputs, targets)
    params = get_params(opt_state)
    if i % 100 == 0:
        loss = mean_squared_error(params, inputs, targets)
        print(f"iteration {i+1}, loss {loss}")

 

코드는 길지만 Neural network를 정의하고 그걸 예측하는 문제를 만드는 것이다.

case는 다음과 같다.

 

input는 0에서 1 사이로 랜덤 샘플링한 값이고 regression하고자 하는 함수는 정해져있다.

 

1) input 값 그대로, target 값을 scaling없이 그대로 사용

2) input값 그대로, target 값을 min-max scaling 적용

3) input값, target값 모두 standardization scaling 적용

 

target값은 0에서 1 사이의 값으로 되어있기 때문에 scaling을 적용하지 않아도 괜찮다고 생각했다.

하지만 같은 네트워크 구조에 같은 iteration이어도 결과가 확연히 차이난다.

 

1)  target 값을 scaling없이 그대로 사용

 

빠르게 수렴하는 것이 불가능하다.

 

2) input값 그대로, target 값을 min-max scaling 적용

 

앞선 경우에 비해 훨씬 잘 수렴한다.

 

3) input값, target값 모두 standardization scaling 적용

 


 

다변수 함수를 regression하려고 하는 경우라면 정규화는 기본적으로 필요하다. 하지만 변수가 한 개밖에 없는 함수여도 scaling을 어떻게 하느냐에 따라 수렴에 영향을 미칠 수 있다.

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[JAX] L-BFGS optimizer로 학습하는 예제 코드  (0) 2023.11.02
[Deep learning] Bayesian Neural Network (1)  (0) 2023.10.23
[JAX] vmap과 jit의 속도  (0) 2023.09.20
[인공지능] 인공지능 라이브러리 정리  (0) 2023.08.24
[JAX] JAX에서 gradient 추척을 멈추는 방법  (0) 2023.08.22
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [JAX] L-BFGS optimizer로 학습하는 예제 코드
  • [Deep learning] Bayesian Neural Network (1)
  • [JAX] vmap과 jit의 속도
  • [인공지능] 인공지능 라이브러리 정리
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (460)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    수치해석
    우분투
    인공지능
    ChatGPT
    옵시디언
    Statics
    Dear abby
    논문작성법
    텝스
    obsidian
    생산성
    딥러닝
    Python
    텝스공부
    MATLAB
    Numerical Analysis
    LaTeX
    matplotlib
    에러기록
    WOX
    teps
    논문작성
    pytorch
    Julia
    JAX
    Zotero
    고체역학
    서버
    IEEE
    Linear algebra
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[인공지능] Learning에서 scaling이 중요한가
상단으로

티스토리툴바