数组中查找重复元素的查找函数

4
我有一个正整数的类似列表的Python对象,我想要获取该列表中重复值的位置。例如,如果输入为[0,1,1],则函数应该返回[1,2],因为元素1在输入数组的位置1和2处出现了两次。同样地: [0,13,13] 应该返回 [[1, 2]] [0,1,2,1,3,4,2,2] 应该返回 [[1, 3], [2, 6, 7]],因为1 在输入数组的位置[1, 3]处出现了两次,2 在位置[2, 6, 7]处出现了3次。 [1, 2, 3] 应该返回一个空列表[] 我的代码如下:
def get_locations(labels):
    out = []
    label_set = set(labels)
    for label in list(label_set):
        temp = [i for i, j in enumerate(labels) if j == label]
        if len(temp) > 1:
            out.append(np.array(temp))

    return np.array(out)

当输入数组较小时,它可以正常工作,但随着规模的增长,速度变得过慢。例如,下面的代码在我的电脑上,当n = 1000时从0.14秒飙升到12秒

from timeit import default_timer as timer
start = timer()
n = 10000
a = np.arange(n)
b = np.append(a, a[-1]) # append the last element to the end
out = get_locations(b)
end = timer()
print(out)
print(end - start) # Time in seconds

请问有什么方法可以加速这个过程吗?非常感谢任何建议。


一个非常相似的问题:高效地将列表中相同元素的索引分组 - Georgy
5个回答

4

您的嵌套循环导致时间复杂度为 O(n ^ 2)。相反,您可以创建一个字典列表来将索引映射到每个标签,并且仅在子列表的长度大于1时提取字典的子列表,这将将时间复杂度降低到O(n)

def get_locations(labels):
    positions = {}
    for index, label in enumerate(labels):
        positions.setdefault(label, []).append(index)
    return [indices for indices in positions.values() if len(indices) > 1]

get_locations([0, 1, 2, 1, 3, 4, 2, 2])返回:

[[1, 3], [2, 6, 7]]

也许值得一提的是,此解决方案依赖于字典的插入顺序,这在Python 3.6及以上版本中是成立的。 - sdcbr

3

您的代码由于嵌套的for循环而变得缓慢。您可以通过使用另一种数据结构以更高效的方式解决此问题:

from collections import defaultdict
mylist = [0,1,2,1,3,4,2,2]

output = defaultdict(list)
# Loop once over mylist, store the indices of all unique elements
for i, el in enumerate(mylist):
    output[el].append(i)

# Filter out elements that occur only once
output = {k:v for k, v in output.items() if len(v) > 1}

这将为您的示例 b 生成以下输出:
{1: [1, 3], 2: [2, 6, 7]}

你可以将此结果转换为所需格式:
list(output.values())
> [[1, 3], [2, 6, 7]]

需要知道的是,这取决于字典是否按插入顺序排序,这只适用于 Python 3.6 及以上版本。


非常感谢大家。blhsing和Vicrobot的建议与这个非常相似。但愿我能接受全部三个建议。你们的方法并没有跨越我的思维,看起来非常简单。再次感谢。 - Aenaon

1

这是我实现的一段代码。它以线性时间运行:

l = [0,1,2,1,3,4,2,2]
dict1 = {}

for j,i in enumerate(l): # O(n)
    temp = dict1.get(i) # O(1) most cases
    if not temp:
        dict1[i] = [j]
    else:
        dict1[i].append(j) # O(1) 

print([item for item in dict1.values() if len(item) > 1]) # O(n)

输出:

[[1, 3], [2, 6, 7]]

-1

这实际上是一个时间复杂度问题。您的算法具有嵌套的for循环,两次迭代列表,因此时间复杂度的级别为n^2(其中n是列表的大小)。因此,当您将列表的大小乘以10(从1,000增加到10,000)时,您会看到大约增加了10^2 = 100的时间。这就是为什么时间从0.14秒增加到12秒的原因。

这里有一个简单的解决方案,不需要额外的库:

def get_locations(labels):
    locations = {}
    for index, label in enumerate(labels):
        if label in locations:
            locations[label].append(index)
        else:
            locations[label] = [index]

    return [locations[i] for i in locations if len(locations[i]) > 1]

由于for循环没有嵌套,时间复杂度大约为2n,因此当问题规模加倍时,您应该看到时间增加了大约4倍。

-2
你可以尝试使用 "collections" 模块中的 "Counter" 函数。
from collections import Counter
list1 = [1,1,2,3,4,4,4]
Counter(list1)

你将会得到类似于这样的输出

Counter({4: 3, 1: 2, 2: 1, 3: 1})

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接