医療現場では検査手法の一般化が進んでいるが、医療施設毎で検査の独自プロセスや使用機器、検査担当者の技量等の違いにより得られるデータが異なることがある。そのため、医療現場で用いられるAIはしばしば汎化性能が問われることがある。今回は「最先端機械学習技術の医療画像応用」のテーマの第二弾として、汎化性能の向上に寄与する、SoTA (State-of-the-Art) を獲得した最適化アルゴリズム SAM (Sharpness-Aware Minimization) (1) を使って、医療用画像の予測に適用させられるかを検証していく。
1. 始めに
損失関数で表現されるグラフは、局所的に損失が小さい所(図1左)より周辺が平坦で損失が小さい所(図1右)の方が汎化性能が高くなることが、知られている (2)。SAM (3)は、周辺が平坦かつ損失が最小となる所を探索しているため、単に損失が最小となる所を探索している従来の最適化アルゴリズムと比べて、汎化性能が高くなる。それにより、ImageNetやCIFAR-10、CIFAR-100等の9つのデータセットのSoTAを更新した。
図1. 損失関数で表現されたグラフ
2. 方法
モデル構築及び推論の実装には深層学習フレームワークPyTorch Lightningを用いており、SAMはPytorchベースのオープンソースコードを組み込んだ。SAMのベースの最適化アルゴリズムには、Adamを用いている。医療画像データは、肺疾患有無に関する2値分類問題である、アメリカ国立衛生研究所が提供する胸部X線画像データセット(NIH Chest X-rays (4))を用いた。データを学習用とテスト用に分け、最適化アルゴリズムをAdamにした場合とSAM(最適化アルゴリズムのベースはAdam)にした場合のテストデータの推論結果の違いを評価した。また、SAMを用いた学習では、損失探索の範囲であるρを異なる値(0.1、0.05、0.01)及びで学習を行い、3つのモデルを作成した。
3. 結果
テストデータ25596枚に対する予測結果から混同行列を作成し、accuracy、マクロ平均precision、マクロ平均recallの精度を算出した。結果、SAMモデルは、SAMを用いずに学習したモデル(Adamモデル)に対して、全指標で上回った。
accuracy | macro avg recall | macro avg precision | |
Adam | 65.54% | 63.92% | 63.77% |
SAM (ρ=0.1) | 69.53% | 64.72% | 68.29% |
SAM (ρ=0.05) | 69.39% | 66.24% | 67.51% |
SAM (ρ=0.01) | 68.10% | 64.80% | 66.05% |
4. 議論
Foret et al. (1) ではρ=0.05が推奨されているが、ρが0.1の時、0.05の時の結果を上回った。ρの値を上げたことで探索範囲が広がり、損失関数グラフにおいてより平坦な箇所を見つけられたと考えられる。今回、SAMを用いるだけで正解率が3%以上向上し、SAMが精度向上に寄与したことを示した。2024年現在では、SAMが派生して、ASAM(Adaptive SAM)、ESAM(Efficient SAM)等も登場しているため、それらを使うことでさらに良い結果が得られる可能性がある。
参考文献
1. Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur. Sharpness-Aware Minimization for Efficiently Improving Generalization. International Conference on Learning Representations, 2021.
2. Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. arXiv e-prints, art. arXiv:1609.04836, September 2016.
3 Xiaosong Wang, Yifan Peng, Le Lu, Zhiyong Lu, Mohammadhadi Bagheri, Ronald Summers. ChestX-ray8: Hospital-scale Chest X-ray Database and Benchmarks on Weakly-Supervised Classification and Localization of Common Thorax Diseases. IEEE CVPR, pp. 3462-3471, 2017.