在NumPy中如何对对象数组进行排序?

3

如何在Numpy中高效地按两个或多个属性对对象数组进行排序?

class Obj():
    def __init__(self,a,b):
        self.a = a
        self.b = b

arr = np.array([],dtype=Obj)        

for i in range(10):
    arr = np.append(arr,Obj(i, 10-i))

arr_sort = np.sort(arr, order=a,b) ???

谢谢,Willem-Jan


NumPy是否支持一个数据类型的类:np.array([],dtype=Obj) - J. P. Petersen
我会使用列表而不是对象数组。列表的附加速度更快。而且列表排序允许排序关键参数。 - hpaulj
也许你正在寻找结构化数组。但它们不能直接与Python类一起使用。 - user7138814
1
dtype=Obj 被视为通用的 object 数据类型。这种数组的元素可以是任何东西,包括 None - hpaulj
1个回答

1

order参数仅适用于结构化数组:

In [383]: arr=np.zeros((10,),dtype='i,i')
In [385]: for i in range(10):
     ...:     arr[i] = (i,10-i)  
In [386]: arr
Out[386]: 
array([(0, 10), (1, 9), (2, 8), (3, 7), (4, 6), (5, 5), (6, 4), (7, 3), (8, 2), (9, 1)], 
      dtype=[('f0', '<i4'), ('f1', '<i4')])
In [387]: np.sort(arr, order=['f0','f1'])
Out[387]: 
array([(0, 10), (1, 9), (2, 8), (3, 7), (4, 6), (5, 5), (6, 4), (7, 3), (8, 2), (9, 1)], 
      dtype=[('f0', '<i4'), ('f1', '<i4')])
In [388]: np.sort(arr, order=['f1','f0'])
Out[388]: 
array([(9, 1), (8, 2), (7, 3), (6, 4), (5, 5), (4, 6), (3, 7), (2, 8),
       (1, 9), (0, 10)], 
      dtype=[('f0', '<i4'), ('f1', '<i4')])

使用二维数组,lexsort提供类似的“有序”排序。

In [402]: arr=np.column_stack((np.arange(10),10-np.arange(10)))
In [403]: np.lexsort((arr[:,1],arr[:,0]))
Out[403]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
In [404]: np.lexsort((arr[:,0],arr[:,1]))
Out[404]: array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0], dtype=int32)

有了您的对象数组,我可以将属性提取到以下任一结构中:

In [407]: np.array([(a.a, a.b) for a in arr])
Out[407]: 
array([[ 0, 10],
       [ 1,  9],
       [ 2,  8],
      ....
       [ 7,  3],
       [ 8,  2],
       [ 9,  1]])
In [408]: np.array([(a.a, a.b) for a in arr],dtype='i,i')
Out[408]: 
array([(0, 10), (1, 9), (2, 8), (3, 7), (4, 6), (5, 5), (6, 4), (7, 3),
       (8, 2), (9, 1)], 
      dtype=[('f0', '<i4'), ('f1', '<i4')])

Python的sorted函数可用于arr(或其列表等价物)。
In [421]: arr
Out[421]: 
array([<__main__.Obj object at 0xb0f2d24c>,
       <__main__.Obj object at 0xb0f2dc0c>,
       ....
       <__main__.Obj object at 0xb0f35ecc>], dtype=object)
In [422]: sorted(arr, key=lambda a: (a.b,a.a))
Out[422]: 
[<__main__.Obj at 0xb0f35ecc>,
 <__main__.Obj at 0xb0f3570c>,
 ...
 <__main__.Obj at 0xb0f2dc0c>,
 <__main__.Obj at 0xb0f2d24c>]

你的Obj类缺少一个好的__str__方法。我必须使用像[(i.a, i.b) for i in arr]这样的东西来查看arr元素的值。
正如我在评论中所述,对于这个例子,列表比对象数组更好。
In [423]: alist=[]
In [424]: for i in range(10):
     ...:     alist.append(Obj(i,10-i))

列表append比重复的数组追加要快。而且与列表相比,对象数组在1D时不会增加太多功能,特别是当对象是自定义类时。你无法对arr进行任何数学运算,正如你所看到的,排序也不容易。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接