【ホワイトペーパー】DeepLearning (上級編)

はじめに

前回のDeep Learning(入門編)ではDeep Learningの特徴や仕組みについて解説しました。今回は、Deep Learning(上級編)としてモデル学習時の流れや学習時に設定するハイパーパラメータ、過学習問題について紹介したいと思います。

学習の手順

教師あり学習の場合を例に、Deep Learningモデルの学習手順を解説します。

データセットの種類

まずはデータセットを準備します。はじめにすべてのデータを学習データとテストデータに分け、さらに学習データを訓練データと検証データの2つに分けます。モデルは訓練データから特徴を学習し、検証データで学習の結果を評価します。データサイエンティストはこの評価をみてハイパーパラメータ(後述) を調整し、一通り調整が終わったら最後にテストデータでモデルの性能を確認します。検証データとテストデータを分ける理由は、ハイパーパラメータの調節をくりかえし行うことで、モデルの検証データへの人為的な最適化が起こってしまうためです。こうした問題を「情報の漏れ」と言います。そのため、チュー二ングに関与する検証データとは別にテストデータで最終的なスコアを評価することで、モデルの真の汎化性能を計測できます。

図1:データセットの分割

検証方法の設定

次に検証方法を設定します。機械学習における代表的な検証方法は以下の2つです。

  • ホールドアウト検証

ホールドアウト検証は、訓練データで学習したモデルを検証データを使って評価する最も単純な検証方法です。しかし、利用可能データが十分ではない場合、検証データのサンプルが少なくデータに偏りが生じるため、正確にモデルの汎化性能を評価するには不適切です。利用可能データが十分か確かめる方法は、異なる訓練データと検証データの組み合わせで学習を行うことです。検証データが不足していた場合、組み合わせを変える度にモデルのスコアに大きなばらつきが生じます。

図2:ホールドアウト検証
  • K-Fold 交差検証

K-Fold交差検証はこのようなホールドアウト検証の問題に対処した方法です。この方法では、学習データを同じサイズのK個のサブセット(フォールド:Fold)に分割して1つをテストデータ、残りのK-1個を学習データとしてスコアの評価を行います。そして全てのサブセットが一回ずつ検証データになるようにK通りの組み合わせでモデルを学習させ、それぞれの学習のスコアの平均値を用いて検証を行います。これによりデータセットごとのスコアのばらつきを考慮した検証が可能になります。

図3:K-交差検証

しかし、1Foldの学習に数十時間以上かかる場合、全てのFoldの学習には数日かかることもあります。そのため利用可能なデータ量や学習にかかる時間を考慮して、どちらの検証方法にするかを決定します。

モデルの構築

次は学習に使用するモデルを構築します。従来、画像処理にはCNN、自然言語処理にはRNNが利用されることが一般的でした、最近では transformer なども登場していますが、モデルの構築の仕方はデータの種類や学習の目的によって様々なため、また別の記事で紹介します。

ハイパーパラメータ等の設定

学習に必要なハイパーパラメータを設定します。主なハイパーパラメータは、活性化関数、バッチサイズ(イテレーション)、エポック数です。

  • 活性化関数

Neural Networkを構築するパーセプトロン(前回記事で解説)で、重みによって線形変換された入力値に対し非線形変換を行う関数です。

図4:パーセプトロンの式

活性化関数によりニューロン間で非線形変換が行われることによって、モデルは各階層での情報を失うことなく様々な粒度の特徴を学習できます。Deep Learningでは、中間層にはReLU関数、出力層にはSigmoid関数の利用が一般的です。ReLU関数とSigmoid関数は以下のような形であり、ReLU関数は0以上の入力をそのまま次のニューロンに渡し、Sigmoid関数は予測確率を出力するため0-1の範囲の値を返します。

図5:ReLU関数とSigmoid関数
  • 損失関数

損失関数はモデルの予測が正解とどれくらいかけ離れているかを計算するための関数です。損失関数で計算された損失値はフィードバックとして重みの更新に利用されます。損失関数にはMAEやMSEがよく利用されます。MAEは予測値と目的値の差の絶対値の平均で、MSEは予測値と目的値の差の二乗の平均であり以下の式で表されます。MSEはMAEに比べ、誤差が大きいほど過大に評価する、つまり間違いをより重要視するという特徴があります。

図6:平均二乗誤差と平均絶対誤差
  • オプティマイザ

重みを更新するアルゴリズムです。Deep Neural Networkでは勾配降下法と呼ばれる方法で重みの更新を行います。この方法では損失関数を重みで微分することによって損失関数曲線の勾配(傾き)を求め、現在の重みから勾配の逆方向に動かすことによって重みを更新し、これをくりかえすことで損失値は徐々に極小値に近づいていきます。このとき一回の更新でどれだけ重みを動かすかは学習率に依存します。学習率が大きければ最適な重みに近づくスピードは速くなりますが、最適な重みで止まるのが難しくなります。

図7:学習率に依存する重みの更新幅

こうした問題を解決する手法として、勾配降下法の式を基本として、移動平均を利用して重みの振動を抑えるモーメンタム、勾配の大きさに応じて学習率を調整するRMSProp、モーメンタムとRMSProp両方の考えを取り入れたAdamなどが考案されています。

  • バッチサイズ(イテレーション)
図8:バッチサイズとイテレーション

一回の重みの更新に利用するデータの数です。バッチサイズ(m)に応じて訓練データをサブセット(n)に分割します。このとき訓練データはm×n個あることになります。サブセット毎の学習をイテレーションといい、イテレーションごとに重みを更新します。すべてのサブセットで学習が終わったら、検証データで現時点の重みにおけるスコアを計算します。(重みの更新なし)

  • エポック数

学習回数です。ここでいう学習は訓練データの各サブセットによる重みの更新と検証データによる評価を指し、学習回数は学習データ全体を何回ループしたかに対応します。n回目の学習とn+1回目の学習の間には訓練データがランダムにシャッフルされます。これにより、毎回異なるデータセットで重みの更新を行うことになるため、より汎化性能が高いモデルを作ることができます。

図9:epoch間の訓練データのシャッフル

モデルの学習

ハイパーパラメータを設定したら、いよいよモデルを学習させます。これはコンピュータが自動で行うため、人がすることはありません。

学習の評価

モデルの学習が終わったら学習の評価を行います。学習の評価には学習曲線を利用します。学習曲線とは各エポックにおける訓練データと検証データのスコアの記録をグラフ化したものです。

図10:学習曲線

学習曲線から最終的な検証データのスコアを確認し、過学習の有無を判断します。

過学習

  • 過学習とは?

過学習とは、モデルが学習データの特徴を学習しすぎた結果、未知のデータに対する予測性能が下がってしまうことです。過学習が生じているとき、訓練データの予測精度はepoch数に応じて上昇するのに対し、評価データの予測精度はある点を境に緩やかに減少していきます。

図11:過学習発生時の学習曲線

過学習が起きる要因はモデルが訓練データに過剰に適合することです。過学習が起きないようにするには、学習データを増やしてより網羅的なパターンを学習させることが一番ですが、他にも以下のような工夫が考えられます。

  • 重みの正則化

重みの大きさに対する制約条件をネットワークの損失関数に追加することで、モデルの自由度を制限する。

  • 学習回数の最適化

学習曲線をみて手動で設定、もしくは検証データのスコアの低下を感知して学習を自動で停止させる。

  • モデルの軽量化

モデルの学習可能なパラメータの数を削減する。

  • ドロップアウトの追加

ドロップアウトが適用された層に対して、モデル訓練時に層の出力の一部をランダムに0に設定する。これにより一部の局所的特徴を過剰に評価する現象を回避できる。

  • オーグメンテーション

訓練データに対して様々なテクニックを用いてデータ量を人為的に増加させる。画像データの場合、反転、輝度変更、拡大縮小、平行移動、アフィン変換など。

Discussion

Deep Learningを利用するときの課題は過学習だけではありません。もう一つの大きな課題は、Deep Learningによる学習が人間の理解を超越していることです。Deep Learningでは数千にも及ぶ膨大なデータを何層にも渡る複雑なネットワークで処理することで学習しており、そのすべての過程を人間が理解することは不可能です。こうした問題を「AIのブラックボックス問題」といいます。これは、機械学習を医療分野で利用するにあたって非常に大きな問題です。なぜなら、医師はなぜAIがそのように診断し、なぜその治療法を提案するのか患者に説明しなければならないからです。このような問題を解決すべく、提案されたのが「Interpretation Model」や「Explainable AI」です。これらは複雑で難解な機械学習モデルによる予測の理解を助けてくれます。これらがどういったアプローチでモデルによる予測の理解を助けてくれるのかは、また別の記事で紹介したいと思います。