ゼロから作るDeep Learning 3 ステップ59~ステップ60 まとめ dezeroでRNN,LSTMを実装する
hirohirohirohiros.hatenablog.com
前回のVGGに引き続き,dezeroでRNNとLSTMを実装します!
ステップ55, 56
RNN
RNNは時系列データに対して効果を発揮するモデルです.それは,RNNが,出力を新たに入力してループする構造を持っているからです.これにより,RNNは状態を持つ事が出来ます.RNNにデータが入力されると,状態が更新されその状態に応じて出力が決まります.ループする構造を持ち,状態を保持すると言われると複雑そうに感じますが,dezeroではシンプルに実装が出来ます.
RNNの順伝播はという式で表されます.この式から分かるように,RNNは重みを二つ持ちます.一つは入力xに対する重み,もう一つは過去の出力を次の時刻の出力にする重みです.これらの入力は別々の重みとして管理します.
これをdezeroで書くとこうなります.
class RNN(Layer): def __init__(self, hidden_size, in_size=None): super().__init__() self.x2h = Linear(hidden_size, in_size=in_size) self.h2h = Linear(hidden_size, in_size=in_size, nobias=True) self.h = None def forward(self, x): if self.h is None: h_new = F.tanh(self.x2h(x)) else: h_new = F.tanh(self.x2h(x) + self.h2h(self.h)) self.h = h_new return h_new
__init__でレイヤーを二つ用意しています.そしてforwardでF.tanh(self.x2h(x) + self.h2h(self.h))として,上の式を実装しています.もちろん,一番最初の入力では,過去のデータが存在しないため,h_new=Noneとなり,F.tanh(self.x2h(x))となります.
Truncated BPTT
BPTTとはBackpropagation Through Timeの略です.これは時間を遡って逆伝播を行う事を意味します.RNNはデータの時系列の並びを学習するので,時間という単語が使われています.Truncatedは打ち切るといった意味があります.よって,Truncated BPTTは時間を遡る逆伝播を打ち切る処理になります.
RNNはデータをいくつでも与えることが出来ます.そして,その数に応じて計算グラフが長く伸びていきます.しかし,それだと逆伝播が上手く行えないので,ある程度の長さで打ち切る必要があります.これをTruncated BPTTと言います.しかし,本当に打ち切る必要性があるのか,打ち切ることでどの程度効果があるのか次の実験で試してみたいと思います.
サイン波の予測
サイン波にノイズを与えたデータを学習データとして与えて,波の形状を予測するモデルを作成します.学習データはサイン波で,テストデータはコサイン波とします.BPTTの長さを30としたときのlossの推移はこうなります.
更に,予測結果はこうなります.
少しデータに乱れがありますが,おおよそ正しく予測できていることが分かります.
Truncated BPTTの効果検証
本書ではTruncated BPTTを割と唐突に紹介され,実際なぜ必要なのか,どれくらい効果があるのかは分かりにくいため,検証してみます.同じサイン波の学習とコサイン波の予測を行ってみます.
まず,Truncated BPTTを1回も行わず,データ全てを逆伝播させて学習を行ってみます.Truncatedする長さを指定するbptt_lengthを3000に設定し,一度もTruncatedしないようにして,学習させます.lossの推移はこうなります.
bptt_length=30の時と比べ,lossの減少がだいぶ遅い事が分かります.更に,bptt_length=30の時はepoch100の時ほぼ0だったのに対し,bptt_lenght=3000では,10程度あるので学習があまり進んでないことも分かります.この状態でコサイン波を予測させるとこうなります.
それっぽい波の形状は出来てますが,全くコサイン波と一致してないことが分かります.正直ここまで学習が出来ていないとは思いませんでした……途中で打ち切るTruncated BPTTの大切さが分かります.
逆に,2回に1回は打ち切るbptt_lenghtも試してみます.bptt_length=1は過去のデータを一つも使わないので普通のニューラルネットワークとなってしまうため,bptt_length=2で試してみます.
bptt_length=30の時と同じように適切に学習が進んでいることが分かります.コサイン波の予測はこうなります.
正確に予測することが出来ています.bptt_lengthを適切に設定し,途中で逆伝播を切ってやることで,適切に学習が出来る事が分かります.