numpy 是 Python 科学计算库,效率非常高,其自带的数据结构 ndarray 能够非常方便的处理多维数据,在 numpy 中定义了很多多维数据拼接的方法,本篇简要介绍它们。
hstack, vstack, dstack
hstack 水平拼接
hstack 能够沿水平方向拼接两个行数相同的多维数组。如下
1 2 3 4
| import numpy as np
a = np.reshape(np.arange(1, 7), newshape=(3, 2)) a
|
array([[1, 2],
[3, 4],
[5, 6]])
1 2
| b = np.reshape(np.arange(11, 17), newshape=(3, 2)) b
|
array([[11, 12],
[13, 14],
[15, 16]])
array([[ 1, 2, 11, 12],
[ 3, 4, 13, 14],
[ 5, 6, 15, 16]])
1
| a.shape, b.shape, c.shape
|
((3, 2), (3, 2), (3, 4))
vstack 垂直拼接
vstack 能够沿垂直方向拼接两个列数相同的多维数组。如下
1 2 3 4
| import numpy as np
a = np.reshape(np.arange(1, 7), newshape=(2, 3)) a
|
array([[1, 2, 3],
[4, 5, 6]])
1 2
| b = np.reshape(np.arange(11, 17), newshape=(2, 3)) b
|
array([[11, 12, 13],
[14, 15, 16]])
array([[ 1, 2, 3],
[ 4, 5, 6],
[11, 12, 13],
[14, 15, 16]])
1
| a.shape, b.shape, c.shape
|
((2, 3), (2, 3), (4, 3))
dstack 深度拼接
dstack 能够沿着深度方向拼接两个行数和列数都相同的多维数组。如下
1 2 3 4
| import numpy as np
a = np.reshape(np.arange(1, 7), newshape=(2, 3)) a
|
array([[1, 2, 3],
[4, 5, 6]])
1 2
| b = np.reshape(np.arange(11, 17), newshape=(2, 3)) b
|
array([[11, 12, 13],
[14, 15, 16]])
array([[[ 1, 11],
[ 2, 12],
[ 3, 13]],
[[ 4, 14],
[ 5, 15],
[ 6, 16]]])
1
| a.shape, b.shape, c.shape
|
((2, 3), (2, 3), (2, 3, 2))
concatenate 指定拼接方向
concatenate 能够指定拼接的方式,但只能指定水平和垂直方向,不包含深度方向。如下
1 2 3 4
| import numpy as np
a = np.reshape(np.arange(1, 7), newshape=(2, 3)) a
|
array([[1, 2, 3],
[4, 5, 6]])
1 2
| b = np.reshape(np.arange(11, 17), newshape=(2, 3)) b
|
array([[11, 12, 13],
[14, 15, 16]])
1 2
| c = np.concatenate((a, b), axis=0) c
|
array([[ 1, 2, 3],
[ 4, 5, 6],
[11, 12, 13],
[14, 15, 16]])
1
| a.shape, b.shape, c.shape
|
((2, 3), (2, 3), (4, 3))
1 2
| c = np.concatenate((a, b), axis=1) c
|
array([[ 1, 2, 3, 11, 12, 13],
[ 4, 5, 6, 14, 15, 16]])
1
| a.shape, b.shape, c.shape
|
((2, 3), (2, 3), (2, 6))
stack 维度扩充
stack 维度扩展,把多个二维的矩阵扩展到三维,如下
1 2 3 4 5 6 7 8
| import matplotlib.pyplot as plt import numpy as np
a = np.zeros(shape=(1024, 1024), dtype=np.uint8) a[:100, :100] = 255 print(a.shape) plt.imshow(a, cmap="gray") plt.show()
|
(1024, 1024)
1 2 3 4 5
| b = np.ones(shape=(1024, 1024), dtype=np.uint8) * 255 b[100:200, 100:200] = 0 print(b.shape) plt.imshow(b, cmap="gray") plt.show()
|
(1024, 1024)
1 2 3 4
| c2 = np.stack((a, b, a), axis=2) print(c2.shape) plt.imshow(c2) plt.show()
|
(1024, 1024, 3)
1 2
| c1 = np.stack((a, b, b), axis=1) print(c1.shape)
|
(1024, 3, 1024)
1 2 3 4
| d1 = np.transpose(c1, [0, 2, 1]) print(d1.shape) plt.imshow(d1) plt.show()
|
(1024, 1024, 3)
1 2
| c0 = np.stack((b, b, a), axis=0) print(c0.shape)
|
(3, 1024, 1024)
1 2 3 4
| d0 = np.transpose(c0, [2, 1, 0]) print(d0.shape) plt.imshow(d0) plt.show()
|
(1024, 1024, 3)
参考文献
- numpy数组的拼接(扩维拼接和非扩维拼接)