Numpy,长数组问题

14

我有两个数组(a和b),它们各自包含n个整数,在范围(0,N)内。

笔误:数组实际上拥有2^n个整数,其中最大的整数取值为N = 3^n

我想要计算在所有可能的组合情况下a和b元素之和(sum_ij_ = a_i_ + b_j_ for all i,j),然后对N取模(sum_ij_ = sum_ij_ % N),最后计算不同求和结果的频率。

为了使用numpy快速地完成这个过程,而且没有任何循环,我尝试使用meshgrid函数和bincount函数。

A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)

现在的问题是我的输入数组很长。当我使用具有2 ^ 13个元素的输入时,meshgrid会给出MemoryError错误。我想为具有2 ^ 15-2 ^ 20个元素的数组计算它。

这是n处于15到20的范围内

是否有任何聪明的技巧可以在numpy中完成此操作?

非常感谢任何帮助。

-- jon


NumPy真的会那么高效吗?我猜你最好用c++,编写自己的函数并尽可能地进行优化。从听起来的情况来看,NumPy无法处理如此大的数组。虽然我必须说,如果您有两个具有2^15到2^20个元素的数组,那么如果您查看它们所有不同的总和,那么您最终将得到一个具有2^30到2^40个元素的数组。这是很多的... - JSchlather
@unutbu: N~3^n@liberalkid: 我想你是对的。尽管我的C++技能不是很好。 - jonalm
3个回答

7

尝试对它进行分块处理。你的网格是一个NxN矩阵,将其分成10x10 N/10xN/10的块,只计算100个bin,在最后将它们加起来即可。这种方法只使用了整个过程的1%左右的内存。


我想这是一个可行的方式,但是有没有使用numpy数组的聪明方法来实现它。尽量减少for循环的使用。 - jonalm
嘿,一个块的最佳大小是多少? - jonalm
可能是您可以制作的最大块,并且仍然可以安全地存储在RAM中。 - Autoplectic

2

根据jonalm的评论进行编辑:

jonalm: N~3^n而不是n~3^N。其中N是a中的最大元素,n是a的元素数量。

n约为2^20。如果N约为3^n,则N约为3^(2^20) > 10^(500207)。科学家估计(http://www.stormloader.com/ajy/reallife.html),宇宙中只有大约10^87个粒子。因此,计算机无法处理大小为10^(500207)的int型变量(天真的方法)。

jonalm: 不过我对你定义的pv()函数有点好奇。(我无法运行它,因为text.find()未定义(猜测它在另一个模块中)。)这个函数如何工作,它有什么优势?

pv是我编写的一个小助手函数,用于调试变量的值。它的工作方式类似于print(),但当你输入pv(x)时,它会同时打印出字面上的变量名(或表达式字符串)、一个冒号和变量的值。

如果你放进去

#!/usr/bin/env python
import traceback
def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)

在脚本中,你应该得到:
x: 1

使用 pv 而不是 print 的小优势在于它节省了您的打字时间。您不必再写出以下内容:
print('x: %s'%x)

你可以直接放下。
pv(x)

当有多个变量需要跟踪时,给变量贴上标签会很有帮助。我只是厌倦了一遍又一遍地写出它们。 pv函数通过使用traceback模块来查看调用pv函数本身的代码行。 (参见http://docs.python.org/library/traceback.html#module-traceback)该行代码以字符串形式存储在变量text中。 text.find()是对通常的字符串方法find()的调用。例如,如果

text='pv(x)'

那么

text.find('(') == 2               # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x'  # Everything in between the parentheses

我假设 n ~ 3^N,且 n ~ 2 ** 20。
这个想法是在模 N 的情况下工作。这可以减少数组的大小。第二个想法(当 n 很大时非常重要)是使用 numpy ndarrays 的 'object' 类型,因为如果使用整数 dtype,可能会导致超过允许的最大整数大小。
#!/usr/bin/env python
import traceback
import numpy as np

def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))

您可以将n更改为2 ** 20,但是下面我展示小n的情况,以便输出更易于阅读。
n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4

a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
#  1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
#  0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
#  3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
#  3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]

wa存储了a中0、1、2、3的数量, wb存储了b中0、1、2、3的数量。
wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')

将以下文本翻译为中文:

将0看作一个代币或筹码。1、2、3同理。

将wa=[24 28 28 20]看作意味着有一个袋子,里面有24个0代币,28个1代币,28个2代币和20个3代币。

你有一个wa袋和一个wb袋。当你从每个袋子中取出一个代币时,你将它们“相加”并形成一个新的代币。你对答案进行“模运算”(模N)。

想象一下从wb袋中取出一个1代币,并将其与wa袋中的每个代币相加。

1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip  (we are mod'ing by N=4)

由于wb袋中有34个1芯片,当你将它们与wa=[24 28 28 20]袋中的所有芯片相加时,你会得到:
34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips

这只是由于34个1芯片造成的部分计数。您还需要处理wb袋中的其他类型芯片,但以下内容展示了使用的方法:
for i,count in enumerate(wb):
    partial_count=count*wa
    pv(partial_count)
    shifted_partial_count=np.roll(partial_count,i)
    pv(shifted_partial_count)
    result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]

pv(result)    
# result: [2444 2504 2520 2532]

这是最终结果:2444个0,2504个1,2520个2,2532个3。
# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
    print('n is too large to run the check')
else:
    c=(a[:]+b[:,np.newaxis])
    c=c.ravel()
    c=c%N
    result2=np.bincount(c)
    pv(result2)
    assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]

请注意,c %= N 是可行的(并且可能使用两倍的内存)。 - Eric O. Lebigot
@EOL,是的,c %= N更好。然而,定义c=(a[:]+b[:,np.newaxis])意味着你已经输了这场战斗,因为这是一个巨大的2-d数组,形状为(n,n),而上面的解决方案仅使用了几个形状为(N)的1-d数组。 - unutbu
非常感谢您的回答,我喜欢这种方法。但是我认为这对我没有帮助,因为数组a(和b中的数字)都不同(我没有提到,我的错)。bincount(a)将仅包含1和0。N〜3 ^ n而不是n〜3 ^ N。N是a中的最大元素,n是a中的元素数量。然而,我对您定义的pv()函数有点好奇。(我无法运行它,因为text.find()未定义(猜测它在另一个模块中))。这个函数如何工作,它的优点是什么? - jonalm
亲爱的Ubuntu。我发现我的符号表示存在不一致。我真正想表达的是size(a)=2^n(而不是我在第一篇帖子中写的n),max(a)=3^n(=N),其中n尽可能高。a[:]+b[:,np.newaxis] %N可以实现n=14,但不能更高。我想要n〜20 => max(a)=3^20 < 2^32,因此我需要巧妙地处理数据。我试图说明的是,对于我需要的n和N,bincount(a)包含的元素比a多,因此我不认为您的方法对该问题有效。话虽如此,我喜欢pv()函数。 - jonalm

1

检查一下你的数学,你需要的空间太大了:

2^20*2^20 = 2^40 = 1 099 511 627 776

如果每个元素只有一个字节,那么这已经是一兆字节的内存。

加上一个或两个循环。这个问题不适合用最大化内存和最小化计算来解决。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接