有效地返回数组中第一个满足条件的值的索引

2025-03-10 08:52:00
admin
原创
61
摘要:问题描述:我需要找到满足条件的 1d NumPy 数组或 Pandas 数字系列中第一个值的索引。数组很大,索引可能靠近数组的开头或结尾,或者条件可能根本不满足。我无法提前判断哪种情况更有可能发生。如果不满足条件,则返回值应该是-1。我考虑了几种方法。尝试 1# func(arr) returns a Boo...

问题描述:

我需要找到满足条件的 1d NumPy 数组或 Pandas 数字系列中第一个值的索引。数组很大,索引可能靠近数组的开头或结尾,或者条件可能根本不满足。我无法提前判断哪种情况更有可能发生。如果不满足条件,则返回值应该是-1。我考虑了几种方法。

尝试 1

# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)

但这通常太慢了,因为会在整个func(arr)数组上应用矢量化函数,而不是在满足条件时停止。具体来说,当条件在数组开头附近满足时,成本会很高。

第二次尝试

np.argmax稍微快一点,但无法识别何时条件从未得到满足:

np.random.seed(0)
arr = np.random.rand(10**7)

assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)

%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms
%timeit np.argmax(arr > 0.999999)                    # 17.7 ms

np.argmax(arr > 1.0)返回0,即条件不满足时的实例。

第三次尝试

# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)

但是当条件在数组末尾附近满足时,速度就太慢了。这可能是因为生成器表达式有大量调用,开销很大__next__

这是否总是一种妥协,或者是否有一种方法,对于通用的func,可以有效地提取第一个索引?

基准测试

对于基准测试,假设func当某个值大于给定常数时找到索引:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

解决方案 1:

numba

有了numba它,就可以优化这两种情况。从语法上讲,你只需要构造一个带有简单for循环的函数:

from numba import njit

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

idx = get_first_index_nb(A, 0.9)

Numba 通过 JIT(“即时”)编译代码并利用CPU 级优化来提高性能。没有装饰器的常规 循环通常会比您已经尝试过的方法更慢,因为条件在后期才得到满足。for`@njit`

对于 Pandas 数字系列df['data'],您可以简单地将 NumPy 表示形式提供给 JIT 编译函数:

idx = get_first_index_nb(df['data'].values, 0.9)

概括

由于numba允许函数作为参数,并且假设传递的函数也可以进行 JIT 编译,因此可以得出一种方法来计算满足任意条件的第 nfunc个索引。

@njit
def get_nth_index_count(A, func, count):
    c = 0
    for i in range(len(A)):
        if func(A[i]):
            c += 1
            if c == count:
                return i
    return -1

@njit
def func(val):
    return val > 0.9

# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)

对于倒数第三个值,您可以输入反转的arr[::-1],并对结果取反len(arr) - 1,这- 1需要考虑 0 索引。

绩效基准测试

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

def get_first_index_np(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

%timeit get_first_index_nb(arr, m)                                 # 375 ns
%timeit get_first_index_np(arr, m)                                 # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

%timeit get_first_index_nb(arr, n)                                 # 204 µs
%timeit get_first_index_np(arr, n)                                 # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

解决方案 2:

我也想做类似的事情,但发现这个问题中提出的解决方案并没有真正的帮助我。特别是,numba对我来说,这个解决方案比问题本身中提出的更传统的方法要慢得多。我有一个times_all列表,通常有几万个元素,并且想要找到第一个times_all大于 a 的元素的索引time_event。我有成千上万个time_events。我的解决方案是将其分成times_all例如 100 个元素的块,首先确定time_event属于哪个时间段,保留该段第一个元素的索引,然后找到该段中的哪个索引,然后将两个索引相加。这是一个最小代码。对我来说,它的运行速度比本页中的其他解决方案快几个数量级。

def event_time_2_index(time_event, times_all, STEPS=100):
    import numpy as np
    time_indices_jumps = np.arange(0, len(times_all), STEPS)
    time_list_jumps = [times_all[idx] for idx in time_indices_jumps]

    time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\n                          if val > time_event), -1)
    index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]
    times_cropped = times_all[index_in_jumps:]
    event_index_rel = next((idx for idx, val in enumerate(times_cropped) \n                      if val > time_event), -1)

    event_index = event_index_rel + index_in_jumps
    return event_index
相关推荐
  政府信创国产化的10大政策解读一、信创国产化的背景与意义信创国产化,即信息技术应用创新国产化,是当前中国信息技术领域的一个重要发展方向。其核心在于通过自主研发和创新,实现信息技术应用的自主可控,减少对外部技术的依赖,并规避潜在的技术制裁和风险。随着全球信息技术竞争的加剧,以及某些国家对中国在科技领域的打压,信创国产化显...
工程项目管理   4008  
  为什么项目管理通常仍然耗时且低效?您是否还在反复更新电子表格、淹没在便利贴中并参加每周更新会议?这确实是耗费时间和精力。借助软件工具的帮助,您可以一目了然地全面了解您的项目。如今,国内外有足够多优秀的项目管理软件可以帮助您掌控每个项目。什么是项目管理软件?项目管理软件是广泛行业用于项目规划、资源分配和调度的软件。它使项...
项目管理软件   2751  
  本文介绍了以下10款项目管理软件工具:禅道项目管理软件、Freshdesk、ClickUp、nTask、Hubstaff、Plutio、Productive、Targa、Bonsai、Wrike。在当今快速变化的商业环境中,项目管理已成为企业成功的关键因素之一。然而,许多企业在项目管理过程中面临着诸多痛点,如任务分配不...
项目管理系统   86  
  本文介绍了以下10款项目管理软件工具:禅道项目管理软件、Monday、TeamGantt、Filestage、Chanty、Visor、Smartsheet、Productive、Quire、Planview。在当今快速变化的商业环境中,项目管理已成为企业成功的关键因素之一。然而,许多项目经理和团队在管理复杂项目时,常...
开源项目管理工具   97  
  本文介绍了以下10款项目管理软件工具:禅道项目管理软件、Smartsheet、GanttPRO、Backlog、Visor、ResourceGuru、Productive、Xebrio、Hive、Quire。在当今快节奏的商业环境中,项目管理已成为企业成功的关键因素之一。然而,许多企业在选择项目管理工具时常常面临困惑:...
项目管理系统   85  
热门文章
项目管理软件有哪些?
曾咪二维码

扫码咨询,免费领取项目管理大礼包!

云禅道AD
禅道项目管理软件

云端的项目管理软件

尊享禅道项目软件收费版功能

无需维护,随时随地协同办公

内置subversion和git源码管理

每天备份,随时转为私有部署

免费试用