📝
合成データアルゴリズム GAN

はじめに
Deep Learningベースの生成モデルは、近年目覚ましい発展を遂げています。2014年にGenerative Adversarial Networks(GAN) [1][2] というアルゴリズムが発表されてから、特に画像を学習し、生成するモデルは目覚ましい発展を遂げてきます。
GANは、ノイズベクトルを与え、そのノイズベクトルから画像を生成する Neural Network(NN) のモデル (Generator[生成器]と呼ぶ) を作り学習させ、その学習させた画像と本物の画像を見分ける NNのモデル(Discriminator[識別器]と呼ぶ)を作り、お互いに学習させていくというものです。
また、Variational Autoencoder(VAE)[3](2013年)や、Flow[4](2015年)といった生成モデルも別で開発されてきました(実は、GANよりVAEの方が先みたいですね)。VAEは、変分ベイズ推定法をNNで学習させるもの、Flowベースのモデルは、データ変形関数という関数をNNで学習させるような手法です。
最近では、Stable Diffusion (拡散モデルの一種)という手法が話題になっていました[5]。拡散モデルについては、2015年に初出論文はあったものの[6]、2020年にDDPM[7]という改良手法が出て、2021年にそこから改良が加えられたStable Diffusion[5] という手法が出て、精度の高さや拡張性の高さ、軽量でメモリの低いGPUでも学習できることなどから話題になりました。
今後のテーブルデータの合成データのアルゴリズムの解説の前準備として、最も基礎となっているGAN [1][2]の解説を書きます。Deep Learning の基礎知識を持っていることを前提とします。
GANについては、テーブルデータではなく画像の生成アルゴリズムですし、既にあらゆる人が解説を出しているので、今更ということもありますが、ここを理解していないと全てが理解できないので、今後の合成データの記事を理解するための最低限の知識として解説記事を出させていただきます。[8]あたりの記事をかなり参考にさせていただきました。
全体的な構造
GANは、敵対的生成ネットワークという名の通り、Generator という偽の画像を生成する NNと、 Discriminator という画像が本物か偽物かを判定する NN を交互に学習させていき、お互いが相手を越えようと学習をしていきます。
GeneratorはDiscriminatorに見分けられない画像を生成し、Discriminator はより精度高く本物と偽物を見分けられるように学習していきます。
Generator はランダムノイズベクトルを入力として画像を生成し、Discriminator が本物と見分けられなくなったら学習成功です。あとは Generator の NN に適当なノイズベクトルを入れれば、偽の画像が無限に生成できます。
GAN の loss 関数(損失関数)は、以下のようになっています。[1][9]
は学習に使う本物のデータの分布であり、 は Generator の入力のノイズベクトル の分布であり、 はGenerator 関数であり、 は出力が範囲 に及ぶ DIscriminator 関数です( 0~1 の連続値です)。 は、Discriminator がサンプル を偽物として分類し、 は本物 として分類することを意味します。 は期待値を意味します。期待値を取るということは、特定のイベントが発生する確率によって適切に重み付けされた、式が取る可能性のあるすべての値を「平均化」することを意味します。
学習アルゴリズム全体は以下のようになっています。
![[8] DCGAN論文: Alec Radford, et al. “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”(2016), page3, arxiv: https://arxiv.org/abs/1511.06434](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/9539301e-bea8-4e50-a232-cd06a5459bdd/d309b7743b47dd2adad76c2340925598.png)
これだけではよくわからないので、具体的に解説をしていきます。
Discriminator 学習
Discriminator からの方が説明しやすいので、まずはこちらから解説します。
![[8] 今さら聞けないGAN(1) 基本構造の理解, @triwave33, Qiita, 投稿日:2018-01-18, 閲覧日:2023-01-24, url: https://qiita.com/triwave33/items/1890ccc71fab6cbca87e](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/849f34b6-e343-4d04-86cf-6907710c56ab/8555d05c79f732aa352423767dafca2a.jpeg)
Discriminatorは、入力された画像が本物か偽物かを見分けるための NN です。なので、Generatorが出した偽の画像( 正解ラベル=0 )と、学習に使う本物の画像( 正解ラベル=1 )を入れて、偽物だと判断したときには 0 を出力し、本物だと判断したときには1 を出力するように学習させます。
ここで、Generatorはただ、偽の画像を生成させるために使っただけで、Discriminatorの学習とは全く関係ありません。”GAN”では、Discriminator の学習時には Generator は学習しません。また、Generatorの学習時には、Discriminatorは学習しません。それぞれの学習時はもう一方のネットワークは固定し、交互に学習を進めます。
Algorithm1 のところでも示されている通り、学習の単位は基本ミニバッチ単位です。ミニバッチとは、学習データの中から一定量のデータを抽出してそのまとまりで学習させていく、という単位です。後述しますが、バッチ学習(データ全体を一気に学習すること)で行うとうまく学習できなくなります。
loss 関数
もう一度 (1) 式を表示します。
(1) 式の定義通りの損失関数を計算します。ただ、Generatorは固定されているので、ただの関数とみなし、 の部分を学習させます。
の場合は、 なので、(1) 式を最大化させることが目標になります。この式は、ある程度機械学習等をやっていると見慣れた式ですが、バイナリクロスエントロピーですね。少し解説を加えます。2クラス分類を対象としたクロスエントロピーをバイナリクロスエントロピーと呼びます。知らない人は[10]あたりを読んでいただけるといいかなと思います。
は NN をsigmoid関数で最後に0~1の連続値として出力させたものです。sigmoid関数は0か1に限りなく近づきますが、0か1にはなりません。なので、 となって、負の無限大にぶっ飛んでいくことはありません。
理想的なのは、 と出力される場合なので、一旦期待値を考えずに、その時の loss 関数は
となります。逆に、これからずれていくごとに マイナスの値を loss 関数として出力することになります。
実際には、ミニバッチ単位での期待値として、データの平均値として計算されることになります。
そして、このloss関数が になるように(0に近づくように)、NN のパラメータを更新していくということになります。
Generator 学習
![[8] 今さら聞けないGAN(1) 基本構造の理解, @triwave33, Qiita, 投稿日:2018-01-18, 閲覧日:2023-01-24, url: https://qiita.com/triwave33/items/1890ccc71fab6cbca87e](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/02679f68-3cde-41f4-960a-26f7a393e31a/80de9ffc4a37135fe1faf9c140711ae3.jpeg)
Generator の目的変数は、Discriminator の本物か偽物かの判定になります。
VAEなどは、入力データそのものが目的変数になって損失関数を計算したりしますが、GANではDiscriminator を利用して損失関数を計算します。
Discriminator はこの時固定します。Back Propagation(損失関数からNNの重みを修正して学習させること) を行うのは、Generator のみであり、目的変数を出すために Discriminator を使うだけでこっちは学習させません。
loss 関数
もう一度(1)式を表示しておきます。
この際、第一項は Generator に関係ないですし、 も固定されているので、最適化問題からは除外します。
こちらは、 なので、この式ができるだけ小さくさせたいです。つまり、 になるようにしたいわけです。意味的には、Generator が出力した偽の画像を、Discriminator が本物と勘違いしてしまうのが、Generatorの目標になります。
あとは、NN のパラメータを、 に近づくように更新していけば良いことになります。
Generatorの出力を一回Discriminatorに通して損失関数を計算するというちょっと変わった形になります。Combind modelとか言ったりします。
NN 構造
基本的に、全て全結合層で構成されています。全結合層とは一番基本のニューロンが次の層のニューロンに全て繋がっている形のやつです。
GANの元論文には特に具体的なNN構造についての記述はありませんでした。なので、色々な実装を参考にしてこんな感じというものを作っていけたらと思います。
ただ、層の作り方は割と自由です。というか、いい感じにGeneratorとDiscriminatorが学習していけるように調整が必要です。
簡単な実装は、[11][12]を参考にすると、以下のような形になります。
以下は、Pytorchのクラスの表示形式の疑似コードです。具体的な実装は[13]を参照してください。input_dimは、入力の次元数です。MNISTの場合だと、(28,28)なので、 になります。
noize_dimは、generatorに入力するノイズベクトルの次元数です。128とかに設定しておくのが無難でしょうか。
論文中[1]には、Dropoutを使用したと書いてありましたので、Dropoutを適用するのがデフォルトだと思われます。Dropoutとは、ニューラルネットワークの学習時に、一定割合のノードを不活性化させながら学習を行うことで過学習を防ぎ(緩和し)、精度をあげるための手法です。[13][14]
また、論文中[1]の活性化関数として、Generatorには ReLU[15]と Sigmoid を使用していて、Discriminator には maxout [16][17]を使用していると書いてありました。
DiscriminatorのFlattan()は単純に次元を下げる物です。NNは2階か3階のテンソルとして表現されているので、最後の値は0~1の1つの値ですから、1階のテンソル=ベクトルに直すだけのものです。
複数の実装例を見ていると、上のような実装が一般的なように思われます。Generator、Discriminator共に、活性化関数は LeakyReLU [18]を使用して、Generatorの最後の活性化関数はTanhを使用します。これはおそらく後の GAN の改良アルゴリズムの影響を受けていると思われます。
ただ、もっと論文に準拠したシンプルな実装もあり[19][20][21]、学習がちゃんとできれば割と自由に実装して良さそうに思われます。
特徴と問題点
[22]がかなり色々考察していてとても参考になります。詳細はこちらの論文を読んでいただけると良いです。
まず、疑問ですが、Generator と Discriminator 両方一緒に学習しちゃダメなんですか?という疑問が出てきます。
これに対しては、明確な解答が見つけられませんでしたが、おそらくは学習が安定しないためだと思われます。
しかし、One Stage GAN(OSGAN) [23][24] という、Generator と Discriminator の学習を同時に行うアルゴリズムがのちに登場しています。
そして、GANの学習には大きな問題があります。[25]によると、
GANはGeneratorがDiscriminatorを騙せるようなきれいな画像を作り出すことができれば、よりよい画像を作ることができるが、ここにはよく知られた問題が2つある。 1つ目の問題はGeneratorあるいはDiscriminatorの一方だけが早く学習していまい、もう一方のネットワークの学習が進みづらくなる問題である。この問題は初期においてはDiscriminatorの学習が進みすぎて、Generatorがどのような画像を作っても簡単に見破ってしまい、結果としてGeneratorが学習できないというものだ。ただ、最近はGeneratorを安定して学習させる技術がいろいろ出てきたことで、以下にも説明するTTUR (two time-scale update rule)にもあるようにDiscriminatorの方を早く学習させるのが主流となっている。 2つ目の問題はmode collapseやmissing modeと呼ばれる問題で、Generatorが最も得意な画像、例えば数字なら0、だけを作ってDiscriminatorを騙すように学習してしまうことを指す。Generatorは1-9の数字は全く作ることなく、よりよい0を作ることを目指して学習してしまう。これを防ぐためには、Discriminatorが1枚1枚の画像だけではなく、バッチの中のデータの多様性などを判断できるようにするとよく、初期においてはfeature matchingやminibatch discriminationなどが提案された。
また、[26]によると、
GANの学習では、さまざまな「不安定性」が報告されています。まず、生成器 (generator) が似たようなデータしか出力しなくなるという、モード崩壊 (mode collapse) と呼ばれる現象が起きやすいことがよく知られています [22]。モード崩壊は、最適化の用語でいえば、モデルパラメータが悪い局所解にはまっている状態であるといえます。また、より悪い状況として、モデルパラメータがいずれの局所解にも収束せず、発散してしまうことがしばしばあります。さらに、Jensen–Shannonダイバージェンスに基づく標準的なGANでは、勾配消失現象が原理的に起きやすいことが指摘されています[27]。
この説明を見る限り、ミニバッチ学習でないと、学習の不安定化は避けられないと思われます。なので、ミニバッチ学習が基本となりそうです。Generator と Discriminator がいい感じに競い合って学習を進めていかないといけません。
また、GANは一般に学習が成功した時、鮮明な画像を出力できることで知られています。きれいで鮮明な画像を出力できたことから大きな話題を呼んだアルゴリズムでもあります。
loss 関数の最適解
最後に、少し loss 関数の最適解とその意味について考えます。[1]の元論文はかなりこの計算をしてくれています。[28][29]も参考にさせていただきました。
再度、(1)式を表示すると、
ここから、期待値を書き直します。
ここから、ノイズベクトル から Generatorが生成するデータの分布を とします。すると、
というふうに書き換えられます。元々のクロスエントロピーの式になります。
ここから、Generator を固定して、Discriminatorの最適解を導出します。なので、 は と独立になります。
(5) 式のクロスエントロピーは、 に関して上の凸の関数になります。なので、微分すると最大値が求められます。
(6)式の右辺 と置いて、 の最適解 を求めると、以下のようになります。
次に、Generatorの最適解を考えます。
(5)式に、 を代入して、 を最小化することを考えます。
3→4行目は、単純に の性質を使い、掛け算を足し算に直して分離しただけになります。また、 に直しています。
4→5行目は、Kullback-Leibler Divergence (KL-Divergence) と呼ばれる確率分布同士がどれだけ似ているかの尺度に直したものに変換しています[30]。
5→6行目は、Jensen–Shannon Divergence(JS-Divergence) と呼ばれる確率分布同士がどれだけ似ているかの尺度に直したものに変換しています[30]。
一応、これらの定義を書いておきます。
の最適解を考えたとき、 は、 の時に、最小値 を取ります。
また、最後の部分は、確率分布の元々の意味から、全ての領域(全てのx)で積分をすると、 になり、最適解 の時も全く同じになるので、
となります。
よって、損失関数 は , の時、最小値 を取ります。
まとめ
- GANに関する知識は、テーブルデータの合成データのアルゴリズムを理解するために重要である。
- GANは、GeneratorとDiscriminatorがそれぞれお互いを超えようと学習することで、鮮明な画像を生成するモデルである。
- GANの学習は不安定で難しい。
参考文献
[1] GAN論文 : Goodfellow, Ian, et al. "Generative adversarial Nets." (2014)arXiv url:https://arxiv.org/abs/1406.2661
[2] GAN論文 : Goodfellow, Ian, et al. "Generative adversarial networks." Communications of the ACM 63.11 (2020): 139-144. https://dl.acm.org/doi/abs/10.1145/3422622
[3] VAE論文 : Diederik P Kingma, et al. "Auto-Encoding Variational Bayes." (2013)arXiv url:https://arxiv.org/abs/1312.6114
[4] Flow論文 :Dinh, Laurent, David Krueger, and Yoshua Bengio. "Nice: Non-linear independent components estimation.”(2015)arXiv url : https://arxiv.org/abs/1410.8516
[5] Stable Diffusion論文 : Robin Rombach, et al. “High-Resolution Image Synthesis with Latent Diffusion Models” arXiv url: https://arxiv.org/abs/2112.10752
[6] Diffusion Model論文 : Jascha Sohl-Dickstein, et al. “Deep Unsupervised Learning using Nonequilibrium Thermodynamics”(2015) arXiv url :https://arxiv.org/abs/1503.03585
[7] DDPM論文 : Jonathan Ho, et al. “Denoising Diffusion Probabilistic Models” (2020) arXiv url :https://arxiv.org/abs/2006.11239
[8] GAN解説 : 今さら聞けないGAN(1) 基本構造の理解, @triwave33, Qiita, 投稿日:2018-01-18, 閲覧日:2023-01-24, url: https://qiita.com/triwave33/items/1890ccc71fab6cbca87e
[9] DCGAN論文: Alec Radford, et al. “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”(2016) arxiv: https://arxiv.org/abs/1511.06434
[10] クロスエントロピー解説 : 【初学者向け】クロスエントロピーを分かりやすく解説。, Biginaid,投稿日2022-12-23, 閲覧日: 2023-01-24, url : https://tips-memo.com/cross-entropy
[11] GAN実装 : Pytorch – GAN の仕組みと Pytorch による実装例, Pystyle, 投稿日2020-06-03, 閲覧日: 2023-01-24, url :https://pystyle.info/pytorch-gan/
[12] GAN実装 : PyTorch🔥 GAN Basic Tutorial for beginner, kaggle, 閲覧日2022-01-24, url: https://www.kaggle.com/code/songseungwon/pytorch-gan-basic-tutorial-for-beginner/notebook
[13] Dropout : Hinton, G. E., Srivastava, N., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. (2012b). Improving neural networks by preventing co-adaptation of feature detectors. Technical report, arXiv:1207.0580.
[14] Dropout解説 : 【ニューラルネットワーク】Dropout(ドロップアウト)についてまとめる, @shu_marubo, Qiita, 投稿日2017-07-18, 閲覧日: 2023-01-24, url: https://qiita.com/shu_marubo/items/70b20c3a6c172aaeb8de
[15] ReLU : Glorot, X., Bordes, A., and Bengio, Y. (2011). Deep sparse rectifier neural networks. In AISTATS’2011.
[16] Maxout : Goodfellow, I. J., Warde-Farley, D., Mirza, M., Courville, A., and Bengio, Y. (2013a). Maxout networks. In ICML’2013.
[17] Maxout解説 : ディープラーニングを実装から学ぶ(7-1)その他(活性化関数~MaxOut、ReLU関連), @Nezura, Qiita, 投稿日2018-07-15, 閲覧日: 2023-01-24, url: https://qiita.com/Nezura/items/f52fdc483e5e7eceb6b9
[18] leakyReLU : Xu, Bing, et al. "Empirical evaluation of rectified activations in convolutional network." arXiv preprint (2015). arXiv url: https://arxiv.org/pdf/1505.00853.pdf
[19] GAN実装 : github, firstcommit 2019-03-29, 閲覧日 2023-01-24, url : https://github.com/eriklindernoren/PyTorch-GAN
[20] GAN実装 : PytorchでGANを実装してみた。, @keiji_dl, Qiita, 投稿日2021-07-22, 閲覧日: 2023-01-24, url: https://qiita.com/keiji_dl/items/45a5775a361151f9189d
[21] GAN実装 : GAN (Generative Adversarial Network): Simple Implementation with PyTorch, KikaBeN, 投稿日2022-04-21, 閲覧日: 2023-01-24, url: https://kikaben.com/gangenerative-adversarial-network-simple-implementation-with-pytorch/
[22] GAN 考察 : Ian Goodfellow. NIPS 2016 Tutorial: Generative Adversarial Networks.
[23] OSGAN : Chengchao Shen, Youtan Yin, Xinchao Wang, Xubin Li, Jie Song, Mingli Song; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2021, pp. 3350-3360 url: https://openaccess.thecvf.com/content/CVPR2021/html/Shen_Training_Generative_Adversarial_Networks_in_One_Stage_CVPR_2021_paper.html
[24] OSGAN解説 : GANの訓練はGとD同時にできる! 訓練を33%以上高速化するOne Stage GAN(OSGAN)の紹介, @koshian2, Qiita, 投稿日2022-02-14, 閲覧日: 2023-01-24, url:https://qiita.com/koshian2/items/da19a8baaad419ebcf78
[25] GAN問題 : GANの学習安定化テクニック, Programming for Beginners, 閲覧日:2023/01/24, url: https://tatsy.github.io/programming-for-beginners/python/stabilize-gan-training/
[26] GAN問題 :【ICLR2020採択論文】GANのなめらかさと安定性, Preferred Networks Blog, 投稿日2020-01-06, 閲覧日: 2023-01-24, url:https://tech.preferred.jp/ja/blog/smoothness-and-stability-in-gans/
[27] GANの不安定さ : Martin Arjovsky and Leon Bottou. Towards principled methods for training generative adversarial networks. In International Conference on Learning Representations (ICLR), 2017.
[28] GAN解説 : GANと損失関数の計算についてまとめた, @kzkadc, QIita, 投稿日:2018-12-15, 閲覧日:2023-01-24, url: https://qiita.com/kzkadc/items/f49718dc8aedbe8a1bee
[29] GAN解説 : 敵対的生成ネットワーク(GAN)の理論を解説, 努力のガリレオ, 投稿日2021-12-06, 閲覧日: 2023-01-24, url : https://dreamer-uma.com/gan-theory/
[30] JS,KL Divergence解説 : 機械学習で用いられるDivergence(KL-Divergence)やらEntropy (Cross-Entropy)の世界へ入門する, @harmegiddo, Qiita, 投稿日2020-04-29, 閲覧日: 2023-01-24, url: https://qiita.com/harmegiddo/items/2a24a36418fade0eaf44