hirohirohirohirosのブログ

地方国立大学に通う情報系学部4年

ゼロから作るDeep Learning 3 ステップ25~ステップ32 まとめ

hirohirohirohiros.hatenablog.com

ステップ25, 26

計算グラフの可視化

 計算グラフを可視化するにはgraphvizというツールを使います.本書ではwindows版でのインストールの仕方が説明ありませんでしたが,graphvizのホームページに行き,windows版をダウンロードしてインストールすれば無事に出来ました.
 ステップ24で本書には無い独自に実装したBeale関数について計算グラフを作りました.それが下の画像になります.

ステップ27

テイラー展開微分

 sin関数の微分は解析的にcosと求められますが,テイラー展開することで近似的に微分することも出来ます.テイラー展開は展開する項数を指定することで近似の程度を調整することが出来ます.今回はthresholdという変数を用意し,項の値がthreshold未満になったら終了するという処理にしています.
 threshold=1e-04の時の計算グラフが以下です.

 それに対し,誤差をより少なくしたthreshold=1e-08としたときの計算グラフが以下です.

 thresholdを小さくすることで,深い計算グラフを作る事が出来ました.thresholdを小さくすることでfor分の繰り返される回数が増えたことによります.

ステップ32

高階微分の実装

 高階微分を実装ための最大のポイントはbackwardで受け渡す値をnp.arrayからVariableクラスに変える事で,計算グラフを構築させることでした.そのために,関数クラスの逆伝播を書き換える必要がありました.今まではndarrayを使っていたため,dataを取り出すようなことをしていましたが,これからはVariableクラスを直接扱うためそのまま書くことで実装出来ます.
 本書では,Add, Mulのみ変更点を記述していましたが,ここでは全ての関数について変更後のコードを載せてみます.変更する点はMulと同じようにdataを取り出していたところをinputsのみにする所です.また,Powクラスで少し躓いたのでその点も記述しておきます.

class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y

    def backward(self, gy):
        return gy, gy

class Mul(Function):
    def forward(self, x0, x1):
        y = x0*x1
        return y
    
    def backward(self, gy):
        x0, x1 = self.inputs
        return gy*x1, gy*x0

class Neg(Function):
    def forward(self, x):
        return -x

    def backward(self, gy):
        return -gy

class Sub(Function):
    def forward(self, x0, x1):
        y = x0 - x1
        return y

    def backward(self, gy):
        return gy, -gy

class Div(Function):
    def forward(self, x0, x1):
        y = x0 / x1
        return y

    def backward(self, gy):
        x0, x1 = self.inputs
        gx0 = gy / x1
        gx1 = gy*(-x0 / x1**2)
        return gx0, gx1

class Pow(Function):
    def __init__(self, c):
        self.c = c

    def forward(self, x):
        y = x**self.c
        return y

    def backward(self, gy):
        x, = self.inputs
        c = self.c
        gx = c*x**(c-1)*gy
        return gx

 Powクラスについて,(今までもそうなっていたかも知れませんが)x = self.inputs[0].dataの所を,x, = self.inputsとする必要があります.
 inputsはリストとして渡されるので,x = self.inputsとするとxがリストになってしまいエラーとなります.x, = self.inputsとすることで,inputsがアンパックされ,xが数字として扱われ,無事累乗することが出来ます.