如何使用传统的Python或者Pandas/Numpy/Scipy来选择列表中序列中重复项的第一个出现。

3

假设有一个名为“series”的列表,其中在多个索引值处有一些重复的元素。是否有一种方法可以找到数字重复序列的第一次出现。

series = [2,3,7,10,11,16,16,9,11,12,14,16,16,16,5,7,9,17,17,4,8,18,18]

返回值应类似于 [5,11,17,21],这些是重复序列 [16,16],[16,16,16],[17,17] 和 [18,18] 的第一次出现的索引值。
6个回答

4

这里有一个性能优化的方法,使用数组切片,类似于@piRSquared的第二种解决方案,但没有任何追加/连接操作 -

a = np.array(series)
out = np.flatnonzero((a[2:] == a[1:-1]) & (a[1:-1] != a[:-2]))+1

样例运行 -

In [28]: a = np.array(series)

In [29]: np.flatnonzero((a[2:] == a[1:-1]) & (a[1:-1] != a[:-2]))+1
Out[29]: array([ 5, 11, 17, 21])

运行时测试(用于工作解决方案)

方法 -

def piRSquared1(series):
    d = np.flatnonzero(np.diff(series) == 0)
    w = np.append(True, np.diff(d) > 1)
    return d[w].tolist()

def piRSquared2(series):
    s = np.array(series)
    return np.flatnonzero(
        np.append(s[:-1] == s[1:], True) &
        np.append(True, s[1:] != s[:-1])
    ).tolist()

def Zach(series):
    s = pd.Series(series)
    i = [g.index[0] for _, g in s.groupby((s != s.shift()).cumsum()) if len(g) > 1]
    return i

def jezrael(series):
    s = pd.Series(series)
    s1 = s.shift(1).ne(s).cumsum()
    m = ~s1.duplicated() & s1.duplicated(keep=False)
    s2 = m.index[m].tolist()
    return s2    

def divakar(series):
    a = np.array(series)
    x = a[1:-1]
    return (np.flatnonzero((a[2:] == x) & (x != a[:-2]))+1).tolist()

针对设置,我们只需将样本输入平铺多次。

时间 -

案例#1:大规模设置

In [34]: series0 = [2,3,7,10,11,16,16,9,11,12,14,16,16,16,5,7,9,17,17,4,8,18,18]

In [35]: series = np.tile(series0,10000).tolist()

In [36]: %timeit piRSquared1(series)
    ...: %timeit piRSquared2(series)
    ...: %timeit Zach(series)
    ...: %timeit jezrael(series)
    ...: %timeit divakar(series)
    ...: 
100 loops, best of 3: 8.06 ms per loop
100 loops, best of 3: 7.79 ms per loop
1 loop, best of 3: 3.88 s per loop
10 loops, best of 3: 24.3 ms per loop
100 loops, best of 3: 7.97 ms per loop

案例2:更大的数据集(前2个解决方案)

In [40]: series = np.tile(series0,1000000).tolist()

In [41]: %timeit piRSquared2(series)
1 loop, best of 3: 823 ms per loop

In [42]: %timeit divakar(series)
1 loop, best of 3: 823 ms per loop

现在,这两种解决方案仅在避免附加的方式上有所不同。让我们更仔细地看一下它们,在较小的数据集上运行 -

In [43]: series = np.tile(series0,100).tolist()

In [44]: %timeit piRSquared2(series)
10000 loops, best of 3: 89.4 µs per loop

In [45]: %timeit divakar(series)
10000 loops, best of 3: 82.8 µs per loop

因此,后一种解决方案中避免连接/追加在处理较小数据集时非常有帮助,但是在处理更大的数据集时,它们变得可比。

在使用该方法时,在较大的数据集上可能会有微小的改进。 因此,最后一步可以重写为:

np.flatnonzero(np.concatenate(([False],(a[2:] == a[1:-1]) & (a[1:-1] != a[:-2]))))

好的,我们要比速度吗?让我拿出numba :) 请看下面我的答案。 - Daniel F
@DanielF 我以为 OP 在使用 pandas/numpy/scipy ;) - Divakar
谁说作弊者永远不会成功? - Daniel F
感谢您详细的方法! - RTM

2
您可以使用 shift 键。
In [3815]: s = pd.Series(series)

In [3816]: cond = (s == s.shift(-1))

In [3817]: cond.index[cond]
Out[3817]: Int64Index([5, 11, 12, 17, 21], dtype='int64')

或者,diff
In [3828]: cond = s.diff(-1).eq(0)

In [3829]: cond.index[cond]
Out[3829]: Int64Index([5, 11, 12, 17, 21], dtype='int64')

使用tolist进行列表输出

In [3833]: cond.index[cond].tolist()
Out[3833]: [5, 11, 12, 17, 21]

细节说明
In [3823]: s.head(10)
Out[3823]:
0     2
1     3
2     7
3    10
4    11
5    16
6    16
7     9
8    11
9    12
dtype: int64

In [3824]: cond.head(10)
Out[3824]:
0    False
1    False
2    False
3    False
4    False
5     True
6    False
7    False
8    False
9    False
dtype: bool

需要 [5, 11, 17, 21] - jezrael
我之前尝试过这种方法,直到我意识到问题所在。假设重复的序列是s=pd.Series(1,2,3,3,3)。因此,如果您向左移动1位,当运行cond = (s == s.shift(-1))时,您将有2个True。但是返回值只需要是索引2,因为那是任何重复序列的第一个出现。 - RTM

2

np.diff & np.flatnonzero
这个答案使用了np.diff函数,并测试了差值是否为零。在这些点上,我们知道有重复项。我们使用np.flatnonzero函数来给出这些差异为零的位置。然而,我们只需要连续差异的第一个位置。所以我们再次使用np.diff函数来过滤掉重复序列中的第一个。这一次,我们将结果用作布尔掩码。

d = np.flatnonzero(np.diff(series) == 0)
w = np.append(True, np.diff(d) > 1)
d[w]

array([ 5, 11, 17, 21])

np.flatnonzero
我认为这是一种更好的方法。我们创建一个布尔数组,评估值何时等于下一个值但不等于前一个值。我们利用np.flatnonzero告诉我们True值的位置。

我也觉得这个答案的对称性很吸引人。

s = np.array(series)

np.flatnonzero(
    np.append(s[:-1] == s[1:], True) &
    np.append(True, s[1:] != s[:-1])
)

array([ 5, 11, 17, 21])

我认为,OP正在要求返回连续的重复索引。 - Zero
@Zero已修复。谢谢关注(-: - piRSquared

2

首先,使用 shiftcumsum 创建唯一的组,然后获取第一个重复项的掩码并通过布尔索引进行过滤:

s = pd.Series([2,3,7,10,11,16,16,9,11,12,14,16,16,16,5,7,9,17,17,4,8,18,18])

s1 = s.shift(1).ne(s).cumsum()
m = ~s1.duplicated() & s1.duplicated(keep=False)
s2 = m.index[m].tolist()
print (s2)
[5, 11, 17, 21]

print (s1)
0      1
1      2
2      3
3      4
4      5
5      6
6      6
7      7
8      8
9      9
10    10
11    11
12    11
13    11
14    12
15    13
16    14
17    15
18    15
19    16
20    17
21    18
22    18
dtype: int32

print (m)
dtype: int32
0     False
1     False
2     False
3     False
4     False
5      True
6     False
7     False
8     False
9     False
10    False
11     True
12    False
13    False
14    False
15    False
16    False
17     True
18    False
19    False
20    False
21     True
22    False
dtype: bool

1
在duplicated函数中聪明地使用keep参数。谢谢! - RTM

2

既然我们似乎在竞速比赛中竞争,而且没有人可能在不作弊的情况下击败Divakar / piRsquared,因为需要使用 pandas/numpy/scipy,所以这里是我的numba解决方案:

from numba import jit
import numpy as np

@jit
def rpt_idx(s):
    out = []
    j = True
    for i in range(len(s)):
        if s[i] == s[i+1]:
            if j:
                out.append(i)
                j = False
        else:
            j = True
    return out

rpt_idx(series)
Out: array([ 5, 11, 17, 21])

对于这样一个微不足道的情况,使用 jit 可能有点过度杀伤力,但它确实可以大大提高速度。

%timeit rpt_idx(series)
The slowest run took 10.50 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.99 µs per loop

%timeit divakar(series)
The slowest run took 7.73 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 12.5 µs per loop

series_ = np.tile(series,10000).tolist()

%timeit divakar(series_)
100 loops, best of 3: 20.1 ms per loop

%timeit rpt_idx(series_)
100 loops, best of 3: 5.84 ms per loop

不再是这样了 :) 而且更快了。 - Daniel F
承认作弊部分并在努力方面有所表现,值得点赞 :) - Divakar
在我看来,微秒级别的基准测试从来不是很好的选择。而且我们是在一个包含20个元素的数组上进行测量的吗? :) - Zero

1
你可以简单地模仿 Python 的 itertools.groupby,将相邻的重复项分组在一起。
>>> import pandas
>>> s = pandas.Series([2, 3, 7, 10, 11, 16, 16, 9, 11, 12, 14, 16, 16, 16, 5, 7, 9, 17, 17, 4, 8, 18, 18])
>>> for _, group in s.groupby((s != s.shift()).cumsum()):
...     if len(group) > 1:
...         print(group.index[0])
5
11
17
21

或者作为列表:
>>> [g.index[0] for _, g in s.groupby((s != s.shift()).cumsum()) if len(g) > 1]
[5, 11, 17, 21]

感谢您的方法。我不确定为什么,但是@jezrael的答案运行速度快了3倍,尽管他使用了cumsum()函数! - RTM

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接