ゼロから作るDeep Learning 3 ステップ41~ステップ44 まとめ
hirohirohirohiros.hatenablog.com
ステップ42
トイデータセット
トイデータセットとは,実験用の小さなデータセットの事を言います.本書では線形回帰の実験をするためにy=ax+bとなるような点を100個用意していることが分かります.
scikit-learnでもトイデータセットが用意されています.例えば,
- ボストンの住宅価格
- アヤメの種類分別
- 糖尿病の進行状況
- 手書き文字分類
などがあります.そして,データ数はボストンの住宅価格で506,アヤメの種類分別で150と少なめになっていることが分かります.
https://scikit-learn.org/stable/datasets/toy_dataset.html
線形回帰
線形回帰を実行します.今回は損失関数はMSEを使っています.本書のコードに加えて,損失関数がどのように減少していっているかをプ六tpしてみようと思います.
コードは以下のようになります
np.random.seed(0) x = np.random.rand(100, 1) y = 5 + 2*x + np.random.rand(100, 1) W = Variable(np.zeros((1, 1))) b = Variable(np.zeros(1)) def predict(x): y = F.matmul(x, W) + b return y def mse(x0, x1): diff = x0 - x1 return F.sum(diff**2)/len(diff) lr = 0.1 iters = 100 loss_list = [] for i in range(iters): y_pred = predict(x) loss = mse(y, y_pred) W.cleargrad() b.cleargrad() loss.backward() W.data -= lr*W.grad.data b.data -= lr*b.grad.data print(W, b, loss) loss_list.append(loss.data) plt.plot(np.arange(100), np.array(loss_list)) plt.show()
plotするためのloss_listにlossを入れるとき,matplotlibは勿論自作クラスのVariableを扱うことは出来ないので,loss.dataとして数値型を入れるように気をつけます.すると表示されるグラフはこうなります.
今回のような簡単なタスクならdezeroでもすぐ損失関数を収束させることが出来ていると分かります.
ステップ43
functionクラスを継承したlinear関数
本書では通常の関数としてsimple_linear関数を作り,funcitonクラスを継承した関数はコードを載せずgithubのみの掲載となっていました.
そもそも,そのまま関数を作るのでなく,functionクラスを継承して作った方がよい理由は,メモリの使用効率が上がるからです.通常の関数で+やーといった演算を行うと,それぞれに対してノードが作られ計算グラフが出来ます.しかし,functionクラスを継承させて関数を作ればその関数1つでノードが作られます.こうすることで,計算を終えるとすぐ全てのデータがメモリから消去され,メモリの効率が上がります.
本書では,functionクラスを継承する利点がメモリの効率であることを使い,クラスを継承しなくても,使わない値をNoneとすることでメモリ効率を上げる工夫を説明していました.その代わり,functionクラスを継承しをここで掲載します(公式のgithubにあるものです)
class Linear(Function): def forward(self, x, W, b): y = x.dot(W) if b is not None: y += b return y def backward(self, gy): x, W, b = self.inputs gb = None if b.data is None else sum_to(gy, b.shape) gx = matmul(gy, W.T) gW = matmul(x.T, gy) return gx, gW, gb def linear(x, W, b=None):return Linear()(x, W, b)
bが0の時,普通に使うときはわざわざ0と書かず,bにあたる数をそもそも書かないのでbはNoneになることに気をつけます.その他backwardはステップ41でやった行列の積の微分を使います.