📝
表形式の合成データを生成する言語モデル GReaT

はじめに
表データは最も一般的な形式のデータであるにも関わらず、機械学習での学習が困難であることが知られています。なぜなら、表形式のデータにはカテゴリデータと数値データの両方が混在しており、さらに、欠落やノイズを含む場合が多いからです。また、プライバシー保護の観点から共有できない情報を含む場合、データそのものが使用できないことも少なくありません。表形式のデータを人工的に自由に生成できれば、上記の困難を克服できる可能性があります。そのため、現実的な表形式の合成データに関する研究が注目されています[1-3]。
表形式の合成データの生成に関する深層学習を用いた最新のアプローチのほとんどには、変分オートエンコーダー(Variational AutoEncoder, VAE)や敵対的生成ネットワーク(Generative Adversarial Networks, GAN)が用いられています。一方、深層学習は自然言語処理(Natural Language Processing, NLP)の分野にも革命をもたらし、2023年3月現在、OpenAIによって開発された人工知能チャットボットであるChatGPTが世間を賑わせています。ここで、GPTはGenerative Pre-trained Transformerの略であり、2018年に最初の論文[4]が公開されて以来、GPT-2[5]、GPT-3[6]、GPT-4[7,8]が発表されています。このような背景から、NLPのアーキテクチャが表形式の合成データの生成にどの程度有用なのか疑問が生じていました。この疑問に対する答えとして提案されたものが、今回ご紹介するGReaT[9]というモデルです。
GReaT(Generation of Realistic Tabular data)は、表形式データの生成にTransformer-decoderネットワークアーキテクチャを適用した最初のモデルです。本記事では、GReaTのモデル構造やGReaTによって生成される合成データの評価についてまとめます。
GReaT
Transformerベースのニューラルネットワークを使用したGReaTは、主に2つの段階で構成されています。1つ目は、テキストでエンコードされた表形式データセットに関する事前学習済み大規模言語モデル(Large Language Model, LLM)のファインチューニングです。ここで、ファインチューニングとは、学習済みモデルに新たな層を追加した後にモデル全体を再学習するプロセスのことです。2つ目は、ファインチューニングされたLLMを用いた表形式の合成データの生成です。
ファインチューニング
概要
![GReaTにおける事前学習済みLLMへのデータ変換の流れ。Borisov et al.(2022)[9]の図2より引用。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/fca8eb2f-7344-4120-b5a2-41b6297e02b2/a6c5bd944124c5aefb1d569e8408726e.png)
上の図は、表形式データを用いてLLMのファインチューニングを行うためのデータ変換の流れを示しています。まず、図中の(a)で、表データを意味のある文に変換します。次に、図中の(b)で、特徴量の順序列を変更します。最後に、図中の(c)で、(b)で得られた文をLLMのファインチューニングに使用します。上記のプロセスでは、たとえば、”Occupation is doctor, Gender is female, Age is 34”のような文字列が使用されます。これは、前処理をほとんど必要とせず、情報の損失もないことを意味します。また、自然言語では語順が重要と考えられますが、提案手法では語順は重要でないことがわかります。以下に、提案手法の詳細を書きます。
テキストのエンコーディング
上図の(a)の操作です。特徴量名の列とサンプルの行で構成された表データを考えます。また、番目の特徴量の番目のデータ値をで表すことにします。このとき、表の各サンプルは次の主語-述語-目的語変換を用いてあるテキスト表現に変換されます。
ここで、はある特徴量名とそれに対応するデータ値の情報を含む句であり、は連結演算子を表します。たとえば、上図の左上で示したOriginal tabular data setの一行目は、”Age is 39, Education is Bachelors, Occupation is Adm-clerical, Gender is Male, Income is 50K”というテキスト表現に変換されます。
特徴量順序のランダム化
上図の(b)の操作です。表データの特徴量の間に順序関係はないので、テキスト表現の中のカンマ””で区切られた短い文をランダムに置換します。ここで、任意ので構成されるランダムな置換ベクトルを導入し、それをに作用させた結果を以下のように書くことにします。
この操作により、たとえば、上で示した”Age is 39, Education is Bachelors, Occupation is Adm-clerical, Gender is Male, Income is 50K”というテキスト表現は、”Education is Bachelors, Income is 50K, Age is 39, Occupation is Adm-clerical, Gender is Male”のように変換されます。このように変換されたテキスト表現を用いて、順序に依存関係を持たない生成言語モデルをファインチューニングします。
事前学習済みLLMのファインチューニング
上図の(c)の操作です。生成タスク用のエンコードされた表形式データに対する事前学習済みLLMのファインチューニングについて説明します。上記のエンコーディングとランダム化で得られたテキスト表現のセットを考えます。ここで、はランダムに抽出された置換ベクトルです。
LLMで処理するためには、入力文 を離散的かつ有限な語句の集合から意味をなす最小単位の字句であるトークンのシーケンスにエンコードする必要があります。これを”トークン化”と呼びます。ここで、であり、はを表現するために必要なトークンの数です。
トークンは文字や単語、サブ単語で構成されます。たとえば、”walking is good for health”という入力文が与えられたとします。これを文字ベースでトークン化すると、
になります。文字ベースを用いたトークン化では、トークンに含まれない未知語が現れにくくなります。一方、データ数が膨大になり、かつ、入力文の情報が失われるといった欠点があります。上記の入力文を単語ベースでトークン化すると、
になります。単語ベースを用いたトークン化では、データ数を抑えられ、入力文の情報も失われにくくなります。しかし、”walking”の原形である”walk”は未知語となり、学習することはできません。このように、文字ベースでのトークン化と単語ベースでのトークン化には、一長一短があります。そこで、サブ単語でのトークン化が提案されました。サブ単語を用いると、”walking”は”walk”と”ing”に分解され、”walk”という動詞の意味と”ing”という進行形の意味を分けて学習することができます。動詞と”ing”の組み合わせは動詞の数だけ存在するので、”動詞”と”動詞+ing”のペアを学習する場合に比べてデータ数を半分に抑えることができます。また、新たな動詞を学習した場合、その動詞の進行形も同時に学習することができます。以上が、文字・単語・サブ単語によるトークン化の説明です。
通常、自然言語では、次に現れるトークンは今までに現れたトークンの条件付き確率の積として次のように表現されます。
つまり、モデルは任意の長さのトークンの入力シーケンスから次のトークンを予測するための確率分布を出力するように訓練されます。この際、モデルの全体は、訓練データセット全体の確率を最大化するようなパラメータ最適化によって得られます。
実際にGReaTを使用するユーザーは、既存の生成言語モデルを選択することで、そのモデルに存在する膨大な量の知識を活用できます。このような大規模なデータの集まりを用いて学習することで、高精度なテキスト表現を生成するモデルが構築されます。
合成データの生成
概要
![GReaTを用いた合成データ生成過程の流れ。Borisov et al.(2022)[9]の図3より引用。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/ac2c6736-4b73-4b7d-b4c4-9bcb4328a702/7d060c2a5c3412789e550f75b636cf12.png)
GReaTを用いた表形式の合成データ生成過程の流れを上図に示します。まず、図中の(a)で、単一の特徴量名または特徴量名とその値の組み合わせをテキストに変換したものをファインチューニングされたLLMに入力します。次に、図中の(b)で、ファインチューニングされたLLMから新しい文を生成します。最後に、図中の(c)で、生成された新しい文を表形式に変換します。
データ生成の式
上図の(b)の操作です。ファインチューニングにより、入力をもとにカテゴリ分布を返す回帰モデルを構築したあとは、いくつかのサンプリング方法を適用することができます。たとえば、次のトークンは、LLMの出力から次式のように生成できます。
ここで、は出力の分布を調整する温度パラメータです。上式の詳細については、文献[10]をご参照ください。
表形式データのサンプリングについて
![GReaTを用いたAdultデータセット[11]に対するサンプリングの例。Borisov et al.(2022)[9]の図9より引用。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/adfb6a80-7e7d-444e-8a8b-b59d96e54789/0e18b46cf92e1a8fdec7f50d5c972d1e.png)
GReaTを用いたAdultデータセット[11]に対するサンプリングの例を上図に示します。この例では、教育(Education)、収入(Income)、年齢(Age)の3つの変数のみを使用しています。最終学歴が高等学校(HS-school)の場合の収入は74%の確率で$50K以下である一方、最終学歴が大学(Bachelors)の場合の収入は49%の確率で$50Kを上回るという現実に則した条件付き確率が得られています。同様に、最終学歴が高等学校(HS-school)かつ年収が$50Kを上回る場合の年齢は7%の確率で61歳であることなどもわかります。このような条件付き確率をもとに、新しい文が生成されます。その際、要求された形式を満足しないまれなケースに対しては、そのサンプルを却下します。論文[9]によると、却下されるサンプルの生成率は1%未満であることが言及されています。
評価
ここからは、複数の定性的および定量的な実験によって得られたGReaTのパフォーマンスを示します。
GReaTの構築に使用した事前学習済みLLM
![使用した事前学習済みLLMの詳細。Borisov et al.(2022)[9]の表9より引用。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/d419fca3-9438-400c-825a-66483b872ba4/a6c3d8a69109eec97da0900ca9f2c72a.png)
上の表に示した2つの異なる事前学習済みLLMを使用してGReaTを構築します。1つ目は、8,200万の学習パラメーターを持つGPT-2のdistillバージョン[11]です。これに対応する表形式データの生成モデルを、”Distill-GReaT”と呼びます。2つ目は、3 億5,500万の学習パラメーターを持つGPT-2のオリジナルバージョン[5]です。これに対応する表形式データの生成モデルを、単に”GReaT”と呼びます。
比較手法
比較する手法として、TVAE[12]、CTGAN[12]、CopulaGAN[13]を使用します。TVAEはVAEを表形式データに適用させたもの、CTGANはGANを表形式データに適用させたものです。CopulaGANは、SDVオープンソースライブラリで公開されているCTGAN[12]の変化モデルです。したがって、TVAE[12]、CTGAN[12]、CopulaGAN[13]、Distill-GReaT[9]、GReaT[9]の5つを評価に用います。
データセット
以下に示す4つのデータセットを使用します。
- Adult Income[14]
- HELOC[15]
- California Housing[16]
- Travel Customers[17]
機械学習効率
![機械学習効率の比較。Borisov et al.(2022)[9]の表1より引用。TRはTravel Customers[17]、HEはHELOC[15]、ADはAdult Income[14]、CHはCalifornia Housing[16]のデータセットを表しています。LRは線形/ロジスティック回帰、DTは決定木、RFはランダムフォレストを表しています。分類では精度スコアを、回帰では平均二乗誤差を示しています。最も優れた結果は太字で、2番目に優れた結果は下線で示しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/837cff23-3314-40be-bd76-a1e61640449b/725a1a0bc381082c6299fcc440b8d479.png)
各手法によって得られた機械学習効率を上の表に示します。Originalで示した値は、元データで学習して元データでテストした結果です。各手法で示した値は、合成データで学習して元データでテストした結果です。この表により、GReaTとDistill-GREaTのいずれかがすべての手法の中で最も優れており、高い機械学習効率を示すことがわかります。
ROCAUCとF1スコア
![ROCAUCおよびF1スコアの比較。Borisov et al.(2022)[9]の表4より引用。TRはTravel Customers[17]、HEはHELOC[15]、ADはAdult Income[14]のデータセットを表しています。LRは線形/ロジスティック回帰、DTは決定木、RFはランダムフォレストを表しています。最も優れた結果は太字で、2番目に優れた結果は下線で示しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/87a7afa4-0447-48e9-8fc2-50f27a4c1e8f/62e0894aae95db7692a106c647fc0df8.png)
各手法によって得られたROCAUCおよびF1スコアを上の表に示します。ここで、ROCAUCは受信者動作特性曲線(ROC)よりも下の部分の面積(AUC)を示しており、値が高いほど判別性能に優れていることを表します。F1スコアは2値分類問題に対する評価指標の一つであり、適合率(PRE)と再現率(REC)の調和平均として定義されます。
ここで、PREとRECは真陽性(TP)・偽陽性(FP)・偽陰性(FN)を用いて、それぞれ
で定義されます。上の表から、GReaTやDistill-GReaTを用いるとROCAUCおよびF1スコアで優れた結果が得られることがわかります。
最近接レコード距離のヒストグラム
![California Housing[16]データセットを用いて得られた最近接レコード距離(DCR)のヒストグラムの比較。Borisov et al.(2022)[9]の図4より引用。”Original Test Data Set”は、元の訓練データセットと元のテストデータセットの間のDCRを表しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/6b7878f5-4d74-4097-a088-06855b667e36/e1f391af990643b36d07f46634936a28.png)
最近接レコード距離(DCR)は、合成データが元データの単純なコピーではないことや、元データにノイズを追加しただけの結果ではないことを確認する指標です。DCRは、合成データセットと元データセットにおけるそれぞれの元を用いて、以下の式によって計算されます。
ここで、は2データ間の-normを表します。は、が元データセットにおいて、少なくとも1 つの実データと同一であることを意味します。内の合成レコードごとにDCRを計算し、結果の値をヒストグラムとしてプロットしたものを上図に示します。各手法においてであるデータはないため、プライバシーの観点から使用することができない合成データは生成されていないことがわかります。DCRにおいては、各手法で得られた結果の間に顕著な違いは確認されません。
Discriminatorによる識別
![Discriminatorによる識別結果の比較。Borisov et al.(2022)[9]の表2より引用。小さな値ほどDiscriminatorが識別しにくいことを表し、完全に識別できないデータセットの値は50を示します。最も優れた結果は太字で、2番目に優れた結果は下線で示しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/d560d311-1862-4a1d-b4ed-34d5a8b88f5b/e2ace1a7f3934237fb1b4077751848c0.png)
生成された合成データが元データと簡単に区別できないかどうかを調べるために、生成された合成データの訓練セット(ラベル0)と元データの訓練セット(ラベル1)の混合でDiscriminatorをトレーニングします。そして、生成された合成データのテストセットと元データのテストセットから均等な割合のサンプルを抽出したデータセットにおける識別精度を比較します。その結果を表した上の表は、Distill-GReaTおよびGReaTで得られた合成データと元データは識別しにくいことを示しています。
二変量同時分布
![California Housing[16]データセットの元データと合成データの比較。Borisov et al.(2022)[9]の図1より引用。物件の特徴を表す変数の経度(Longitude)と緯度(Latitude)の結合ヒストグラムを示しています。黒線はカリフォルニア州の実際の境界線を表しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/bbccaf94-52f1-4503-9a3a-b71d2959bd26/e7df855dce289a201b4fbbbf64168d0f.png)
GReaTによって生成される特徴分布と他の手法によって生成される特徴分布を定性的に比較するため、California Housing[16]データセットの経度と緯度の結合ヒストグラムを示します。CTGANとCopulaGANで得られた合成データの分布がカリフォルニア州の境界から大きくはみ出ています。このことから、CTGANとCopulaGANは変数間の依存関係をうまくモデル化できていないと言えます。TVAEで得られた合成データの分布はCTGANやCopulaGANのものと比べて良好ではありますが、それでもカリフォルニア州の境界の外にデータがはみ出ている領域が観察されます。一方、Distill-GReaTとGReaTでは、カリフォルニア州の実際の形状とよく一致するデータを生成できていることが示されています。これは、Distill-GReaTとGReaTで得られる合成データと元データの二変量同時分布はよく一致することを表しています。
元データと合成データの直接的な比較
![Adult Income[14]データセットの元データと合成データの比較。Borisov et al.(2022)[9]の図7より引用。年齢(age)と教育レベル(education_num)の結合分布を示しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/7f2cd79b-5819-4624-9f2c-a84632eee7f0/af49cba7f4092831ec1abd2d8386fb2f.png)
上図は、カーネル密度推定器を使用して計算されたAdult Income[14]データセットの年齢と教育レベルの結合分布です。等高線の形状を見ると、Distill-GReaTとGReaTは他の手法に比べて、元データに類似した特徴量分布を持つ合成データを生成することがわかります。
計算時間
![計算時間の比較。Borisov et al.(2022)[9]の表7より引用。100エポックでの訓練及びファインチューニングに要した時間(training / fine-tuning time)と1000サンプルの生成に要した時間(sampling time)を示しています。](https://s3.ap-northeast-1.amazonaws.com/wraptas-prod/acompany-engineer/b14cc9f4-992f-43ba-bb32-86f82bbe9e1a/1197c930bdfa8a04f5411c2d15013e17.png)
各手法における訓練時間とサンプル生成時間を上の表に示します。TVAE、CopulaGAN、CTGANの訓練時間は数分である一方、Distill-GReaTやGReaTの訓練時間は1時間を超えていることがわかります。また、サンプル生成時間においても、Distill-GReaTやGReaTは他の手法に比べて数十倍以上の時間を要しています。このように、Distill-GReaTやGReaTは他の手法に比べて計算時間を要する手法だと言えます。
Pythonパッケージ
GReaTを使用するためのPythonパッケージが提供されています。詳しくは、参考文献[18,19]をご参照ください。
まとめ
本記事では、表形式のデータ生成にTransformer-decoderネットワークアーキテクチャを適用した最初のモデルであるGReaT[9]を紹介しました。GReaTの特徴は以下のとおりです。
- 事前学習済みLLMを使用するため、テキストデータベースに蓄積された膨大な知識を活用できる。
- データの前処理を必要としないため使いやすく、カテゴリデータや数値データを変換する必要がないため情報損失を最小限に抑えることができる。
- 他の手法(TVAE, CTGAN, CopulaGAN)に比べて、GReaTから得られる合成データは有用性や類似性の観点から優れていると言える。
- GReaTは、他の手法に比べて多くの計算時間を要する。
参考文献
[1] S. Bourou, A. El Saer, T. H. Velivassaki, A. Voulkidis, T. Zahariadis, A Review of Tabular Data Synthesis Using GANs on an IDS Dataset, Information, 12, 375 (2021).
[2] M. Hernandez, G. Epelde, A. Alberdi, R. Cilla, and D. Rankin, Synthetic data generation for tabular health records: A systematic review, Neurocomputing, 493, 28 (2022).
[3] V. Borisov, T. Leemann , K. Seßler , J. Haug , M. Pawelczyk , and G. Kasneci, Deep neural networks and tabular data: a survey, IEEE Trans. Neural Netw., 1 (2022).
[4] (GPT) A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever, Improving language understanding by generative pre-training, (2018).
[5] (GPT-2) A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, and I. Sutskever, Language models are unsupervised multitask learners, (2019).
[6] (GPT-3) T. Brown, B. Mann, N., M. Subbiah, J. D. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, S. Agarwal, A. Herbert-Voss, G. Krueger, T. Henighan, R. Child, A. Ramesh, D. Ziegler, J. Wu, C. Winter, C. Hesse, M. Chen, E. Sigler, M. Litwin, S. Gray, B. Chess, J. Clark, C. Berner, S. McCandlish, A. Radford, I. Sutskever, and D. Amodei, Language models are few-shot learners, In Advances in Neural Information Processing Systems, H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin (Eds.), 33, Curran Associates, Inc., 1877 (2020).
[7] (GPT-4) OpenAI, GPT-4 Technical Report, (2023).
[8] (GPT-4) OpenAI, GPT-4 System Card, (2023).
[9] (GReaT) V. Borisov, K. Seßler, T. Leemann, M. Pawelczyk, and G. Kasneci, Language models are realistic tabular data generators, arXiv:2210.06280 (2022).
[10] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, arXiv:1503.02531 (2015).
[11] V. Sanh, L. Debut, J. Chaumond, and T. Wolf, DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter, In NeurIPS EMC2 Workshop, (2019).
[12] (TVAE,CTGAN) L. Xu, M. Skoularidou, A. Cuesta-Infante, and K. Veeramachaneni, Modeling tabular data using conditional GAN, in Proc. Adv. Neural Inf. Process. Syst., 33, 1 (2019).
[17] (Travel Customersデータセット) https://www.kaggle.com/datasets/tejashvi14/tour-travels-customer-churn-prediction