hirohirohirohirosのブログ

地方国立大学に通う情報系学部4年

PyTorch実戦入門 第3章 演習問題 解答まとめ

1

import torch
a = torch.tensor(list(range(9)))
print(a.size())
print(a.storage_offset())
print(a.stride())

>>torch.Size([9])
>>0
>>(1,)

a

 .view()を使うと,

b = a.view(3, 3)
print(b)

>>tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

となります.3*3の行列に変形されたことが分かります.同じストレージを共有しているかどうかを確認するにはid()関数を使います.これは,Pythonが持っている関数で,オブジェクトのidを返します.このidが等しければ同じストレージを共有していることになります.

print(id(a.storage()) == id(b.storage()))

>> True

同じストレージを共有しているため,このようなことが起こります.

import torch
a = torch.tensor(list(range(9)))
b = a.view(3, 3)
a[0] = 10
print(a)
print(b)

>>tensor([10,  1,  2,  3,  4,  5,  6,  7,  8])
>>tensor([[10,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]])

a[0]を10に書き換えたことにより,aが[10, 1, 2, 3, 4, 5, 6, 7, 8]となっているのはいいのですが,bも先頭が10に書き換わってしまっています.同じストレージを使っているためです.値を書き換える際は気をつける必要があります.

b

c = b[1:, 1:]
print(c)
print(c.size())
print(c.storage_offset())
print(c.stride())

>>tensor([[4, 5],
        [7, 8]])
>>torch.Size([2, 2])
>>4
>>(3, 1)

2

a

a = torch.tensor(list(range(9)))
print(torch.cos(a))

>>tensor([ 1.0000,  0.5403, -0.4161, -0.9900, -0.6536,  0.2837,  0.9602,  0.7539,
        -0.1455])

関数をaに適用してもエラーが出ません……公式の回答によるとaをfloat型にしないと動かないようですがそうしなくても動きます……バージョンが変わったことにより動くようになったのでしょうか……?

b

 エラーがでないため省略