線形回帰モデルと正則化

1. はじめに

機械学習を勉強し始めると早々に登場してくるのが線形回帰モデルです.
線形回帰モデルは解析的に扱いやすく,モデルの説明がしやすいので,回帰というタスクを理解する上では必須の項目なのでしょう.

本記事では線形回帰モデルの概略と例をあげます.また,過学習を抑制する正則化についても述べます.

2. 線形回帰モデル

線形回帰モデルとは

\displaystyle
\hat{y}(\mathbf{x},\mathbf{w}) = w_0 + \sum_{j=1}^{m-1} w_j \phi_j(\mathbf{x})

と表されます.ここで \phi_j(\mathbf{x})は基底関数と呼びます.基底関数 \phi_j(\mathbf{x})非線形関数であっても,パラメータ \mathbf{w}に関しては線形であるため,線形回帰モデルと呼ばれます.

パラメータに関して線形であるため,訓練データに対しての最適なパラメータを求めることが容易であるということがメリットです.

最も単純なモデルは基底関数を

\displaystyle
\phi_j(x) = x_j

とした時で

\displaystyle
\hat{y}(\mathbf{x},\mathbf{w}) = w_0 + \sum_{j=1}^{m-1} w_j x_j
\tag{1}

と表されます.これは線形回帰モデルの中でも特に線形回帰と呼ばれます.入力に関しても線形なモデルです. また,基底関数を

\displaystyle
\phi_j(x) = x^j

とすると

\displaystyle
\hat{y}(\mathbf{x},\mathbf{w}) = w_0 + \sum_{j=1}^{m-1} w_j x^j

となり多項式回帰と呼ばれます.入力に関しては非線形なモデルで,次数 m-1を大きくするほどモデルの表現力が上がります.

基底関数は他にもシグモイド関数ガウス関数などがありますが,今回は取り上げません.

それでは実際に線形回帰モデルを使用してみます. 訓練データとモデルの二乗和誤差を最小にするようなパラメータ \mathbf{w}を求めます.

ここで.パラメータを解析的に求めるために,変数の表現方法を整理します.

入力変数 \mathbf{x} n行あるとすると入力変数 \boldsymbol{X}は $$ \boldsymbol{X} = (\mathbf{x}_1, \mathbf{x}_2,... \mathbf{x}_n)^T= \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1m-1}\\ x_{21} & x_{22} & \cdots & x_{2m-1}\\ \vdots & \vdots & & \vdots\\ x_{n1} & x_{n2} & \cdots & x_{nm-1}\\ \end{bmatrix} $$ となり, n m-1列の行列で表されます. また,モデルの定数項(バイアス)も行列に含めたいため  \boldsymbol{X}の左側に (1,1,...,1)^Tを追加すると $$ \boldsymbol{X} = \begin{bmatrix} 1 & x_{11} & x_{12} & \cdots & x_{1m-1}\\ 1 & x_{21} & x_{22} & \cdots & x_{2m-1}\\ \vdots & \vdots & \vdots & & \vdots\\ 1 & x_{n1} & x_{n2} & \cdots & x_{nm-1}\\ \end{bmatrix} $$ と表現され, \boldsymbol{X} n m列の行列となります.  \mathbf{w} = (w_0, w_1, ..., w_{m-1})^Tより \hat{\boldsymbol{y}}=(\hat{y}_1, \hat{y}_2, ..., \hat{y}_n) ^Tは $$ \hat{\boldsymbol{y}} = \boldsymbol{X}\mathbf{w} $$ と表されます.(1)式より簡単に表せるようになりました.

実際のデータを \boldsymbol{y} = (y_1,y_2, ..., y_n)^Tとすると二乗和誤差 Lは $$ L = (\boldsymbol{y} - \hat{\boldsymbol{y}})^T(\boldsymbol{y} - \hat{\boldsymbol{y}}) \tag{2} $$ と表されます.(2)式を展開すると $$ L = (\boldsymbol{y} - \hat{\boldsymbol{y}})^T(\boldsymbol{y} - \hat{\boldsymbol{y}}) = \boldsymbol{y}^T\boldsymbol{y}\ -2\mathbf{w}^T\boldsymbol{X}^T\boldsymbol{y} + \mathbf{w}^T\boldsymbol{X}^T \boldsymbol{X}\mathbf{w} $$ となります. 二乗和誤差(損失関数)を最小にするパラメータを求めたいため, L偏微分し,勾配が0となる \mathbf{w}を求めます. $$ \nabla_{\bf{w}} L = -2\boldsymbol{X}^T\boldsymbol{y} +2 \boldsymbol{X}^T \boldsymbol{X}\mathbf{w} = 0 $$ より $$ \mathbf{w} = (\boldsymbol{X}^T\boldsymbol{X})^{-1}\boldsymbol{X}^T\boldsymbol{y} \tag{3} $$ とパラメータを求めることができます.パラメータを直接与えてくれるような(3)式は正規方程式と呼ばれます.

この正規方程式を使用して回帰を行います. 次のようなデータがあるとします.

f:id:rakAcHIkARA:20190516121655j:plain

このような x yの関係は直線で表すことができそうです. $$ y = w_0 + w_1x $$ とし,パラメータを(3)式によって求めて直線を画像上に引いたのが,下の図です. 実際,このデータは x yの1次式にガウスノイズを加えたものなので, x yの関係は1次式で表されるわけです. f:id:rakAcHIkARA:20190516121731j:plain

では,次のような分布はどうでしょう. f:id:rakAcHIkARA:20190516121806j:plain 直線よりは曲線の方が分布に合ってそうです. このような時は曲線を表すことができる多項式回帰を使ってみましょう. 3次多項式をモデルに使用した場合が次の画像です. f:id:rakAcHIkARA:20190516121827j:plain

モデルの自由度を増やし,表現力をあげた場合はどうなるでしょうか. 次数を大幅に増やし,9次多項式でフィッティングさせてみます. f:id:rakAcHIkARA:20190516121853j:plain

曲線が全ての点を通っているように見えます.
実際,訓練データに対しての誤差はほとんど0です.

しかし,予測において重要なのは汎化性能です.未知のデータに対しての予測誤差が小さくなくては意味がありません.

実際このデータはsinにガウスノイズを足したもので生成されています.

3. 過学習正則化

f:id:rakAcHIkARA:20190516133958j:plain

訓練データ(赤)に対しては非常に当てはまりがよいですが,点を増やした時の分布(青)に当てはまっていないことがわかります.このように訓練データに対しては異常に適合し,一方で未知データに対しての誤差が非常に大きい状態を過学習と言います. 次の学習曲線を見ると過学習の様子がわかります.

f:id:rakAcHIkARA:20190516121956j:plain

訓練データが少ない時,訓練データに対しては誤差は小さく,検証データに対しては異常に誤差が大きいことがわかります.しかし,学習させるデータ数を増やすと検証データに対する誤差が小さくなっている様子がわかります.自由度の高さゆえに少ないデータに対してはほぼ完全に適合することができるわけです. 過学習を抑制する方法は
・訓練データ数を増やす.
・損失関数に罰則項(正則化項)を付加する.
という方法があります.後者は特に正則化と言います.以下は正則化項を付加した場合の損失関数です.

 \displaystyle
  E(\mathbf{w}) = \sum_{i=1}^{n} (y_{i} - \mathbf
{w}^{T} \boldsymbol{\phi}(\mathbf{x}_{i}))^{2} +\lambda \sum_{j=1}^{m-1} |w_{j}|^{q}

特に q=1の時,Lasso回帰, q=2の時,Ridge回帰と呼ばれます.

実際にRidgeとLassoを用いて,過学習を抑制してみましょう(最適化問題の解法は上記で述べた正規方程式とRidge,Lasso共に異なります).

以下の図がRidgeとLassoを用いた結果です.

f:id:rakAcHIkARA:20190516141944j:plain
Ridge

f:id:rakAcHIkARA:20190516141859j:plain
Lasso
どちらも \lambdaが大きくなるにつれて,直線に近づいています.

 \lambdaは二乗和誤差と正則化項の相対的な重要度を決めるパラメータになっています. \lambdaが大きいほど正則化項を最小化するという要素が大きくなるため,パラメータ \boldsymbol{w}の大きさが小さくなります.この作用により,高次のパラメータが小さくなったことから直線に近付いたと考えられます.

正則化過学習を抑制するおおよその理解はこれで良いと思うのですが,ここでLassoとRidgeの違いについて見てみます. 違いを説明するための道具として,

 \displaystyle
  \min E(\mathbf{w}) = \min \left\{\sum_{i=1}^{n} (y_{i} - \mathbf
{w}^{T} \boldsymbol{\phi}(\mathbf{x}_{i}))^{2} +\lambda \sum_{j=1}^{m-1} |w_{j}|^{q}\right\}

 \displaystyle
\min E(\mathbf{w}) = \min \left\{\sum_{i=1}^{n} (y_{i} - \mathbf
{w}^{T} \boldsymbol{\phi}(\mathbf{x}_{i}))^{2} \right\}
 \displaystyle
\text{subject to}~~~~~~~\sum_{j=1}^{m-1} |w_{j}|^{q} \leq \eta

が等しいことを示します. 後者は条件付き最適化問題なので,ラグランジュの未定乗数法を使用します. 制約式 g

 \displaystyle
g(\mathbf{w}) = \eta - \sum_{j=1}^{m-1} |w_{j}|^{q} \geq 0

となります. そこからラグランジュ関数は

 \displaystyle
L(\mathbf{w}, \lambda) = \left\{\sum_{i=1}^{n} (y_{i} - \mathbf
{w}^{T} \boldsymbol{\phi}(\mathbf{x}_{i}))^{2} \right\} - \lambda \left\{\eta -\sum_{j=1}^{m-1} |w_{j}|^{q}\right\}

となることがわかります.ここで

 \displaystyle
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\lambda \geq 0 \\
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\lambda g = 0

を満たす必要があります.(KKT条件)
よって

 \displaystyle
\nabla L(\mathbf{w}, \lambda) = \nabla \left\{\sum_{i=1}^{n} (y_{i} - \mathbf
{w}^{T} \boldsymbol{\phi}(\mathbf{x}_{i}))^{2} \right\} + \lambda \nabla \sum_{j=1}^{m-1} |w_{j}|^{q} = 0

となる \boldsymbol{w}を求める問題となり,(4)式と同じ問題となります. つまりLassoは

 \displaystyle
\sum_{j=1}^{m-1} |w_{j}| \leq \eta

の制約の元での誤差関数の最小化,Ridgeは

 \displaystyle
\sum_{j=1}^{m-1} w_{j}^2 \leq \eta

の制約の元での最小化と言えます. この制約のもとでの最小化を表す図を以下に示します. LassoとRidgeの違いを説明するときによく使用されます.

f:id:rakAcHIkARA:20190515180113p:plain:w400
f:id:rakAcHIkARA:20190515180128p:plain:w400

緑の線が損失関数の等高線表示で,赤の線が制約式が等号の場合を表しています.同じ等高線でもLassoの方が疎な(パラメータのベクトルの成分のうち0が多い)解を得やすいということが示唆されます.(PRMLに記載されている図とほぼ一緒です.)
先ほどのRidgeとLassoの係数の違いを見てもLassoが疎な解を得やすいことがわかります. f:id:rakAcHIkARA:20190516132747p:plain 横軸が次数で縦軸が係数です.これは先ほど示した \lambda=0.01の時のRidgeとLassoのそれぞれの次数ごとの係数です.Lassoの方は0となっている係数がRidgeに比べて多いことがわかります.

Lassoは疎な解を得やすいことから特徴量選択に使用されるようです. Ridgeは単に過学習を抑制するためには効果的です.

終わりに

線形回帰モデルと過学習の抑制方法について学んだことを記述しました. Elastic Netまで書くのが目標でしたが,疲れ果てました.余裕ができたときに追記しようと思います.