测试 Numpy 数组是否包含给定行
- 2025-01-21 09:01:00
- admin 原创
- 139
问题描述:
有没有一种 Pythonic 且有效的方法来检查 Numpy 数组是否包含给定行的至少一个实例?我所说的“有效”是指它在找到第一个匹配的行时终止,而不是迭代整个数组,即使已经找到结果。
使用 Python 数组可以非常干净地完成此操作if row in array:
,但对于 Numpy 数组,这并不像我预期的那样工作,如下所示。
使用 Python 数组:
>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False
但 Numpy 数组给出的结果不同,而且看起来相当奇怪。(该__contains__
方法ndarray
似乎没有记录。)
>>> a = np.array([[1,2],[10,20],[100,200]])
>>> np.array([1,2]) in a
True
>>> np.array([1,20]) in a
True
>>> np.array([1,42]) in a
True
>>> np.array([42,1]) in a
False
解决方案 1:
您可以使用 .tolist()
>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False
或者使用视图:
>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False
或者通过 numpy 列表生成(可能非常慢):
any(([1,2] == x).all() for x in a) # stops on first occurrence
或者使用 numpy 逻辑函数:
any(np.equal(a,[1,2]).all(1))
如果你对这些进行计时:
import numpy as np
import time
n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()
t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]
tests=[ ('early hit',[t1, t1+1, t1+2]),
('middle hit',[t2,t2+1,t2+2]),
('late hit', [t3,t3+1,t3+2]),
('miss',[0,2,0])]
fmt=' {:20}{:.5f} seconds and is {}'
for test, tgt in tests:
print('
{}: {} in {:,} elements:'.format(test,tgt,n))
name='view'
t1=time.time()
result=(a[...]==tgt).all(1).any()
t2=time.time()
print(fmt.format(name,t2-t1,result))
name='python list'
t1=time.time()
result = True if tgt in b else False
t2=time.time()
print(fmt.format(name,t2-t1,result))
name='gen over numpy'
t1=time.time()
result=any((tgt == x).all() for x in a)
t2=time.time()
print(fmt.format(name,t2-t1,result))
name='logic equal'
t1=time.time()
np.equal(a,tgt).all(1).any()
t2=time.time()
print(fmt.format(name,t2-t1,result))
您可以看到,无论命中与否,numpy 例程搜索数组的速度都相同。对于早期命中,Pythonin
运算符可能要快得多,而如果您必须遍历整个数组,生成器只是一个坏消息。
以下是 300,000 x 3 元素阵列的结果:
early hit: [9000, 9001, 9002] in 300,000 elements:
view 0.01002 seconds and is True
python list 0.00305 seconds and is True
gen over numpy 0.06470 seconds and is True
logic equal 0.00909 seconds and is True
middle hit: [450000, 450001, 450002] in 300,000 elements:
view 0.00915 seconds and is True
python list 0.15458 seconds and is True
gen over numpy 3.24386 seconds and is True
logic equal 0.00937 seconds and is True
late hit: [899970, 899971, 899972] in 300,000 elements:
view 0.00936 seconds and is True
python list 0.30604 seconds and is True
gen over numpy 6.47660 seconds and is True
logic equal 0.00965 seconds and is True
miss: [0, 2, 0] in 300,000 elements:
view 0.00936 seconds and is False
python list 0.01287 seconds and is False
gen over numpy 6.49190 seconds and is False
logic equal 0.00965 seconds and is False
对于 3,000,000 x 3 阵列:
early hit: [90000, 90001, 90002] in 3,000,000 elements:
view 0.10128 seconds and is True
python list 0.02982 seconds and is True
gen over numpy 0.66057 seconds and is True
logic equal 0.09128 seconds and is True
middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
view 0.09331 seconds and is True
python list 1.48180 seconds and is True
gen over numpy 32.69874 seconds and is True
logic equal 0.09438 seconds and is True
late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
view 0.09868 seconds and is True
python list 3.01236 seconds and is True
gen over numpy 65.15087 seconds and is True
logic equal 0.09591 seconds and is True
miss: [0, 2, 0] in 3,000,000 elements:
view 0.09588 seconds and is False
python list 0.12904 seconds and is False
gen over numpy 64.46789 seconds and is False
logic equal 0.09671 seconds and is False
这似乎表明这np.equal
是实现此目的的最快的纯 numpy 方法......
解决方案 2:
__contains__
在撰写本文时,(a == b).any()
Numpys可以说只有当是标量时才是正确的b
(这有点复杂,但我相信 – 只有在 1.7 或更高版本中才能这样工作 – 这将是正确的通用方法,对于和维数(a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any()
的所有组合都是有意义的)...a
`b`
编辑:需要说明的是,当涉及广播时,这不一定a
是预期的结果。另外,有人可能会认为应该像现在这样单独处理项目np.in1d
。我不确定是否有一种明确的工作方式。
现在你想让 numpy 在找到第一个出现时停止。据我所知,目前还不存在这种情况。这很困难,因为 numpy 主要基于 ufuncs,它对整个数组执行相同的操作。Numpy 确实优化了这些类型的缩减,但实际上只有当被缩减的数组已经是布尔数组(即np.ones(10, dtype=bool).any()
)时才有效。
否则,它将需要一个__contains__
不存在的特殊函数。这可能看起来很奇怪,但你必须记住,numpy 支持许多数据类型,并且有更大的机制来选择正确的数据类型并选择正确的函数来处理它。换句话说,ufunc 机制无法做到这一点,而且__contains__
由于数据类型的原因,实现或这样的特殊功能实际上并不那么简单。
你当然可以用 python 编写它,或者因为你可能知道你的数据类型,所以用 Cython/C 自己编写它非常简单。
话虽如此。通常,对于这些事情,使用基于排序的方法要好得多。这有点乏味,而且没有这样的事情searchsorted
,lexsort
但它有效(如果你愿意,你也可以滥用它scipy.spatial.cKDTree
)。这假设你只想沿着最后一个轴进行比较:
# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.
sorted.sort()
b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)
result = sorted[ind] == b_comp
这也适用于数组b
,如果你保留排序后的数组,那么如果你一次对单个值(行)执行此操作b
,效果会更好,因为a
保持不变(否则我会np.in1d
在将其视为重新数组之后执行此操作)。重要提示:为了安全起见,你必须执行np.ascontiguousarray
。它通常不会执行任何操作,但如果执行了,否则将是一个很大的潜在错误。
解决方案 3:
我认为
equal([1,2], a).all(axis=1) # also, ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)
将列出匹配的行。正如 Jamie 指出的那样,要知道是否存在至少一行这样的行,请使用any
:
equal([1,2], a).all(axis=1).any()
# True
另外:
我怀疑in
(和__contains__
)与上面的一样,但使用any
而不是all
。
解决方案 4:
我将建议的解决方案与perfplot进行了比较,发现如果你在一个较长的未排序列表中寻找一个 2 元组,
np.any(np.all(a == b, axis=1))
是最快的解决方案。如果在前几行中找到匹配项,则显式短路循环总是可以更快。
在此处输入图片描述
重现情节的代码:
import numpy as np
import perfplot
target = [6, 23]
def setup(n):
return np.random.randint(0, 100, (n, 2))
def any_all(data):
return np.any(np.all(target == data, axis=1))
def tolist(data):
return target in data.tolist()
def loop(data):
for row in data:
if np.all(row == target):
return True
return False
def searchsorted(a):
s = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
s.sort()
t = np.ascontiguousarray(target).view(s.dtype)
ind = s.searchsorted(t)
return (s[ind] == t)[0]
perfplot.save(
"out02.png",
setup=setup,
kernels=[any_all, tolist, loop, searchsorted],
n_range=[2 ** k for k in range(2, 20)],
xlabel="len(array)",
)
解决方案 5:
如果你确实想在第一次出现时停止,你可以编写一个循环,例如:
import numpy as np
needle = np.array([10, 20])
haystack = np.array([[1,2],[10,20],[100,200]])
found = False
for row in haystack:
if np.all(row == needle):
found = True
break
print("Found: ", found)
然而,我强烈怀疑,它会比使用 numpy 例程对整个数组执行此操作的其他建议慢得多。
解决方案 6:
要知道二维 numpy 数组中是否存在特定的一维 numpy 数组(行),一个更简单的方法是使用以下条件。
if np.sum(np.prod(2-darray == 1-darray),axis = 1)) > 0
如果np.sum(np.prod(2-darray == 1-darray),axis = 1))
大于0
,则该行存在于二维数组中,否则不存在。
扫码咨询,免费领取项目管理大礼包!