[JAX] Gaussian process 파라미터에 따른 결과 visualization

2023. 3. 24. 15:17·연구 Research/인공지능 Artificial Intelligent

Gaussian process에서 사용하는 커널 종류는 다양할 수 있지만 여기서는 Radial Basis Fuction을 이용해서 gaussian process 샘플들을 구하고 이에 대한 관찰을 시각화하는 방법에 대해서 이야기한다.

 

RBF 함수는 Paris Perdikaris 교수님의 수업자료를 참고하였다.

 

 

$k(x_1,x_2)=\eta \exp\left( \dfrac{(x_1 -x_2)^{2}}{2l^2}\right)$

커널함수가 이렇게 설정되어 있을 때 우리가 조절할 수 있는 파라미터는 scale factor인 $\eta$와 length인 $l$이다.

 

개념적으로 생각하였을 때 random process인 gaussian process는

$\mathbf{x}~\mathcal{N}(\mathbf{0}, K(\mathbf{x},\mathbf{x}))$ 분포를 가졌고 여기서 $K(\mathbf{x},\mathbf{x})$를 구성하는 커널이 위의 $k(x_1,x_2)$이다.

 

여기에서 gaussian process는 어떤 독립변수(여기서는 $x$라고 하자)에 따라 스칼라 값을 내뱉는 랜덤 프로세스이다.

$x=x_1$이라는 시점과 $x=x_{2}$라는 시점은 개별 시점에서 random variable으로 작용한다. 그리고 전체 궤적 ($f(X)$)는 어떤 sample space에서 sampling된 trajectory라고 볼 수 있다. 즉, sample space에서 한 번 추출된 trajectory는 deterministic function이 되는 것이다. 이걸 sample path라고 한다.

 

그런데 이런 random process는 각 독립변수에서 random variable이고 무한 차원까지 확장될 수 있기 때문에 이해하는 것이 정말 어렵다. 예를 들어 $x=0,1,2,\ldots,\infty$마다 다 다른 확률 분포를 가지고 있으면 이 프로세스를 이해할 방법이 별로 없다. 그래서 gaussian process는 각 $x_{1},x_{2}$ 사이의 관계를 특정 커널로 표현하여서 좀 더 간단하게 어떤 프로세스를 표현하고자 하였다.

 

 

실제로 관찰하면 아래와 같이 RBF를 정의한다. 

 

def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs = np.expand_dims(x1 / lengthscales, 1) - \
            np.expand_dims(x2 / lengthscales, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

 

이를 output_scale과 lengthscales를 바꿔가면서 구하면 다음과 같다.

key = random.PRNGKey(100)
N = 100
keys = random.split(key, N)
length_scale = np.array([1.0, 2.0, 4.0, 8.0]) # 4
ctr_scaler = np.array([1.0, 4.0, 25.0, 100.0]) # 4
gp_sample_mapping = vmap(vmap(vmap(draw_gp_sample, in_axes=(None, None, 0)), in_axes=(None, 0, None)), in_axes=(0,None,None))

gp_samples = gp_sample_mapping(keys, length_scale, ctr_scaler)
print(gp_samples.shape)
>>
(100, 4, 4, 512)
import matplotlib.pyplot as plt
import matplotlib as mpl

fig = plt.figure(figsize=(15,15))

std_array = np.array([1.,2.,5.,10.])
for i in range(length_scale.shape[0]):
    for j in range(ctr_scaler.shape[0]):
        num = i * 4 + j
        ax = fig.add_subplot(4,4,num+1)
        for k in range(N):
            plt.plot(gp_samples[k, i, j, :], alpha = 0.4)
            plt.title('length : '+str(length_scale[i])+' scale : '+str(ctr_scaler[j]))
        rect1 = mpl.patches.Rectangle((0,-2*std_array[j]), gp_samples.shape[3], 4*std_array[j], color='lightgray')
        ax.add_patch(rect1)
        plt.axhline(y=-std_array[j], color='yellow', linestyle='dotted')
        plt.axhline(y=std_array[j], color='yellow', linestyle='dotted')
plt.show()

 

 

 

다음 그림을 보면 $1\sigma$만큼은 노란색 점선으로 표기하고 $2\sigma$는 회색 직사각형으로 표시하였다.

다음과 같이 length scale이 증가하면 각 x 시점에서의 조밀성이 낮아진다. 또한 scale은 말그대로 결과로 나오는 값(y축 값)이 커지는데, scale =$1^2, 2^2, 5^2, 10^2$으로 했기 때문에 standard deviation이 1,2,5,10으로 증가한 것을 볼 수 있다.

 

kernel이 covariance matrix로 들어가기 때문에 평균은 0이지만 분산은 scale에 따라서 변하게 된다.

length scale은 클 수록 K의 element 값이 작아지는 효과가 있는데 그래서 $l$이 커질수록 샘플링에서의 거리감이 커진다.

 

그리고 위 그래프를 보면 length가 같으면 모두 같은 결과가 나오는데 이는 JAX에서 Key를 고정해놓고 랜덤을 돌려서 그런 것이다. scale은 값의 크기만 바꿀 뿐, K matrix를 바꿀 수 없지만, length를 바꾸면 K matrix의 eigenvalue 자체가 바뀌기 때문에 length가 바뀌면 샘플링 또한 달라지게 되었다.

 

이를 통해 gaussian process를 적절하게 생성하기 위한 파라미터 선택을 이해할 수 있다. 

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글

[JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기  (0) 2023.04.12
[JAX] 학습 중 NaN 값이 나올 때 찾는 방법  (0) 2023.03.28
[JAX] 기본 Neural Networks 모델  (0) 2023.03.23
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기  (2) 2023.03.22
[JAX] Cholesky decomposition error 230318 기준  (0) 2023.03.18
'연구 Research/인공지능 Artificial Intelligent' 카테고리의 다른 글
  • [JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기
  • [JAX] 학습 중 NaN 값이 나올 때 찾는 방법
  • [JAX] 기본 Neural Networks 모델
  • [JAX] Gradient, Jacobian, Hessian 등 미분값 구하기
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (460)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (21)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 유학 생활 Daily (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    teps
    논문작성법
    고체역학
    생산성
    Linear algebra
    matplotlib
    Numerical Analysis
    pytorch
    서버
    딥러닝
    논문작성
    텝스
    ChatGPT
    JAX
    Zotero
    MATLAB
    Python
    Julia
    우분투
    수치해석
    LaTeX
    WOX
    에러기록
    텝스공부
    Statics
    IEEE
    인공지능
    Dear abby
    obsidian
    옵시디언
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[JAX] Gaussian process 파라미터에 따른 결과 visualization
상단으로

티스토리툴바