hirohirohirohirosのブログ

地方国立大学に通う情報系学部4年

ゼロから作るDeep Learning 3 ステップ37~ステップ40 まとめ

hirohirohirohiros.hatenablog.com

ステップ38

実際のtranpose関数

 本書で実装したtranspose関数は2次元テンソルの転置のみを行うように実装されていましたが,numpyのtransopose関数はより汎用性が高いです.
 n次元テンソルでも,軸の順番を書くことで,軸の入れ替えを行うことが出来ます.4次元テンソルxがあった時,x.transpose(1, 0, 3, 2)と書くと,元のxの0次元目は1次元目に,1次元目は0次元目に,2次元目は3次元目に,3次元目は2次元目に,軸が入れ替わります.
 今まで2次元テンソルでx.transpose()と書いたときは,x.transpose(1, 0)の省略形であるイメージだと思います.
 これをdezeroで実装したコードは本書では掲載されておらず,githubに載せられているのみになっています.ここでは,そのコードの解説もしようと思います.
 実際に実装したコードが以下になります.

class Transpose(Function):
    def __init__(self, axes=None):
        self.axes = axes

    def forward(self, x):
        y = x.transpose(self.axes)
        return y

    def backward(self, gy):
        if self.axes is None:
            return transpose(gy)

        axes_len = len(self.axes)
        inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))
        return transpose(gy, inv_axes)

 始めに,__init__で,入れ替える軸の順番をaxesで保存します.2次元テンソルで軸を指定しなかった場合,axesはNoneになります.
 forward関数は,2次元テンソルの時と同じようにそのままtranspose関数を使います.
 backward関数では,軸を指定したときとしなかったときで処理が少し異なります.2次元テンソルで軸を指定しなかったときは,axesがNoneとなるので,if self.axes is Noneとして場合分けをし,本書で書かれてたコードと同じように実装します.
 軸を指定していたとき処理が少し複雑です.結局,backwardでもテンソルを転置して返すという処理をすべきなのは,2次元テンソルから変わらないので,最終的に返すのはtransposeです.
 inv_axesで,self.axesをnp.argsortとして,ソート後のインデックスを出力します.こうすることで,逆向きの入れ替えを表現することが出来ます.
 [ax % axes_len for ax in self.axes]の処理は正直何の意味があるのかよく分かりませんでした……0からlen(x)まで全ての数字が入っているならself.axesと[ax % axes_len for ax in self.axes]は同じ値になるはずだからです.