如何快速检查numpy数组的所有元素是否为浮点数?

5
我需要编写一个函数F,它接受一个dtype为object的numpy数组,并返回该数组中的所有元素是浮点数、整数或字符串的情况。例如:
F(np.array([1., 2.], dtype=object))  --> float
F(np.array(['1.', '2.'], dtype=object))  --> string
F(np.array([1, 2], dtype=object))  --> int
F(np.array([1, 2.], dtype=object))  --> float
F(np.array(['hello'], dtype=object))  --> string

F(np.array([1, 'hello'], dtype=object))  --> ERROR

有什么好的想法可以有效地完成这个任务吗?(==使用numpy内置函数)非常感谢。

1
你想让 F(np.array([1, 2.], dtype=object)) 因为混合使用整数和浮点数而抛出错误吗? - bcollins
返回 F(np.array([1, 2.], dtype=object)) 的浮点数是可以的。 - Olexiy
2个回答

4

最简单的方法可能是通过np.array将内容传递,并检查结果类型:

a = np.array([1., 2.], dtype=object)
b = np.array(['1.', '2.'], dtype=object)
c = np.array([1, 2], dtype=object)
d = np.array([1, 2.], dtype=object)
e = np.array(['hello'], dtype=object)
f = np.array([1, 'hello'], dtype=object)

>>> np.array(list(a)).dtype
dtype('float64')
>>> np.array(list(b)).dtype
dtype('S2')
>>> np.array(list(c)).dtype
dtype('int32')
>>> np.array(list(d)).dtype
dtype('float64')
>>> np.array(list(e)).dtype
dtype('S5')

如果类型不兼容,它不会引发错误,因为这不是numpy的行为:

>>> np.array(list(f)).dtype
dtype('S5')

非常感谢!引发一个错误实际上很容易 - 只需执行 np.all(new_array == old_array)。 - Olexiy
FYI list(array_to_check) 会复制 array_to_check - Vladimir Shteyn

1

不确定这种方法在对象管理方面是否最有效,但是怎么样:

def F(a):
    unique_types = set([type(i) for i in list(a)])
    if len(unique_types) > 1:
        raise ValueError('data types not consistent')
    else:
        return unique_types.pop()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接