Numpy squeeze() 函数
Numpy squeeze() 函数用于从数组形状中删除一维元素。
此函数可用于消除大小为 1 的维度,从而简化数组操作。例如,如果我们有一个形状为 (1, 3, 1, 5) 的数组,应用 squeeze() 会将其形状转换为 (3, 5),并移除其中的单例维度。
此函数接受一个可选的 axis 参数来指定要移除的维度,但如果未提供,则会移除所有单例维度。
结果是一个维度更少但数据相同的新数组。
语法
Numpy squeeze() 函数的语法如下:
numpy.squeeze(a, axis=None)
参数
以下是 Numpy squeeze() 函数的参数−
- a(array_like): 这是输入数据,应为数组或类数组对象。
- axis(None、int 或 int 元组,可选): 此参数选择形状中单维条目的子集。如果指定了轴,则仅挤压该轴或这些轴。如果未指定轴,则将删除所有单维条目。如果指定的轴的大小不是 1,则会引发错误。
返回值
此函数返回输入数组,但所有或部分大小为 1 的维度已被删除。这不会修改原始数组,而是返回一个新数组。
示例 2
以下是使用 Numpy squeeze() 函数的示例。在此示例中,形状为 (1, 3, 1) 的数组 'a' 被压缩,删除了所有一维元素,最终得到一个形状为 (3,) 的数组。-
import numpy as np # 原始数组形状为 (1, 3, 1) a = np.array([[[1], [2], [3]]]) print("原始数组形状:", a.shape) # 压缩后的数组 squeezed_a = np.squeeze(a) print("压缩后的数组形状:", squeezed_a.shape) print("压缩后的数组:", squeezed_a)
输出
原始数组形状:(1, 3, 1) 压缩后的数组形状: (3,) 压缩后的数组:[1 2 3]
示例 2
在此示例中,我们尝试压缩非一维轴,即轴 1,由于轴 1 的大小为 3,因此导致 ValueError -
import numpy as np # 原始数组形状为 (1, 3, 1) a = np.array([[[1], [2], [3]]]) print("原始数组形状:", a.shape) try: # 尝试压缩非一维轴 squeezed_a = np.squeeze(a, axis=1) except ValueError as e: print("Error:", e)
输出
原始数组形状:(1, 3, 1) Error: cannot select an axis to squeeze out which has size not equal to one
示例 3
以下示例展示了如何使用 numpy.squeeze() 从数组形状中删除一维元素 -
import numpy as np # 创建形状为 (1, 3, 3) 的三维数组 x = np.arange(9).reshape(1, 3, 3) print('数组 X:') print(x) print(' ') # 从 x 的形状中删除一维元素 y = np.squeeze(x) print('数组 Y:') print(y) print(' ') # 打印数组形状 print('X 和 Y 数组的形状:') print(x.shape, y.shape)
输出
数组 X: [[[0 1 2] [3 4 5] [6 7 8]]] 数组 Y: [[0 1 2] [3 4 5] [6 7 8]] X 和 Y 数组的形状: (1, 3, 3) (3, 3)