hirohirohirohirosのブログ

地方国立大学に通う情報系学部4年

ゼロから作るDeep Learning 3 ステップ41~ステップ44 まとめ

hirohirohirohiros.hatenablog.com

ステップ42

トイデータセット

 トイデータセットとは,実験用の小さなデータセットの事を言います.本書では線形回帰の実験をするためにy=ax+bとなるような点を100個用意していることが分かります.
 scikit-learnでもトイデータセットが用意されています.例えば,

  • ボストンの住宅価格
  • アヤメの種類分別
  • 糖尿病の進行状況
  • 手書き文字分類

などがあります.そして,データ数はボストンの住宅価格で506,アヤメの種類分別で150と少なめになっていることが分かります.
https://scikit-learn.org/stable/datasets/toy_dataset.html

線形回帰

 線形回帰を実行します.今回は損失関数はMSEを使っています.本書のコードに加えて,損失関数がどのように減少していっているかをプ六tpしてみようと思います.
 コードは以下のようになります

np.random.seed(0)
x = np.random.rand(100, 1)
y = 5 + 2*x + np.random.rand(100, 1)

W = Variable(np.zeros((1, 1)))
b = Variable(np.zeros(1))

def predict(x):
    y = F.matmul(x, W) + b
    return y

def mse(x0, x1):
    diff = x0 - x1
    return F.sum(diff**2)/len(diff)

lr = 0.1
iters = 100
loss_list = []

for i in range(iters):
    y_pred = predict(x)
    loss = mse(y, y_pred)

    W.cleargrad()
    b.cleargrad()
    loss.backward()

    W.data -= lr*W.grad.data
    b.data -= lr*b.grad.data
    print(W, b, loss)
    loss_list.append(loss.data)

plt.plot(np.arange(100), np.array(loss_list))
plt.show()

 plotするためのloss_listにlossを入れるとき,matplotlibは勿論自作クラスのVariableを扱うことは出来ないので,loss.dataとして数値型を入れるように気をつけます.すると表示されるグラフはこうなります.

 今回のような簡単なタスクならdezeroでもすぐ損失関数を収束させることが出来ていると分かります.

ステップ43

functionクラスを継承したlinear関数

 本書では通常の関数としてsimple_linear関数を作り,funcitonクラスを継承した関数はコードを載せずgithubのみの掲載となっていました.
 そもそも,そのまま関数を作るのでなく,functionクラスを継承して作った方がよい理由は,メモリの使用効率が上がるからです.通常の関数で+やーといった演算を行うと,それぞれに対してノードが作られ計算グラフが出来ます.しかし,functionクラスを継承させて関数を作ればその関数1つでノードが作られます.こうすることで,計算を終えるとすぐ全てのデータがメモリから消去され,メモリの効率が上がります.
 本書では,functionクラスを継承する利点がメモリの効率であることを使い,クラスを継承しなくても,使わない値をNoneとすることでメモリ効率を上げる工夫を説明していました.その代わり,functionクラスを継承しをここで掲載します(公式のgithubにあるものです)

class Linear(Function):
    def forward(self, x, W, b):
        y = x.dot(W)
        if b is not None:
            y += b
        return y

    def backward(self, gy):
        x, W, b = self.inputs
        gb = None if b.data is None else sum_to(gy, b.shape)
        gx = matmul(gy, W.T)
        gW = matmul(x.T, gy)
        return gx, gW, gb

def linear(x, W, b=None):return Linear()(x, W, b)

 bが0の時,普通に使うときはわざわざ0と書かず,bにあたる数をそもそも書かないのでbはNoneになることに気をつけます.その他backwardはステップ41でやった行列の積の微分を使います.