hirohirohirohirosのブログ

地方国立大学に通う情報系学部4年

ゼロから作るDeep Learning 3 ステップ47~ステップ51 まとめ

hirohirohirohiros.hatenablog.com

ステップ49

Datasetクラス

 大規模なデータからなるデータセットを使って学習しようとしたとき,その大規模データを一つのインスタンスにして扱おうとすると,メモリが足りなくなります.この問題を解決するために,Datasetクラスを作ります.
 dezeroでは,Datasetクラスは__getitem__と__len__の二つのメゾットを持つ事を必須としています.これはPytorchでも同様です.
 大規模データでは__getitem__メゾットで,indexを指定して呼び出したときに,ファイルにある{index}.npyのようにデータを取り出すようにすることで,大量にメモリを使用することを防げます.
 前処理もDatasetに書くことで,Datasetクラスの中で処理を完結されることが出来ます.
 Pytorchでやっていた事がどういうことだったのか分かってきて楽しいです.

ステップ50

DataLoader

 DataLoaderはミニバッチを取り出すために使います.ステップ49で,大規模データを一度に読み込むとメモリが足りなくなるという話があったように,学習時でも大規模データを一度に学習しようとすると,メモリが足りなくなります.よって,小さな学習データに小分けして,その小分けした学習データで学習します.この小分けした学習データをミニバッチと言います.ミニバッチを全て学習させると1エポックが終了したことになります.
 ステップ49の学習コードとステップ50の学習コードを比較すると,DataLoaderクラスの役目がよく分かります.
 ステップ49の学習コードはこのようになっています.(一部抜粋)

train_set = dezero.datasets.Spiral()
model = MLP((hidden_size, 10))
...
for epoch in range(max_epoch):
    index = np.random.permutation(data_size)
    sum_loss = 0
    
    for i in range(max_iter):
        batch_index = index[i*batch_size:(i+1)*batch_size]
        batch = [train_set[i] for i in batch_index]
        batch_x = np.array([example[0] for example in batch])
        batch_t = np.array([example[1] for example in batch])

        y = model(batch_x)
        ...

 対してステップ50のコードはこのようになります.

train_set = datasets.Spiral(train=True)
test_set = datasets.Spiral(train=False)
train_loader = DataLoader(train_set, batch_size)
test_loader = DataLoader(test_set, batch_size, shuffle=False)
model = MLP((hidden_size, 3))
...
for epoch in range(max_epoch):
    sum_loss= 0
    for x, t in train_loader:
        y = model(x)
        ...
    with dezero.no_grad():
        for x, t in test_loader:
            y = model(x)
            ...

 ステップ49ではbatch_indexという,どのインデックスのデータをミニバッチとするのか示したリストを作り,trainから抜き出していました.ステップ50では,その処理を全てDataloaderというクラスの中にしまい,train_loader = DataLoader(train_set, batch_size)と書くだけで,ミニバッチの取り出しを可能にしています.
 Dataloaderクラスは__iter__と__next__メゾットを記入しているため,イテレータです.イテレータなのでfor文を書くことで一つずつ値を取り出す事が出来ます.だからfor x, t in train_loader:という書き方が出来る訳です.
 Pytorchで使っているDataloaderが大体どうなっているのかを知れておもしろいです.