A Fully Differentiable Beam Search Decoder
第11回最先端NLP勉強会
2019年 9月 28日
Ronan Collobert, Awni Hannun, Gabriel Synnaeve
In ICML, 2019
読み手:柯遠志(慶應大 萩原研 D3)
簡単にまとめると(by 発表者)
選んだ理由
H H H E E E L L _ L O _
H H E E E E L _ _ L O _
・・・・・・
HELLO
HELLO
重いし、文字の所属するアライメントも考慮しなくて性能も低い
昔の報告はLER 30%台
現在はほぼ使われない
繰り返した文字を認識するためのblank token
H H H E E E L L _ L O _
H H E E E E L _ _ L O _
・・・・・・
HELLO
HELLO
Bayes Ruleで計算するError Riskを最小化
Mutual Informationで出力と正解の関係を表現
正解になるPathのスコアを最大化
動的計画法で計算量を軽減
CTCの改良�(文字レベルの正則化しない、blankラベルを不使用)
H H H E E E L L _ L O _
H H E E E E L _ _ L O _
・・・・・・
HELLO
HELLO
H H H E E E L L _ L O _
H H E E E E L _ _ L O _
・・・・・・
HELLO
HELLO
今から見ると性能が低い(LER 30%台)
なんとWSJデータセットのSOTAらしい
よく使われてる音声認識の代表的なLoss関数たち
しかし、どちらにも、学習とInferenceが不一致のため、exposure biasとlabel biasがある
C
A
T
<B>
C
A
T
<E>
学習する時はGround Truthを使用
<B>
C
A
T
<E>
Inferenceする時は自分の出力を頼る
1個間違ったら後全部間違う
(Accumulated Error)
<B>
C
D
A
O
T
G
<E>
0.4
0.6
0.9
0.1
0.4
0.6
0.9
0.1
0.4
0.6
1.0
1.0
P(‘<B>CAT<E>’) = 0.324
P(‘<B>DOG<E>’) = 0.216
しかし、Greedy Searchなら‘<B>DOG<E>’
Beam Searchで解決できるが、InferenceだけにBeam Searchすれば、TrainingとInferenceが不一致
ビームから落ちるまで
はモデル自分の予測
落ちた箇所から次の予測はTeacher-forcing
ビームから落ちるまで
はモデル自分の予測
tまではビーム内であれば0, 落ちたら正の数
t時刻の正解のスコア
t時刻のK番目スコア
最終時刻K=1, 他の時刻K=beam size
Forward: ビームから落ちる箇所を見つけて修正
Backward: 一つの系列にまとめる
ビームから落ちるまで
はモデル自分の予測
落ちた箇所から次の予測はTeacher-forcing
ビームから落ちるまで
はモデル自分の予測
提案手法:Differentiable Beam Search Decoder (DBD)
DBDは何:微分できるBeam Search Decoder
どうやって:尤度の計算手法に工夫をかけて実現
特徴:
DBDの仕組み
正解の尤度を最大化するため、�正解になるのアライメントの尤度を全員最大化したい
しかし、普通のSoftmaxを計算すれば、重い
全部の可能の出力に対応する全部アライメントの集合、めっちゃでかい
DBDの仕組み
Beam Searchするので�全体的な一番いい系列≈Beam内の一番いい系列�
だから、Beam外の出力を考慮しなくてもいい
全部の可能→Beam内の可能、正則化項の計算が一気に軽く
モデルのHypothesis (=全部の予測?)
DBDのトレーニング
Beam Searchするので�全体的な一番いい系列≈Beam内の一番いい系列�
だから、Beam外の出力を考慮しなくてもいい
全部の可能→Beam内の可能、正則化項の計算が一気に軽く
モデルのHypothesis (=全部の予測?)
しかし、Trainingするとき、Ground Truthは必ずBにあるわけがないじゃない?
学習時、Groud Truthがビームから落ちれば?
DBDの答え:学習時は、正解を入れてBeamと正解アライメント集合の和集合正則化すればいい
学習時、Groud Truthがビームから落ちれば?
DBDの答え:学習時は、正解を入れてBeamと正解アライメント集合の和集合正則化すればいい
ので、この正則化項の計算も容易
動的計画法で計算 (CTC、ASGのForward Passと同じ)
普通にBeam Searchで計算
普通にBeam Searchで計算
全部微分できるのため、普通に連鎖律でBPで学習できる
音声モデルの出力スコア
文字レベルの言語モデル
単語レベルの言語モデル
音声モデルが出力した
t時刻の正解文字のスコア
文字モデルが出力したt-1時刻の文字からt時刻の正解文字の推移スコア
単語モデルが評価した出力全体のスコア
1-gramからK-gramの中で単語の関連性をモデリングする系列評価として最低限のモデル
実験1 音声データ+学習済み言語モデル
Model | nov93dev (Validation Set) | nov92 (Test Set) |
ASG 10M AM (beam size 8000) | 8.5 | 5.6 |
ASG 10M AM (beam size 500) | 8.9 | 5.7 |
ASG 7.5M AM (beam size 8000) | 8.8 | 6.0 |
ASG 7.5M AM (beam size 500) | 9.4 | 6.1 |
DBD 10M AM (beam size 500) | 8.7 | 5.9 |
DBD 7.5M AM (beam size 500) | 7.7 | 5.3 |
DBD 7.5M AM (beam size 1000) | 7.7 | 5.1 |
Attention RNN + CTC [Bahdanau+] | | 9.3 |
CNN + ASG [Zeghidour+] | 9.5 | 5.6 |
CNN + ASG (wav + convLM) [Zeghidour+] | 6.8 | 3.5 |
RNN + E2E-LF-MMI [Hadian+] | | 4.1 |
BILSTM + PAPB + CE [Baskar+] | | 3.8 |
Improved LF-MMI [Hadian+] | 4.3 | 2.5 |
実験1 音声データ+学習済み言語モデル
Model | nov93dev (Validation Set) | nov92 (Test Set) |
ASG 10M AM (beam size 8000) | 8.5 | 5.6 |
ASG 10M AM (beam size 500) | 8.9 | 5.7 |
ASG 7.5M AM (beam size 8000) | 8.8 | 6.0 |
ASG 7.5M AM (beam size 500) | 9.4 | 6.1 |
DBD 10M AM (beam size 500) | 8.7 | 5.9 |
DBD 7.5M AM (beam size 500) | 7.7 | 5.3 |
DBD 7.5M AM (beam size 1000) | 7.7 | 5.1 |
Attention RNN + CTC [Bahdanau+] | | 9.3 |
CNN + ASG [Zeghidour+] | 9.5 | 5.6 |
CNN + ASG (wav + convLM) [Zeghidour+] | 6.8 | 3.5 |
RNN + E2E-LF-MMI [Hadian+] | | 4.1 |
BILSTM + PAPB + CE [Baskar+] | | 3.8 |
Improved LF-MMI [Hadian+] | 4.3 | 2.5 |
もっと複雑なSOTAモデルに劣るけど...
小さいなビームサイズでも高性能でASGに勝る
実験2 音声データのみ (言語モデルはJoint Learning)
Model | nov93dev (Validation Set) | nov92 (Test Set) |
ASG (zero LM decoding) | 18.3 | 13.2 |
ASG (2-gram LM decoding) | 14.8 | 11.0 |
ASG (4-gram LM decoding) | 14.7 | 11.3 |
DBD zero LM | 16.9 | 11.6 |
DBD 2-gram LM | 14.6 | 10.4 |
DBD 2-gram-bilinear LM | 14.2 | 10.0 |
DBD 4-gram LM | 13.9 | 9.9 |
DBD 4-gram-bilinear LM | 14.0 | 9.8 |
RNN + CTC [Graves+] | | 30.1 |
Attention RNN + CTC [Bahdanau+] | | 18.6 |
Attention RNN + CTC + TLE [Bahdanau+] | | 17.6 |
Attention RNN + seq2seq + CNN [Chan+] | | 9.6 |
BILSTM + PAPB + CE [Baskar+] | | 10.8 |
実験2 音声データのみ (言語モデルはJoint Learning)
Model | nov93dev (Validation Set) | nov92 (Test Set) |
ASG (zero LM decoding) | 18.3 | 13.2 |
ASG (2-gram LM decoding) | 14.8 | 11.0 |
ASG (4-gram LM decoding) | 14.7 | 11.3 |
DBD zero LM | 16.9 | 11.6 |
DBD 2-gram LM | 14.6 | 10.4 |
DBD 2-gram-bilinear LM | 14.2 | 10.0 |
DBD 4-gram LM | 13.9 | 9.9 |
DBD 4-gram-bilinear LM | 14.0 | 9.8 |
RNN + CTC [Graves+] | | 30.1 |
Attention RNN + CTC [Bahdanau+] | | 18.6 |
Attention RNN + CTC + TLE [Bahdanau+] | | 17.6 |
Attention RNN + seq2seq + CNN [Chan+] | | 9.6 |
BILSTM + PAPB + CE [Baskar+] | | 10.8 |
簡単の言語モデルだけでも�ただ同時学習するだけで性能を向上できる
既存モデルからの転移学習の考察
Beam sizeについて考察
まとめ
参考文献