일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- LaTeX
- WOX
- obsidian
- Julia
- matplotlib
- MATLAB
- Statics
- Numerical Analysis
- 생산성
- pytorch
- 우분투
- 텝스
- 딥러닝
- Linear algebra
- IEEE
- 수식삽입
- teps
- ChatGPT
- 에러기록
- 옵시디언
- 논문작성법
- 수치해석
- Dear abby
- 텝스공부
- Zotero
- 고체역학
- JAX
- 인공지능
- 논문작성
- Python
- Today
- Total
뛰는 놈 위에 나는 공대생
[JAX] Gaussian process 파라미터에 따른 결과 visualization 본문
[JAX] Gaussian process 파라미터에 따른 결과 visualization
보통의공대생 2023. 3. 24. 15:17Gaussian process에서 사용하는 커널 종류는 다양할 수 있지만 여기서는 Radial Basis Fuction을 이용해서 gaussian process 샘플들을 구하고 이에 대한 관찰을 시각화하는 방법에 대해서 이야기한다.
RBF 함수는 Paris Perdikaris 교수님의 수업자료를 참고하였다.
$k(x_1,x_2)=\eta \exp\left( \dfrac{(x_1 -x_2)^{2}}{2l^2}\right)$
커널함수가 이렇게 설정되어 있을 때 우리가 조절할 수 있는 파라미터는 scale factor인 $\eta$와 length인 $l$이다.
개념적으로 생각하였을 때 random process인 gaussian process는
$\mathbf{x}~\mathcal{N}(\mathbf{0}, K(\mathbf{x},\mathbf{x}))$ 분포를 가졌고 여기서 $K(\mathbf{x},\mathbf{x})$를 구성하는 커널이 위의 $k(x_1,x_2)$이다.
여기에서 gaussian process는 어떤 독립변수(여기서는 $x$라고 하자)에 따라 스칼라 값을 내뱉는 랜덤 프로세스이다.
$x=x_1$이라는 시점과 $x=x_{2}$라는 시점은 개별 시점에서 random variable으로 작용한다. 그리고 전체 궤적 ($f(X)$)는 어떤 sample space에서 sampling된 trajectory라고 볼 수 있다. 즉, sample space에서 한 번 추출된 trajectory는 deterministic function이 되는 것이다. 이걸 sample path라고 한다.
그런데 이런 random process는 각 독립변수에서 random variable이고 무한 차원까지 확장될 수 있기 때문에 이해하는 것이 정말 어렵다. 예를 들어 $x=0,1,2,\ldots,\infty$마다 다 다른 확률 분포를 가지고 있으면 이 프로세스를 이해할 방법이 별로 없다. 그래서 gaussian process는 각 $x_{1},x_{2}$ 사이의 관계를 특정 커널로 표현하여서 좀 더 간단하게 어떤 프로세스를 표현하고자 하였다.
실제로 관찰하면 아래와 같이 RBF를 정의한다.
def RBF(x1, x2, params):
output_scale, lengthscales = params
diffs = np.expand_dims(x1 / lengthscales, 1) - \
np.expand_dims(x2 / lengthscales, 0)
r2 = np.sum(diffs**2, axis=2)
return output_scale * np.exp(-0.5 * r2)
이를 output_scale과 lengthscales를 바꿔가면서 구하면 다음과 같다.
key = random.PRNGKey(100)
N = 100
keys = random.split(key, N)
length_scale = np.array([1.0, 2.0, 4.0, 8.0]) # 4
ctr_scaler = np.array([1.0, 4.0, 25.0, 100.0]) # 4
gp_sample_mapping = vmap(vmap(vmap(draw_gp_sample, in_axes=(None, None, 0)), in_axes=(None, 0, None)), in_axes=(0,None,None))
gp_samples = gp_sample_mapping(keys, length_scale, ctr_scaler)
print(gp_samples.shape)
>>
(100, 4, 4, 512)
import matplotlib.pyplot as plt
import matplotlib as mpl
fig = plt.figure(figsize=(15,15))
std_array = np.array([1.,2.,5.,10.])
for i in range(length_scale.shape[0]):
for j in range(ctr_scaler.shape[0]):
num = i * 4 + j
ax = fig.add_subplot(4,4,num+1)
for k in range(N):
plt.plot(gp_samples[k, i, j, :], alpha = 0.4)
plt.title('length : '+str(length_scale[i])+' scale : '+str(ctr_scaler[j]))
rect1 = mpl.patches.Rectangle((0,-2*std_array[j]), gp_samples.shape[3], 4*std_array[j], color='lightgray')
ax.add_patch(rect1)
plt.axhline(y=-std_array[j], color='yellow', linestyle='dotted')
plt.axhline(y=std_array[j], color='yellow', linestyle='dotted')
plt.show()
다음 그림을 보면 $1\sigma$만큼은 노란색 점선으로 표기하고 $2\sigma$는 회색 직사각형으로 표시하였다.
다음과 같이 length scale이 증가하면 각 x 시점에서의 조밀성이 낮아진다. 또한 scale은 말그대로 결과로 나오는 값(y축 값)이 커지는데, scale =$1^2, 2^2, 5^2, 10^2$으로 했기 때문에 standard deviation이 1,2,5,10으로 증가한 것을 볼 수 있다.
kernel이 covariance matrix로 들어가기 때문에 평균은 0이지만 분산은 scale에 따라서 변하게 된다.
length scale은 클 수록 K의 element 값이 작아지는 효과가 있는데 그래서 $l$이 커질수록 샘플링에서의 거리감이 커진다.
그리고 위 그래프를 보면 length가 같으면 모두 같은 결과가 나오는데 이는 JAX에서 Key를 고정해놓고 랜덤을 돌려서 그런 것이다. scale은 값의 크기만 바꿀 뿐, K matrix를 바꿀 수 없지만, length를 바꾸면 K matrix의 eigenvalue 자체가 바뀌기 때문에 length가 바뀌면 샘플링 또한 달라지게 되었다.
이를 통해 gaussian process를 적절하게 생성하기 위한 파라미터 선택을 이해할 수 있다.
'연구 Research > 인공지능 Artificial Intelligent' 카테고리의 다른 글
[JAX] JAX와 Torch, CUDA, cudnn 버전 맞추기 (0) | 2023.04.12 |
---|---|
[JAX] 학습 중 NaN 값이 나올 때 찾는 방법 (0) | 2023.03.28 |
[JAX] 기본 Neural Networks 모델 (0) | 2023.03.23 |
[JAX] Gradient, Jacobian, Hessian 등 미분값 구하기 (2) | 2023.03.22 |
[JAX] Cholesky decomposition error 230318 기준 (0) | 2023.03.18 |