線形回帰モデルと正則化
1. はじめに
機械学習を勉強し始めると早々に登場してくるのが線形回帰モデルです.
線形回帰モデルは解析的に扱いやすく,モデルの説明がしやすいので,回帰というタスクを理解する上では必須の項目なのでしょう.
本記事では線形回帰モデルの概略と例をあげます.また,過学習を抑制する正則化についても述べます.
2. 線形回帰モデル
線形回帰モデルとは
と表されます.ここでは基底関数と呼びます.基底関数
が非線形関数であっても,パラメータ
に関しては線形であるため,線形回帰モデルと呼ばれます.
パラメータに関して線形であるため,訓練データに対しての最適なパラメータを求めることが容易であるということがメリットです.
最も単純なモデルは基底関数を
とした時で
と表されます.これは線形回帰モデルの中でも特に線形回帰と呼ばれます.入力に関しても線形なモデルです. また,基底関数を
とすると
となり多項式回帰と呼ばれます.入力に関しては非線形なモデルで,次数を大きくするほどモデルの表現力が上がります.
基底関数は他にもシグモイド関数やガウス関数などがありますが,今回は取り上げません.
それでは実際に線形回帰モデルを使用してみます.
訓練データとモデルの二乗和誤差を最小にするようなパラメータを求めます.
ここで.パラメータを解析的に求めるために,変数の表現方法を整理します.
入力変数が
行あるとすると入力変数
は
$$
\boldsymbol{X} = (\mathbf{x}_1, \mathbf{x}_2,... \mathbf{x}_n)^T=
\begin{bmatrix}
x_{11} & x_{12} & \cdots & x_{1m-1}\\
x_{21} & x_{22} & \cdots & x_{2m-1}\\
\vdots & \vdots & & \vdots\\
x_{n1} & x_{n2} & \cdots & x_{nm-1}\\
\end{bmatrix}
$$
となり,
行
列の行列で表されます.
また,モデルの定数項(バイアス)も行列に含めたいため
の左側に
を追加すると
$$
\boldsymbol{X} =
\begin{bmatrix}
1 & x_{11} & x_{12} & \cdots & x_{1m-1}\\
1 & x_{21} & x_{22} & \cdots & x_{2m-1}\\
\vdots & \vdots & \vdots & & \vdots\\
1 & x_{n1} & x_{n2} & \cdots & x_{nm-1}\\
\end{bmatrix}
$$
と表現され,
は
行
列の行列となります.
より
は
$$
\hat{\boldsymbol{y}} = \boldsymbol{X}\mathbf{w}
$$
と表されます.(1)式より簡単に表せるようになりました.
実際のデータをとすると二乗和誤差
は
$$
L = (\boldsymbol{y} - \hat{\boldsymbol{y}})^T(\boldsymbol{y} - \hat{\boldsymbol{y}}) \tag{2}
$$
と表されます.(2)式を展開すると
$$
L = (\boldsymbol{y} - \hat{\boldsymbol{y}})^T(\boldsymbol{y} - \hat{\boldsymbol{y}}) = \boldsymbol{y}^T\boldsymbol{y}\ -2\mathbf{w}^T\boldsymbol{X}^T\boldsymbol{y} + \mathbf{w}^T\boldsymbol{X}^T \boldsymbol{X}\mathbf{w}
$$
となります.
二乗和誤差(損失関数)を最小にするパラメータを求めたいため,
を偏微分し,勾配が0となる
を求めます.
$$
\nabla_{\bf{w}} L = -2\boldsymbol{X}^T\boldsymbol{y} +2 \boldsymbol{X}^T \boldsymbol{X}\mathbf{w} = 0
$$
より
$$
\mathbf{w} = (\boldsymbol{X}^T\boldsymbol{X})^{-1}\boldsymbol{X}^T\boldsymbol{y}
\tag{3}
$$
とパラメータを求めることができます.パラメータを直接与えてくれるような(3)式は正規方程式と呼ばれます.
この正規方程式を使用して回帰を行います. 次のようなデータがあるとします.
このようなと
の関係は直線で表すことができそうです.
$$
y = w_0 + w_1x
$$
とし,パラメータを(3)式によって求めて直線を画像上に引いたのが,下の図です.
実際,このデータは
と
の1次式にガウスノイズを加えたものなので,
と
の関係は1次式で表されるわけです.
では,次のような分布はどうでしょう.
直線よりは曲線の方が分布に合ってそうです.
このような時は曲線を表すことができる多項式回帰を使ってみましょう.
3次多項式をモデルに使用した場合が次の画像です.
モデルの自由度を増やし,表現力をあげた場合はどうなるでしょうか.
次数を大幅に増やし,9次多項式でフィッティングさせてみます.
曲線が全ての点を通っているように見えます.
実際,訓練データに対しての誤差はほとんど0です.
しかし,予測において重要なのは汎化性能です.未知のデータに対しての予測誤差が小さくなくては意味がありません.
実際このデータはsinにガウスノイズを足したもので生成されています.
3. 過学習と正則化
訓練データ(赤)に対しては非常に当てはまりがよいですが,点を増やした時の分布(青)に当てはまっていないことがわかります.このように訓練データに対しては異常に適合し,一方で未知データに対しての誤差が非常に大きい状態を過学習と言います. 次の学習曲線を見ると過学習の様子がわかります.
訓練データが少ない時,訓練データに対しては誤差は小さく,検証データに対しては異常に誤差が大きいことがわかります.しかし,学習させるデータ数を増やすと検証データに対する誤差が小さくなっている様子がわかります.自由度の高さゆえに少ないデータに対してはほぼ完全に適合することができるわけです.
過学習を抑制する方法は
・訓練データ数を増やす.
・損失関数に罰則項(正則化項)を付加する.
という方法があります.後者は特に正則化と言います.以下は正則化項を付加した場合の損失関数です.
特に=1の時,Lasso回帰,
=2の時,Ridge回帰と呼ばれます.
実際にRidgeとLassoを用いて,過学習を抑制してみましょう(最適化問題の解法は上記で述べた正規方程式とRidge,Lasso共に異なります).
以下の図がRidgeとLassoを用いた結果です.
が大きくなるにつれて,直線に近づいています.
は二乗和誤差と正則化項の相対的な重要度を決めるパラメータになっています.
が大きいほど正則化項を最小化するという要素が大きくなるため,パラメータ
の大きさが小さくなります.この作用により,高次のパラメータが小さくなったことから直線に近付いたと考えられます.
正則化が過学習を抑制するおおよその理解はこれで良いと思うのですが,ここでLassoとRidgeの違いについて見てみます. 違いを説明するための道具として,
と
が等しいことを示します.
後者は条件付き最適化問題なので,ラグランジュの未定乗数法を使用します.
制約式は
となります. そこからラグランジュ関数は
となることがわかります.ここで
を満たす必要があります.(KKT条件)
よって
となるを求める問題となり,(4)式と同じ問題となります.
つまりLassoは
の制約の元での誤差関数の最小化,Ridgeは
の制約の元での最小化と言えます. この制約のもとでの最小化を表す図を以下に示します. LassoとRidgeの違いを説明するときによく使用されます.
![f:id:rakAcHIkARA:20190515180113p:plain:w400 f:id:rakAcHIkARA:20190515180113p:plain:w400](https://cdn-ak.f.st-hatena.com/images/fotolife/r/rakAcHIkARA/20190515/20190515180113.png)
![f:id:rakAcHIkARA:20190515180128p:plain:w400 f:id:rakAcHIkARA:20190515180128p:plain:w400](https://cdn-ak.f.st-hatena.com/images/fotolife/r/rakAcHIkARA/20190515/20190515180128.png)
緑の線が損失関数の等高線表示で,赤の線が制約式が等号の場合を表しています.同じ等高線でもLassoの方が疎な(パラメータのベクトルの成分のうち0が多い)解を得やすいということが示唆されます.(PRMLに記載されている図とほぼ一緒です.)
先ほどのRidgeとLassoの係数の違いを見てもLassoが疎な解を得やすいことがわかります.
横軸が次数で縦軸が係数です.これは先ほど示した
の時のRidgeとLassoのそれぞれの次数ごとの係数です.Lassoの方は0となっている係数がRidgeに比べて多いことがわかります.
Lassoは疎な解を得やすいことから特徴量選択に使用されるようです. Ridgeは単に過学習を抑制するためには効果的です.
終わりに
線形回帰モデルと過学習の抑制方法について学んだことを記述しました. Elastic Netまで書くのが目標でしたが,疲れ果てました.余裕ができたときに追記しようと思います.