hirohirohirohirosのブログ

地方国立大学に通う情報系学部4年

ゼロから作るDeep Learning 3 ステップ11~ステップ16 まとめ

hirohirohirohiros.hatenablog.com

ステップ14

累積代入文の仕様

 a = a+bとa += bは全く同じ意味だと思っていましたが,厳密には違うようです.

x = np.array(1)
print(id(x))
x += np.array(1)
print(id(x))
x = x + np.array(1)
print(id(x))
>>1903444783952
>>1903444783952
>>1903444551024

 xに1を足すとき,累積代入文+=を使うとxは元のxと同じidになっています.つまり,元のxのメモリの値をそのまま書き換えていることになります.しかし,x=x+1とした時はidが変わっていてこの処理はコピーということになります.
 ただの整数ではこのような仕様は無く全てコピーされるようです.

x = 1
print(id(x))
x += 1
print(id(x))
x = x + 1
print(id(x))
>>1638183136
>>1638183168
>>1638183200

 ただし,xをリストにすると,ndarrayと同じ挙動を取ります.

x = [1]
print(id(x))
x += [1]
print(id(x))
x = x + [1]
print(id(x))
>>2698927754184
>>2698927754184
>>2698927754888

 これはpythonにおいて,intは変更不可能なイミュータブル,listは変更可能なミュータブルな事が関係しているのかなと思います.intは元から変更することが出来ないので,加算などしても,それはコピーされるだけなのは納得できます.listは変更可能なので同じメモリに上書きされるような挙動も起こすのだと思います.ただそれがa+=bかa=a+bかで変わるというのはいささか不思議な仕様だと思います……
 そして,np.array(1)のように0次元のndarrayでも,listと同じように変更可能なミュータブルだから,このような挙動になるのだと思われます.

ステップ16

Variableクラスのbackwardをheapqで実装する

 今までのVariableクラスのbackwardは世代順関係なく取り出していたため,適切な順番でbackwardされないという問題がありました.そこで世代順でソートし,一番大きい世代から取り出すことで正しい順番でbackwardする事が出来ます.ただし,ここでやりたいことは,一番大きい値を取り出す事です.よってfuncを全ての要素についてソートする必要はありません.
 全ての要素でソートすると,ソートするたびにO(Nlog N)の計算量が掛かってしまいます.今回やりたいことは一番大きな値を取り出す事だけなのでヒープ構造を使えばよさそうです.ヒープ構造で一番大きな値を取り出す操作はO(log N)で出来ます.本書ではソートする方法を載せておりヒープは実装されていません.ここで実装したいと思います.
 ソートするkeyは世代です.世代が同じならどの順番で取り出されても良いです.ヒープの要素にはタプルを取れるので,ヒープに値を追加するには

heapq.heappush(funcs, (-f.generation, f))

とすればよさそうです(ヒープは最小値を取り出す事が出来ないので最大値を取り出すには-を掛ける必要があります).
 しかし,これではエラーが出てしまいます.値を追加するとき既にヒープに同じ世代が含まれていたら,タプルの次の値を比較して順番を決めようとします.しかしタプルの二番目はf,つまりVariableクラスであり,>や<が定義されていないのでエラーとなってしまいます.
 これに対処するには色々方法があるかと思いますが,公式リファレンスでは項目番号を要素に追加する方法が提示されていました.同じ値が含まれていないなら,それ以降のタプルの値を見ることはないので,pushするたびに+1される項目番号を要素に追加すれば,完全な順序分けができることになります.
 この方法で実装しようと思います.countという変数を用意し,add_funcを呼び出すたび+1するようにします.そしてヒープに(世代, count, f)をpushします.取り出すとき,始め二つの要素は不必要で使わないで,それを示すように名前は_とします.

import heapq
count = 0
def add_func(f, count):
    if f not in seen_set:
        heapq.heappush(funcs, (-f.generation, count, f))
        seen_set.add(f)
        count += 1
   return count
        
count = add_func(self.creator, count)
while funcs:
    _, _, f = heapq.heappop(funcs)
...

このようにすればヒープで正しくbackwardが実装出来ているはずです.今のところ要素数が少ないので目に見えた差はありませんが,より大きなデータを扱うようになれば,ヒープによってこのDezeroが高速に動くようになるはずです.