将itertools数组转换为numpy数组

8
我正在创建这个数组:
A=itertools.combinations(range(6),2)

我需要使用numpy操作这个数组,例如:

A.reshape(..

如果维度A很高,命令list(A)会非常慢。

我如何将itertools数组“转换”为numpy数组?

更新1: 我已经尝试了hpaulj的解决方案,在这种特定情况下略微慢一些,有什么想法吗?

start=time.clock()

A=it.combinations(range(495),3)
A=np.array(list(A))
print A

stop=time.clock()
print stop-start
start=time.clock()

A=np.fromiter(it.chain(*it.combinations(range(495),3)),dtype=int).reshape (-1,3)
print A

stop=time.clock()
print stop-start

结果:

[[  0   1   2]
 [  0   1   3]
 [  0   1   4]
 ..., 
 [491 492 494]
 [491 493 494]
 [492 493 494]]
10.323822
[[  0   1   2]
 [  0   1   3]
 [  0   1   4]
 ..., 
 [491 492 494]
 [491 493 494]
 [492 493 494]]
12.289898

你好,你有什么问题需要问吗? - Kotshi
我该如何将itertools数组“转换”为numpy数组? - stef_B.
4
如何从一个生成器构建一个NumPy数组?您可以使用NumPy的fromiter函数来从生成器中构建NumPy数组。fromiter函数需要两个参数:一个生成器和一个数据类型。下面是一个例子:import numpy as np def generate_numbers(): for i in range(10): yield i arr = np.fromiter(generate_numbers(), dtype=np.int) print(arr)输出:array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])在这个例子中,我们定义了一个生成器generate_numbers(),它生成0到9的整数。然后,我们使用NumPy的fromiter函数将生成器转换为一个整数类型的NumPy数组。最后,我们打印出这个数组的值。 - postelrich
你确定不是因为组合数过于庞大而导致速度"太慢"吗?如果你在尝试创建十亿个元素之类的东西,那总需要一些时间的。itertools.combinations函数会立即返回,因为它实际上并不提前创建任何组合,而是一个生成器。 - Blckknght
2个回答

6
我重新打开这个问题是因为我不喜欢链接的答案。接受的答案建议使用

标签来关闭

标签,但实际上这种做法可能会导致HTML错误。正确的做法是直接在

标签后面添加文本内容,而不是显式地关闭它。

np.array(list(A))  # producing a (15,2) array

但是OP显然已经尝试过list(A),并发现它很慢。

另一个答案建议使用np.fromiter。但是在其注释中隐藏的是需要1d数组的注意事项。

In [102]: A=itertools.combinations(range(6),2)
In [103]: np.fromiter(A,dtype=int)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-103-29db40e69c08> in <module>()
----> 1 np.fromiter(A,dtype=int)

ValueError: setting an array element with a sequence.

使用fromiter与此itertools需要以某种方式展开迭代器。

一个快速的计时集表明list不是缓慢的步骤。转换列表为数组是缓慢的步骤:

In [104]: timeit itertools.combinations(range(6),2)
1000000 loops, best of 3: 1.1 µs per loop
In [105]: timeit list(itertools.combinations(range(6),2))
100000 loops, best of 3: 3.1 µs per loop
In [106]: timeit np.array(list(itertools.combinations(range(6),2)))
100000 loops, best of 3: 14.7 µs per loop

我认为使用fromiter的最快方法是通过巧妙地使用itertools.chaincombinations展开:

In [112]: timeit
np.fromiter(itertools.chain(*itertools.combinations(range(6),2)),dtype=int)
   .reshape(-1,2)
100000 loops, best of 3: 12.1 µs per loop

在这种小规模的情况下,时间并没有节省多少。 (fromiter还需要一个count,可以再缩短一微秒。对于更大的情况,range(60)fromiter需要的时间是array的一半。


[numpy] itertools上进行快速搜索会发现许多纯numpy方法来生成所有组合。 itertools快速生成纯Python结构,但将其转换为数组是一个缓慢的步骤。


关于问题的一个挑剔的观点。

A是一个生成器,不是数组。list(A)确实会产生嵌套的列表,可以粗略地描述为一个数组。 但它不是np.array,也没有reshape方法。


你可以通过指定最终数组的大小来从 np.fromiter 中挤出更多的性能,这可以使用 scipy.special.binom(6, 2) 计算。 - ali_m
@hpaulj,我已经尝试了您的解决方案,请查看问题中的更新。 - stef_B.
有使用纯numpy方法生成组合的方式可能更快。@all_m建议使用triu,我相信其他一些方法已经在之前的SO问题中提出过。 - hpaulj
在 [numpy] itertools 上进行快速搜索,会得到许多纯 numpy 方式生成所有组合的建议。@hpaulj 您能否提供其中一些链接,因为我找不到? - winkmal
https://stackoverflow.com/search?q=%5Bnumpy%5D+itertools - hpaulj

3

获取N个元素的所有成对组合的另一种方法是使用np.triu_indices(N, k=1)生成一个(N, N)矩阵的上三角形索引,例如:

np.vstack(np.triu_indices(6, k=1)).T

对于小数组,itertools.combinations 更胜一筹,但对于较大的 Ntriu_indices 技巧可能更快:

In [1]: %timeit np.fromiter(itertools.chain.from_iterable(itertools.combinations(range(6), 2)), np.int)
The slowest run took 10.46 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 4.04 µs per loop

In [2]: %timeit np.array(np.triu_indices(6, 1)).T
The slowest run took 10.97 times longer than the fastest. This could mean that an intermediate result is being cached 
10000 loops, best of 3: 22.3 µs per loop

In [3]: %timeit np.fromiter(itertools.chain.from_iterable(itertools.combinations(range(1000), 2)), np.int)
10 loops, best of 3: 69.7 ms per loop

In [4]: %timeit np.array(np.triu_indices(1000, 1)).T
100 loops, best of 3: 10.6 ms per loop

我认为这个解决方案不会生成超过两个元素的组合。 - stef_B.
是的,我提到这个是因为你最初的问题是关于两个元素的组合。我认为可能可以将这种方法推广到处理超过两个元素的组合,但需要更多的思考。 - ali_m
我之前不知道 chain.fromiterable 这个函数。对于大数据情况,它的速度比 chain(*...) 快两倍。 - hpaulj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接