如果 cumsum 大于值,则重新启动 cumsum 并获取索引
- 2025-03-12 08:50:00
- admin 原创
- 82
问题描述:
假设我有一个距离数组x=[1,2,1,3,3,2,1,5,1,1]
。
我想从 x 中获取 cumsum 达到 10 的索引,在本例中为 idx=[4,9]。
因此,在满足条件后,cumsum 会重新启动。
我可以用循环来完成它,但是对于大型数组来说循环很慢,我想知道是否可以用某种vectorized
方式来完成它。
解决方案 1:
一种有趣的方法
sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1)
newx=sumlm.accumulate(x, dtype=np.object)
newx
array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object)
np.nonzero(newx==10)
(array([4, 9]),)
解决方案 2:
这是一个带有 numba 和数组初始化的代码 -
from numba import njit
@njit
def cumsum_breach_numba2(x, target, result):
total = 0
iterID = 0
for i,x_i in enumerate(x):
total += x_i
if total >= target:
result[iterID] = i
iterID += 1
total = 0
return iterID
def cumsum_breach_array_init(x, target):
x = np.asarray(x)
result = np.empty(len(x),dtype=np.uint64)
idx = cumsum_breach_numba2(x, target, result)
return result[:idx]
时间安排
包括@piRSquared's solutions
并使用来自同一篇文章的基准测试设置 -
In [58]: np.random.seed([3, 1415])
...: x = np.random.randint(100, size=1000000).tolist()
# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop
# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop
# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop
Numba:附加与数组初始化
为了仔细观察数组初始化如何提供帮助,这似乎是两个 numba 实现之间的巨大差异,让我们对数组数据进行计时,因为数组数据创建本身在运行时很繁重,并且它们都依赖于它 -
In [62]: x = np.array(x)
In [63]: %timeit cumsum_breach_numba(x, 10)# with appending
10 loops, best of 3: 31.5 ms per loop
In [64]: %timeit cumsum_breach_array_init(x, 10)
1000 loops, best of 3: 1.8 ms per loop
为了强制输出拥有自己的内存空间,我们可以进行复制。不过不会对事情产生太大影响 -
In [65]: %timeit cumsum_breach_array_init(x, 10).copy()
100 loops, best of 3: 2.67 ms per loop
解决方案 3:
循环并不总是坏事(尤其是当你需要循环时)。此外,没有任何工具或算法可以比 O(n) 更快。所以让我们做一个好的循环吧。
生成器函数
def cumsum_breach(x, target):
total = 0
for i, y in enumerate(x):
total += y
if total >= target:
yield i
total = 0
list(cumsum_breach(x, 10))
[4, 9]
使用 Numba 进行即时编译
Numba 是一个需要安装的第三方库。Numba
可能对支持哪些功能很挑剔。但这个有用。
此外,正如 Divakar 指出的那样,Numba 在数组方面表现更好
from numba import njit
@njit
def cumsum_breach_numba(x, target):
total = 0
result = []
for i, y in enumerate(x):
total += y
if total >= target:
result.append(i)
total = 0
return result
cumsum_breach_numba(x, 10)
测试两者
因为我觉得¯_(ツ)_/¯
设置
np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()
准确性
i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))
assert i0 == i1
时间
%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))
582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numba 的速度大约快 100 倍。
为了进行更真实的苹果与苹果的测试,我将列表转换为 Numpy 数组
%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))
43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
这使得他们的比分大致相同。
相关推荐
热门文章
项目管理软件有哪些?
热门标签
曾咪二维码
扫码咨询,免费领取项目管理大礼包!
云禅道AD