在使用 PyTorch 进行深度学习模型构建时,经常需要改变矩阵的形状,常用的方法有 resize,view, reshape 等。这和 Numpy 中的 resize 和 reshape 之间的区别是什么。本篇以 JupyterLab 为平台演示两者转换矩阵形状的区别。

PyTorch Tensor

PyTorch 中改变 Tensor 形状的方法有 resize, view, reshape 等,

PyTorch 中 Tensor 的存储方式

PyTorch 中张量 Tensor 存储分为头信息区(包括大小size, 步长stride, 维度等)和存储区(Storage)。如下图,Tensor B 是对 Tensor A 进行截取或转置或改变形状后得到的,此时,Tensor B 共享 Tensor A 的存储区(内存开销小),只不过头信息不同而已。类似于浅拷贝。

png

可以通过如下方法判断两个Tensor是否是共用一个存储区

1
import torch
1
2
x = torch.ones(5)
x
tensor([1., 1., 1., 1., 1.])
1
2
y = x[2:]
y
tensor([1., 1., 1.])
1
2
# 比较存储区的内存地址
x.storage().data_ptr() == y.storage().data_ptr()
True

当两个张量共用一个存储区时,改变一个张量的元素值(数据),另一个也相应的改变。这就是浅拷贝的特点。

Tensor 的步长 Stride 与连续性

Tensor 的步长属性可以简单理解为从 Tensor 的一个维度跨到下一个维度的跨度。Tensor 的值在内存中是顺序存储的,stride[0] 表示从当前行跨到下一行需要调过几个元素,stride[1] 表示从当前列跨到下一列需要跳过几个元素。其他以此类推。如下例子

1
2
x = torch.ones(8).reshape(2, 4)
x
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]])
1
x.stride()
(4, 1)
1
x.size()
torch.Size([2, 4])
1
2
# 同 size() 函数
x.shape
torch.Size([2, 4])

对于 Tensor 的 stride 和 size,如果满足如下公式,则称该 Tensor 是连续的:
$$stride[i] = stride[i + 1] \times size[i + 1]$$

如上例子中,stride[0] = 4, stride[1] = 1, size[1] = 4,满足上面的公式。因此 Tensor x 是连续的。当一个张量不是连续时,可以通过一些方法转化为连续的。请见 view 小结。

resize

resize 是改变张量形状的一种方法,它不仅能够保持数据区不变改变形状,还能够截取部分数据区或则填充数据区。但,该方法不能处理带有 requires_grad=True 的张量。

1
2
x = torch.Tensor(range(9))
x
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.])
1
x.shape
torch.Size([9])
1
x.size()
torch.Size([9])
1
2
# 获取存储区的内存地址
x.storage().data_ptr()
94208512799360

当 resize 设置的尺寸小于原 Tensor,则按原 Tensor 从左到右,从上到下截取

1
2
y = x.resize_(2, 2)
y
tensor([[0., 1.],
        [2., 3.]])
1
2
# 存储区地址不变
y.storage().data_ptr()
94208512799360

当 resize 设置的尺寸等于原 Tensor,则按原 Tensor 从左到右,从上到下改变形状

1
2
y = x.resize_(3, 3)
y
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
1
2
# 存储区地址不变
y.storage().data_ptr()
94208512799360

当 resize 设置的尺寸大于原 Tensor,则按原 Tensor 从左到右,从上到下填充,不足部分填充 0

1
2
y = x.resize_(4, 4)
y
tensor([[ 0.0000e+00,  1.0000e+00,  2.0000e+00,  3.0000e+00],
        [ 4.0000e+00,  5.0000e+00,  6.0000e+00,  7.0000e+00],
        [ 8.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-3.3085e-19,  3.0736e-41, -1.0771e-19,  3.0736e-41]])
1
2
# 存储区地址改变
y.storage().data_ptr()
94208512507328

当 Tensor 中定义 requires_grad=True 时,resize 无法使用

1
2
x = torch.randn(6, requires_grad=True)
x
tensor([ 0.1426, -0.1536, -1.1960,  1.5682, -0.3277, -0.4954],
       requires_grad=True)
1
2
# 将出现错误
x.resize_(2, 3)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

/tmp/ipykernel_32385/1762610407.py in <module>
      1 # 将出现错误
----> 2 x.resize_(2, 3)


RuntimeError: cannot resize variables that require grad

view

view 是另一种改变张量形状的方法,它也属于浅拷贝,但它不能截取或填补数据区,并且只能处理那些满足连续性的张量。

1
2
3
x = torch.arange(8).reshape(2, 4).type(torch.float32)
x.requires_grad = True
x
tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]], requires_grad=True)
1
x.stride()
(4, 1)
1
x.size()
torch.Size([2, 4])
1
x.storage().data_ptr()
94208510273472
1
2
y = x.view(4, -1)
y
tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.]], grad_fn=<ViewBackward0>)
1
2
# 内存地址不变,属于浅拷贝
y.storage().data_ptr()
94208510273472
1
2
# 当新形状设置的与原形状不匹配时,将发生错误
x.view(2, 2)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

/tmp/ipykernel_32385/1087069803.py in <module>
      1 # 当新形状设置的与原形状不匹配时,将发生错误
----> 2 x.view(2, 2)


RuntimeError: shape '[2, 2]' is invalid for input of size 8
1
x.view(6, 6)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

/tmp/ipykernel_32385/3418528958.py in <module>
----> 1 x.view(6, 6)


RuntimeError: shape '[6, 6]' is invalid for input of size 8
1
2
3
# 转置
y = x.permute(1, 0)
y
tensor([[0., 4.],
        [1., 5.],
        [2., 6.],
        [3., 7.]], grad_fn=<PermuteBackward0>)
1
y.stride()
(1, 4)
1
y.size()
torch.Size([4, 2])
1
y.storage().data_ptr()
94208510273472

通过上面的连续性公式判断,张量 y 不满足连续性。因此不是连续的。

1
2
# 非连续的张量,不能使用view
y.view(2, -1)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

/tmp/ipykernel_32385/3470641071.py in <module>
      1 # 非连续的张量,不能使用view
----> 2 y.view(2, -1)


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

可以通过函数 contiguous() 将一个非连续的张量变成一个连续性的张量。其实,这种方法属于深拷贝,它使用原张量的数据新创建了一个连续性的张量。

1
2
z = y.contiguous()
z
tensor([[0., 4.],
        [1., 5.],
        [2., 6.],
        [3., 7.]], grad_fn=<CloneBackward0>)
1
z.stride()
(2, 1)
1
z.size()
torch.Size([4, 2])
1
z.storage().data_ptr()
94208510995136
1
z.view(2, -1)
tensor([[0., 4., 1., 5.],
        [2., 6., 3., 7.]], grad_fn=<ViewBackward0>)

reshape

reshape 是比 view 引入稍晚,但比 view 更健壮的一种方法。当张量满足连续性条件时,它等于 view,属于浅拷贝;当张量不满足连续性条件时,它先将张量连续化,然后再使用 view 改变形状,属于深拷贝。

1
2
3
x = torch.arange(8).reshape(2, 4).type(torch.float32)
x.requires_grad = True
x
tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]], requires_grad=True)
1
x.stride()
(4, 1)
1
x.size()
torch.Size([2, 4])
1
x.storage().data_ptr()
94208509045056
1
torch.reshape(x, (4, 2))
tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.]], grad_fn=<ReshapeAliasBackward0>)
1
2
y = x.reshape(4, 2)
y
tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.]], grad_fn=<ReshapeAliasBackward0>)
1
y.storage().data_ptr()
94208509045056
1
2
y = x.permute(1, 0)
y
tensor([[0., 4.],
        [1., 5.],
        [2., 6.],
        [3., 7.]], grad_fn=<PermuteBackward0>)
1
y.stride()
(1, 4)
1
y.size()
torch.Size([4, 2])

可见张量 y 不满足连续性

1
y.storage().data_ptr()
94208509045056
1
2
3
# 不连续的张量也可以使用 reshape 改变形状
z = y.reshape(2, -1)
z
tensor([[0., 4., 1., 5.],
        [2., 6., 3., 7.]], grad_fn=<UnsafeViewBackward0>)
1
2
# 但存储区的地址改变了,即新创建了一片内存区域
z.storage().data_ptr()
94208510283328

Numpy array

本小节简单回顾下 numpy 中的类似用法。

resize vs. view vs. reshape

1
import numpy as np
1
2
x = np.arange(8).reshape(2, 4)
x
array([[0, 1, 2, 3],
       [4, 5, 6, 7]])
1
2
x.resize(4, 2)
x
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]])
1
np.resize(x, (2, 2))
array([[0, 1],
       [2, 3]])
1
np.resize(x, (4, 3))
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 0],
       [1, 2, 3]])
1
x
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]])
1
x.view(np.float64)
array([[0.0e+000, 4.9e-324],
       [9.9e-324, 1.5e-323],
       [2.0e-323, 2.5e-323],
       [3.0e-323, 3.5e-323]])
1
x.reshape(1, 8)
array([[0, 1, 2, 3, 4, 5, 6, 7]])
1
x.reshape((8, 1))
array([[0],
       [1],
       [2],
       [3],
       [4],
       [5],
       [6],
       [7]])

参考文献

  1. PyTorch:view() 与 reshape() 区别详解
  2. np.resize和np.reshape()的区别
  3. Python numpy.ndarray.view用法及代碼示例