NumPy - 交集
NumPy 中的交集
在 NumPy 中,"交集"是指两个或多个数组之间共有的元素。
NumPy 提供了一个名为 numpy.intersect1d() 的内置函数,用于查找两个数组之间的交集。
什么是数组交集?
使用数组时,您可能经常需要查找同时出现在两个数组中的元素。这个过程称为求交集。
例如,如果您有两组数字,并且需要确定哪些数字同时出现在两组中,则可以执行交集运算。
NumPy intersect1d() 函数
在 NumPy 中,intersect1d() 函数用于求两个一维数组的交集,如果需要,甚至可以求多个数组的交集。
以下是 NumPy intersect1d() 函数的基本语法。它的工作原理是比较两个输入数组并返回包含公共元素的数组 -
numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False)
其中,
- ar1, ar2:这两个输入数组是我们想要查找公共元素的。
- assume_unique:如果设置为 True,则假定两个输入数组都只包含唯一元素,从而加快计算速度。
- return_indices:如果设置为 True,则该函数不仅返回交集元素,还返回它们在原始数组中的索引数组。
示例
在下面的示例中,我们使用 numpy.intersect1d() 函数查找两个数组之间的共同元素 -
import numpy as np # 定义两个数组 array1 = np.array([1, 2, 3, 4, 5]) array2 = np.array([4, 5, 6, 7, 8]) # 求两个数组的交集 intersection = np.intersect1d(array1, array2) print("array1 与 array2 的交集:", intersection)
以下是得到的结果 -
array1 与 array2 的交集:[4 5]
假设元素唯一以加快计算速度
如果您确定输入数组仅包含唯一元素(即无重复元素),则可以将 True 传递给 assume_unique 参数。这样可以避免检查重复元素,从而加快计算速度:
示例
与上例一样,交集保持不变,但由于假设了唯一性,该函数效率更高 -
import numpy as np # 定义两个包含唯一元素的数组 array1 = np.array([1, 2, 3, 4, 5]) array2 = np.array([4, 5, 6, 7, 8]) # 假设元素唯一,求交集 intersection = np.intersect1d(array1, array2, assume_unique=True) print("假设元素唯一,求交集:", intersect)
输出结果如下:
假设元素唯一,求交集:[4 5]
返回交集元素的索引
除了返回交集元素之外,numpy.intersect1d() 函数还可以返回这些元素在输入数组中的索引。
当你想知道原始数组中的公共元素。为此,请将 return_indices 参数设置为 True。
示例
在此示例中,交集元素 4 和 5 分别出现在 array1 中的索引 3 和 4 处以及 array2 中的索引 0 和 1 处 -
import numpy as np # 定义两个数组 array1 = np.array([1, 2, 3, 4, 5]) array2 = np.array([4, 5, 6, 7, 8]) # 查找交集并返回索引 intersection, indices1, indices2 = np.intersect1d(array1, array2, return_indices=True) print("交集元素:", intersection) print("数组1中的索引:", indices1) print("数组2中的索引:", indices2)
执行上述代码后,我们得到以下输出 -
交集元素:[4 5] 数组1中的索引:[3 4] 数组2中的索引:[0 1]
多个元素的交集数组
numpy.intersect1d() 函数也可用于求两个以上数组的交集。
虽然该函数本身设计用于同时处理两个数组,但您可以使用循环或 functools 模块中的 reduce() 函数轻松将其扩展为处理多个数组。
示例
如下例所示,三个数组的公共元素为 5,因此构成了交集 -
import numpy as np from functools import reduce # 定义多个数组 array1 = np.array([1, 2, 3, 4, 5]) array2 = np.array([4, 5, 6, 7, 8]) array3 = np.array([5, 6, 7, 8, 9]) # 求所有数组的交集 intersection = reduce(np.intersect1d, [array1, array2, array3]) print("多个数组的交集:", intersection)
结果如下:-
多个数组的交集:[5]
处理不同数据类型的数组
NumPy 的 intersect1d() 函数也可以处理不同数据类型的数组,例如整数、浮点数和字符串。
但是,该函数会根据元素的数据类型进行比较,这意味着它会执行类型敏感的比较匹配。
示例
在此示例中,交集元素 4 以浮点数形式返回,因为第一个数组包含浮点数 -
import numpy as np # 定义具有不同数据类型的数组 array1 = np.array([1.0, 2.0, 3.0, 4.0]) array2 = np.array([4, 5, 6, 7]) # 查找交集元素 intersection = np.intersect1d(array1, array2) print("交集元素:", intersection)
输出结果如下 -
交集元素: [4.]
处理浮点精度问题
处理浮点数时,可能会出现精度问题,尤其是当值彼此非常接近但由于浮点运算的方式而不完全相同时。为了避免这种情况,您可以在执行交集运算之前对数组进行四舍五入。
示例
通过将数组四舍五入到小数点后两位,交集运算可以更精确地计算,即使浮点数差异很小,如下例所示 -
import numpy as np # 定义浮点数组 array1 = np.array([1.234, 2.345, 3.456, 4.567]) array2 = np.array([4.567, 5.678, 6.789]) # 对数组进行四舍五入并求交集 array1_rounded = np.round(array1, 2) array2_rounded = np.round(array2, 2) intersection = np.intersect1d(array1_rounded, array2_rounded) print("四舍五入后的交集:", Intersection)
输出结果如下 -
四舍五入后的交集:[4.57]