ニューラルネットワークの予測の不確実性(deep Gaussian process・概要編)


はじめに

deep Gaussian processでニューラルネットワークの予測の不確実性を算出する手法を紹介します。

Gaussian process (ガウス過程)

ガウス過程では、データセット$${{(x_1, y_1), ..., (x_n, y_n)}}$$の関数$$${y=f(x)}$$を、多変量正規分布を用いてモデル化します。
平均関数$${m(x_i)}$$とカーネル関数$${k(x_i,x_j)}$$を用いて下記のように定義します。

$$
p\left(\begin{bmatrix}
y_1 \\
\vdots \\
y_n
\end{bmatrix}\right)
= \mathcal{N}\left(
\begin{bmatrix}
y_1 \\
\vdots \\
y_n
\end{bmatrix}
\middle|
\begin{bmatrix}
m(x_1) \\
\vdots \\
m(x_n)
\end{bmatrix}
,
\begin{bmatrix}
k(x_1,x_1) & \dots & k(x_1,x_n)\\
\vdots & \ddots & \vdots \\
k(x_n,x_1) & \dots & k(x_n,x_n)
\end{bmatrix}
\right)
$$

このように定義される関数を$${f\sim GP(m,k)}$$と書きます。

$${Y=[y_1,...,y_n]^T, X=[x_1, ..., x_n]^T,m(x_i)=0,}$$

$$
k(X,X)=\begin{bmatrix}
k(x_1,x_1) & \dots & k(x_1,x_n)\\
\vdots & \ddots & \vdots \\
k(x_n,x_1) & \dots & k(x_n,x_n)
\end{bmatrix}
$$

の場合を例に説明します。
また、ノイズ$${\epsilon \sim \mathcal{N}(0,\beta^{-1})}$$を考慮して、$${y=f(x)+\epsilon}$$を考えると、$${f \sim GP(0,k(x_i,x_j)+\delta_{ij}\beta^{-1})}$$となります。

教師データ$${X,Y}$$を用いて、テストデータ$${X^*}$$の予測を行う場合は、同時分布を考えます。

$$
p(Y^*,Y|X^*,X) = \mathcal{N}\left(
\begin{bmatrix}
Y^* \\
Y
\end{bmatrix}
\middle|
0,
\begin{bmatrix}
k(X^*,X^*) + \beta^{-1}I & k(X^*,X) \\
K(X,X^*) & K(X,X)+ \beta^{-1}I
\end{bmatrix}
\right)
$$

この同時分布から条件付き確率を求めると下記のように書くことができます。

$$
p(Y^*|Y,X^*,X) = \mathcal{N}(Y^*|\mu^*,\Sigma^*) \\
\mu^* = k(X^*,X)(k(X,X)+\beta^{-1}I)Y \\
\Sigma^* = k(X^*,X^*)+\beta^{-1}I-k(X^*,X)(k(X,X)+\beta^{-1}I)^{-1}k(X,X^*)
$$

このように、テストデータの予測値の事後分布を求めることができます。

カーネル関数としては、$${k(x_i,x_j)=\sigma^2\exp(-\frac{1}{2}w_l\sum_{l}(x_{il}-x_{jl})^2)}$$のような関数が用いられます。
$${l}$$は$${x}$$の要素のインデックスで、$${w_l}$$はパラメータです。
$${w_l}$$を最適化することで、予測性能が高いモデルを得ることができます。

deep Gaussian process (深層ガウス過程)

ニューラルネットワークの各層にガウス過程を用いたものが、深層ガウス過程になります。
3層のニューラルネットワークを考え、入力層を$${X^1}$$、中間層を$${X^2}$$、出力層を$${Y}$$とします。
中間層と出力層は下記のようにガウス過程で定義されます。

$$
f^{X^2} \sim GP(0,k^1(X^1,X^1)) \\
f^Y \sim GP(0,k^2(X^2,X^2))
$$

深層ガウス過程では、学習のためにinducing points(誘導点)を導入します。
3層のニューラルネットワークの場合は、入力層の誘導点$${Z_1}$$と、中間層の誘導点$${Z_2}$$を導入します。
誘導点は変分推論で最適化する変数で、観測データを用いて、誘導点、平均関数、カーネル関数の最適化を行います。
この最適化は、大規模な観測データを少数の誘導点に圧縮していると捉えることができます。
新たなデータが与えられた場合は、学習結果を用いて事後分布を計算します。

参考資料

  • 須山敦志, ベイズ深層学習, 2019.

  • A. C. Damianou and N. D. Lawrence, Deep Gaussian Processes, AISTATS, 2013.

  • H. Salimbeni and M. P. Deisenroth, Doubly Stochastic Variational Inference
    for Deep Gaussian Processes, NeurIPS, 2017.

この記事が気に入ったらサポートをしてみませんか?