PyTorch実戦入門 第3章 演習問題 解答まとめ
1
import torch a = torch.tensor(list(range(9))) print(a.size()) print(a.storage_offset()) print(a.stride()) >>torch.Size([9]) >>0 >>(1,)
a
.view()を使うと,
b = a.view(3, 3) print(b) >>tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
となります.3*3の行列に変形されたことが分かります.同じストレージを共有しているかどうかを確認するにはid()関数を使います.これは,Pythonが持っている関数で,オブジェクトのidを返します.このidが等しければ同じストレージを共有していることになります.
print(id(a.storage()) == id(b.storage())) >> True
同じストレージを共有しているため,このようなことが起こります.
import torch a = torch.tensor(list(range(9))) b = a.view(3, 3) a[0] = 10 print(a) print(b) >>tensor([10, 1, 2, 3, 4, 5, 6, 7, 8]) >>tensor([[10, 1, 2], [ 3, 4, 5], [ 6, 7, 8]])
a[0]を10に書き換えたことにより,aが[10, 1, 2, 3, 4, 5, 6, 7, 8]となっているのはいいのですが,bも先頭が10に書き換わってしまっています.同じストレージを使っているためです.値を書き換える際は気をつける必要があります.
b
c = b[1:, 1:] print(c) print(c.size()) print(c.storage_offset()) print(c.stride()) >>tensor([[4, 5], [7, 8]]) >>torch.Size([2, 2]) >>4 >>(3, 1)
2
a
a = torch.tensor(list(range(9))) print(torch.cos(a)) >>tensor([ 1.0000, 0.5403, -0.4161, -0.9900, -0.6536, 0.2837, 0.9602, 0.7539, -0.1455])
関数をaに適用してもエラーが出ません……公式の回答によるとaをfloat型にしないと動かないようですがそうしなくても動きます……バージョンが変わったことにより動くようになったのでしょうか……?
b
エラーがでないため省略