numpy 是 Python 科学计算库,效率非常高,其自带的数据结构 ndarray 能够非常方便的处理多维数据,在 numpy 中定义了很多多维数据拼接的方法,本篇简要介绍它们。

hstack, vstack, dstack

hstack 水平拼接

hstack 能够沿水平方向拼接两个行数相同的多维数组。如下

1
2
3
4
import numpy as np

a = np.reshape(np.arange(1, 7), newshape=(3, 2))
a
array([[1, 2],
       [3, 4],
       [5, 6]])
1
2
b = np.reshape(np.arange(11, 17), newshape=(3, 2))
b
array([[11, 12],
       [13, 14],
       [15, 16]])
1
2
c = np.hstack((a, b))
c
array([[ 1,  2, 11, 12],
       [ 3,  4, 13, 14],
       [ 5,  6, 15, 16]])
1
a.shape, b.shape, c.shape
((3, 2), (3, 2), (3, 4))

vstack 垂直拼接

vstack 能够沿垂直方向拼接两个列数相同的多维数组。如下

1
2
3
4
import numpy as np

a = np.reshape(np.arange(1, 7), newshape=(2, 3))
a
array([[1, 2, 3],
       [4, 5, 6]])
1
2
b = np.reshape(np.arange(11, 17), newshape=(2, 3))
b
array([[11, 12, 13],
       [14, 15, 16]])
1
2
c = np.vstack((a, b))
c
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [11, 12, 13],
       [14, 15, 16]])
1
a.shape, b.shape, c.shape
((2, 3), (2, 3), (4, 3))

dstack 深度拼接

dstack 能够沿着深度方向拼接两个行数和列数都相同的多维数组。如下

1
2
3
4
import numpy as np

a = np.reshape(np.arange(1, 7), newshape=(2, 3))
a
array([[1, 2, 3],
       [4, 5, 6]])
1
2
b = np.reshape(np.arange(11, 17), newshape=(2, 3))
b
array([[11, 12, 13],
       [14, 15, 16]])
1
2
c = np.dstack((a, b))
c
array([[[ 1, 11],
        [ 2, 12],
        [ 3, 13]],

       [[ 4, 14],
        [ 5, 15],
        [ 6, 16]]])
1
a.shape, b.shape, c.shape
((2, 3), (2, 3), (2, 3, 2))

concatenate 指定拼接方向

concatenate 能够指定拼接的方式,但只能指定水平和垂直方向,不包含深度方向。如下

1
2
3
4
import numpy as np

a = np.reshape(np.arange(1, 7), newshape=(2, 3))
a
array([[1, 2, 3],
       [4, 5, 6]])
1
2
b = np.reshape(np.arange(11, 17), newshape=(2, 3))
b
array([[11, 12, 13],
       [14, 15, 16]])
1
2
c = np.concatenate((a, b), axis=0)  # axis=0 表示在 axis=0 维度上扩展,即垂直方向上扩展,等效于 np.vstack
c
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [11, 12, 13],
       [14, 15, 16]])
1
a.shape, b.shape, c.shape
((2, 3), (2, 3), (4, 3))
1
2
c = np.concatenate((a, b), axis=1)  # axis=1 表示在 axis=1 维度上扩展,即水平方向上扩展,等效于 np.hstack
c
array([[ 1,  2,  3, 11, 12, 13],
       [ 4,  5,  6, 14, 15, 16]])
1
a.shape, b.shape, c.shape
((2, 3), (2, 3), (2, 6))

stack 维度扩充

stack 维度扩展,把多个二维的矩阵扩展到三维,如下

1
2
3
4
5
6
7
8
import matplotlib.pyplot as plt
import numpy as np

a = np.zeros(shape=(1024, 1024), dtype=np.uint8)
a[:100, :100] = 255
print(a.shape)
plt.imshow(a, cmap="gray")
plt.show()
(1024, 1024)

png

1
2
3
4
5
b = np.ones(shape=(1024, 1024), dtype=np.uint8) * 255
b[100:200, 100:200] = 0
print(b.shape)
plt.imshow(b, cmap="gray")
plt.show()
(1024, 1024)

png

1
2
3
4
c2 = np.stack((a, b, a), axis=2)  # 增加到 3 维,沿第三维增加。沿深度方向拼接,类似于 dstack
print(c2.shape)
plt.imshow(c2)
plt.show()
(1024, 1024, 3)

png

1
2
c1 = np.stack((a, b, b), axis=1)  # 增加到 3 维,沿第二维增加
print(c1.shape)
(1024, 3, 1024)
1
2
3
4
d1 = np.transpose(c1, [0, 2, 1])
print(d1.shape)
plt.imshow(d1)
plt.show()
(1024, 1024, 3)

png

1
2
c0 = np.stack((b, b, a), axis=0)  #增加到 3 维,沿着第一维增加
print(c0.shape)
(3, 1024, 1024)
1
2
3
4
d0 = np.transpose(c0, [2, 1, 0])
print(d0.shape)
plt.imshow(d0)
plt.show()
(1024, 1024, 3)

png

参考文献

  1. numpy数组的拼接(扩维拼接和非扩维拼接)