ゼロから作るDeep Learning 3 ステップ37~ステップ40 まとめ
hirohirohirohiros.hatenablog.com
ステップ38
実際のtranpose関数
本書で実装したtranspose関数は2次元テンソルの転置のみを行うように実装されていましたが,numpyのtransopose関数はより汎用性が高いです.
n次元テンソルでも,軸の順番を書くことで,軸の入れ替えを行うことが出来ます.4次元テンソルxがあった時,x.transpose(1, 0, 3, 2)と書くと,元のxの0次元目は1次元目に,1次元目は0次元目に,2次元目は3次元目に,3次元目は2次元目に,軸が入れ替わります.
今まで2次元テンソルでx.transpose()と書いたときは,x.transpose(1, 0)の省略形であるイメージだと思います.
これをdezeroで実装したコードは本書では掲載されておらず,githubに載せられているのみになっています.ここでは,そのコードの解説もしようと思います.
実際に実装したコードが以下になります.
class Transpose(Function): def __init__(self, axes=None): self.axes = axes def forward(self, x): y = x.transpose(self.axes) return y def backward(self, gy): if self.axes is None: return transpose(gy) axes_len = len(self.axes) inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes])) return transpose(gy, inv_axes)
始めに,__init__で,入れ替える軸の順番をaxesで保存します.2次元テンソルで軸を指定しなかった場合,axesはNoneになります.
forward関数は,2次元テンソルの時と同じようにそのままtranspose関数を使います.
backward関数では,軸を指定したときとしなかったときで処理が少し異なります.2次元テンソルで軸を指定しなかったときは,axesがNoneとなるので,if self.axes is Noneとして場合分けをし,本書で書かれてたコードと同じように実装します.
軸を指定していたとき処理が少し複雑です.結局,backwardでもテンソルを転置して返すという処理をすべきなのは,2次元テンソルから変わらないので,最終的に返すのはtransposeです.
inv_axesで,self.axesをnp.argsortとして,ソート後のインデックスを出力します.こうすることで,逆向きの入れ替えを表現することが出来ます.
[ax % axes_len for ax in self.axes]の処理は正直何の意味があるのかよく分かりませんでした……0からlen(x)まで全ての数字が入っているならself.axesと[ax % axes_len for ax in self.axes]は同じ値になるはずだからです.