如何从一个结构化的NumPy数组中删除一列?

17

假设你有一个结构化的NumPy数组,从CSV文件中生成,其中第一行是字段名。数组的形式如下:

dtype([('A', '<f8'), ('B', '<f8'), ('C', '<f8'), ..., ('n','<f8'])

现在,假设你想从这个数组中移除第'i'列。有没有方便的方法可以这样做?

我希望它能像删除一样工作:

new_array = np.delete(old_array, 'i')

有任何想法吗?

3个回答

21

这不完全是一个函数调用,但以下是一种删除第i个字段的方法:

In [67]: a
Out[67]: 
array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], 
      dtype=[('A', '<f8'), ('B', '<f8'), ('C', '<f8')])

In [68]: i = 1   # Drop the 'B' field

In [69]: names = list(a.dtype.names)

In [70]: names
Out[70]: ['A', 'B', 'C']

In [71]: new_names = names[:i] + names[i+1:]

In [72]: new_names
Out[72]: ['A', 'C']

In [73]: b = a[new_names]

In [74]: b
Out[74]: 
array([(1.0, 3.0), (4.0, 6.0)], 
      dtype=[('A', '<f8'), ('C', '<f8')])

作为一个函数包装:

def remove_field_num(a, i):
    names = list(a.dtype.names)
    new_names = names[:i] + names[i+1:]
    b = a[new_names]
    return b

也许更自然的做法是删除给定的字段 名称

def remove_field_name(a, name):
    names = list(a.dtype.names)
    if name in names:
        names.remove(name)
    b = a[names]
    return b

另外,请查看属于matplotlib的mlab模块中的drop_rec_fields函数


< p > 更新:请参阅我的答案如何在不复制的情况下从结构化numpy数组中删除列?,以了解一种创建结构化数组子集视图的方法,而无需复制整个数组。


7

通过谷歌搜索并从Warren的回答中了解了我需要知道的内容后,我忍不住发布了一个更加简洁的版本,并添加了一次性高效地删除多个字段的选项:

def rmfield( a, *fieldnames_to_remove ):
    return a[ [ name for name in a.dtype.names if name not in fieldnames_to_remove ] ]

示例:

a = rmfield(a, 'foo')
a = rmfield(a, 'foo', 'bar')  # remove multiple fields at once

如果我们真的要用高尔夫球来比喻,下面的内容是等效的:
rmfield=lambda a,*f:a[[n for n in a.dtype.names if n not in f]]

1
如果我可以这么说,你的第二个解决方案相当丑陋。特别是我不喜欢你使用lambda表达式来实现函数声明的效果。这不是一个好的风格,也很难阅读。其他人似乎也同意我的看法:https://dev59.com/hnVC5IYBdhLWcg3w9GLM#134638 - Konstantin Schubert
3
也许你没有看到这句话:“如果我们真的要击打高尔夫球……”。“代码高尔夫”的目标是创建最短的代码,而不考虑可读性,这几乎总是很丑的。 - jez
1
我不知道那个短语。我仍然看不出重点,但在那种情况下,也许我的回应有点严厉。 - Konstantin Schubert
1
看看吧,它在codegolf.stackexchange.com上有自己的堆栈 :-) - jez

0
最简单的解决方案是使用内置函数。
让我们有一个 `points_array = np.array`。 这个 `np.array` 有多列,其中之一是 "classes"。
import numpy.lib.recfunctions as recfc

points_array = recfc.drop_fields(points_array, "classes", usemask=False)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接