Notice
Recent Posts
Recent Comments
Link
관리 메뉴

뛰는 놈 위에 나는 공대생

[JAX] Gaussian process 파라미터에 따른 결과 visualization 본문

연구 Research/인공지능 Artificial Intelligent

[JAX] Gaussian process 파라미터에 따른 결과 visualization

보통의공대생 2023. 3. 24. 15:17

Gaussian process에서 사용하는 커널 종류는 다양할 수 있지만 여기서는 Radial Basis Fuction을 이용해서 gaussian process 샘플들을 구하고 이에 대한 관찰을 시각화하는 방법에 대해서 이야기한다.

 

RBF 함수는 Paris Perdikaris 교수님의 수업자료를 참고하였다.

 

 

$k(x_1,x_2)=\eta \exp\left( \dfrac{(x_1 -x_2)^{2}}{2l^2}\right)$

커널함수가 이렇게 설정되어 있을 때 우리가 조절할 수 있는 파라미터는 scale factor인 $\eta$와 length인 $l$이다.

 

개념적으로 생각하였을 때 random process인 gaussian process는

$\mathbf{x}~\mathcal{N}(\mathbf{0}, K(\mathbf{x},\mathbf{x}))$ 분포를 가졌고 여기서 $K(\mathbf{x},\mathbf{x})$를 구성하는 커널이 위의 $k(x_1,x_2)$이다.

 

여기에서 gaussian process는 어떤 독립변수(여기서는 $x$라고 하자)에 따라 스칼라 값을 내뱉는 랜덤 프로세스이다.

$x=x_1$이라는 시점과 $x=x_{2}$라는 시점은 개별 시점에서 random variable으로 작용한다. 그리고 전체 궤적 ($f(X)$)는 어떤 sample space에서 sampling된 trajectory라고 볼 수 있다. 즉, sample space에서 한 번 추출된 trajectory는 deterministic function이 되는 것이다. 이걸 sample path라고 한다.

 

그런데 이런 random process는 각 독립변수에서 random variable이고 무한 차원까지 확장될 수 있기 때문에 이해하는 것이 정말 어렵다. 예를 들어 $x=0,1,2,\ldots,\infty$마다 다 다른 확률 분포를 가지고 있으면 이 프로세스를 이해할 방법이 별로 없다. 그래서 gaussian process는 각 $x_{1},x_{2}$ 사이의 관계를 특정 커널로 표현하여서 좀 더 간단하게 어떤 프로세스를 표현하고자 하였다.

 

 

실제로 관찰하면 아래와 같이 RBF를 정의한다. 

 

def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs = np.expand_dims(x1 / lengthscales, 1) - \
            np.expand_dims(x2 / lengthscales, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

 

이를 output_scale과 lengthscales를 바꿔가면서 구하면 다음과 같다.

key = random.PRNGKey(100)
N = 100
keys = random.split(key, N)
length_scale = np.array([1.0, 2.0, 4.0, 8.0]) # 4
ctr_scaler = np.array([1.0, 4.0, 25.0, 100.0]) # 4
gp_sample_mapping = vmap(vmap(vmap(draw_gp_sample, in_axes=(None, None, 0)), in_axes=(None, 0, None)), in_axes=(0,None,None))

gp_samples = gp_sample_mapping(keys, length_scale, ctr_scaler)
print(gp_samples.shape)
>>
(100, 4, 4, 512)
import matplotlib.pyplot as plt
import matplotlib as mpl

fig = plt.figure(figsize=(15,15))

std_array = np.array([1.,2.,5.,10.])
for i in range(length_scale.shape[0]):
    for j in range(ctr_scaler.shape[0]):
        num = i * 4 + j
        ax = fig.add_subplot(4,4,num+1)
        for k in range(N):
            plt.plot(gp_samples[k, i, j, :], alpha = 0.4)
            plt.title('length : '+str(length_scale[i])+' scale : '+str(ctr_scaler[j]))
        rect1 = mpl.patches.Rectangle((0,-2*std_array[j]), gp_samples.shape[3], 4*std_array[j], color='lightgray')
        ax.add_patch(rect1)
        plt.axhline(y=-std_array[j], color='yellow', linestyle='dotted')
        plt.axhline(y=std_array[j], color='yellow', linestyle='dotted')
plt.show()

 

 

 

다음 그림을 보면 $1\sigma$만큼은 노란색 점선으로 표기하고 $2\sigma$는 회색 직사각형으로 표시하였다.

다음과 같이 length scale이 증가하면 각 x 시점에서의 조밀성이 낮아진다. 또한 scale은 말그대로 결과로 나오는 값(y축 값)이 커지는데, scale =$1^2, 2^2, 5^2, 10^2$으로 했기 때문에 standard deviation이 1,2,5,10으로 증가한 것을 볼 수 있다.

 

kernel이 covariance matrix로 들어가기 때문에 평균은 0이지만 분산은 scale에 따라서 변하게 된다.

length scale은 클 수록 K의 element 값이 작아지는 효과가 있는데 그래서 $l$이 커질수록 샘플링에서의 거리감이 커진다.

 

그리고 위 그래프를 보면 length가 같으면 모두 같은 결과가 나오는데 이는 JAX에서 Key를 고정해놓고 랜덤을 돌려서 그런 것이다. scale은 값의 크기만 바꿀 뿐, K matrix를 바꿀 수 없지만, length를 바꾸면 K matrix의 eigenvalue 자체가 바뀌기 때문에 length가 바뀌면 샘플링 또한 달라지게 되었다.

 

이를 통해 gaussian process를 적절하게 생성하기 위한 파라미터 선택을 이해할 수 있다. 

 

Comments