如何在numpy数组中允许浮点数和数组?

3

使用numpy和matplotlib,函数通常允许数字(浮点或整数)或numpy数组作为参数,如下:

import numpy as np

print np.sin(0)
# 0

x = np.arange(0,4,0.1)
y = np.sin(x)

在这个例子中,我使用一个整数参数和一个numpy数组 x 分别调用np.sin函数。现在我想编写一个类似的函数来实现相同的操作,但我不知道如何实现。例如:

def fun(foo, n):
    a = np.zeros(n)
    for i in range(n):
        a[i] = foo
    return a

我希望能够像这样调用函数:fun(1, 5),但不能像这样:fun(x, 5)。当然,我的实际计算要复杂得多。

如何初始化a,使得它既可以是简单数字,也可以是由数字组成的数组元素?

非常感谢您的帮助!


a 不能是列表的原因是什么? - C_Z_
@chepner,文档显示不是这样的。http://docs.scipy.org/doc/numpy-1.10.1/user/basics.rec.html 列必须是一个数据类型,但不是整个数组。 - user1121588
你想从 fun(x,5) 中获得什么?x 的 5 个副本?数组的形状是什么? - hpaulj
@CactusWoman:后来我依赖于a是一个numpy数组。我想我可以将a创建为列表,然后填充列表,最后写入a = np.array(a)。但是难道没有更“简洁”的解决方案吗? - Daniel
@hpaulj 我只是举了一个极其简化的函数作为例子。在我的实际问题中,我正在计算相当复杂的概率。 - Daniel
3个回答

2
内置的numpy函数通常以

开头。

 def foo(a, ...):
     a = np.asarray(a)
     ...

也就是说,它们将输入参数转换为数组(如果已经是数组,则不进行复制)。这使它们可以与标量和列表一起使用。
一旦参数是一个数组,它就有了形状,并且可以与其他参数进行广播。
在您的示例中,当foo是一个数组时,不清楚应该发生什么。
def fun(foo, n):
    a = np.zeros(n)
    for i in range(n):
        a[i] = foo
    return a

a 被初始化为一个浮点数数组。这意味着只有当 foo 是单个元素数字(标量,可能是单个元素数组)时,a[i]=foo 才有效。如果 foo 是具有多个值的数组,则可能会出现关于尝试使用序列设置元素的错误。

a[i] 等同于 a[i,...]。它在第一个维度上进行索引。因此,如果正确初始化了 a,则可以接受数组作为输入(遵循广播规则)。

如果将 a 初始化为 np.zeros(n, dtype=object),则 a[i]=foo 可以接受任何内容,因为 a 仅包含指向 Python 对象的指针。

np.frompyfunc 是从函数生成数组的一种方法。但它返回一个 dtype=object 的数组。 np.vectorize 使用它,但可以更好地控制输出类型。但是,两者都只适用于标量。如果给出数组作为参数,则按元素逐个传递到函数中。


0

你需要 a 来继承 foo 的尺寸:

def fun(foo, n):
    a = np.zeros((n,) + np.shape(foo))
    for i in range(n):
        a[i] = foo
    return a

我尝试过类似的东西,但这样我就不能再执行 fun(1, 5) 了:整数或浮点数没有“形状”。 - Daniel
@Daniel:这就是为什么我使用np.shape(foo)而不是foo.shape,它在标量上运行良好(返回())。 - Eric
哦,我错过了这个区别!事实上,我只尝试了foo.shape。我会尝试一下并回来告诉你! - Daniel

-1

您可以使用类型识别:

import numpy as np

def double(a):
    if type(a)==int:
        return 2*a
    elif type(a)==float:
        return 2.0*a
    elif type(a)==list:
        return [double(x) for x in a]
    elif type(a)==np.ndarray:
        return 2*a
    else:
        print "bad type"

print double(7)
print double(7.2)
print double([2,9,7])
print double(np.array([[9,8],[2,3]]))

结果:

>14
>14.4
>[4, 18, 14]
>[[18 16]
 [ 4  6]]

最终可能需要像我在列表上所做的那样进行递归处理


isinstance(var, type) 这样写更好。此外,你的整个函数可以重写为 np.asarray(a) * 2 - Eric
@Eric,如果使用类型的继承,则需要isinstance,这不是问题的目标。然后np.asarray将列表转换为数组,这与我的函数有所不同...顺便说一句,当然这种方法没有任何意义,只是为了提供一种解决方案来识别int和数组。 - Vince
继承是目标,因为使用type(...) = ...时会出现错误: a = np.zeros(2); double(a[0])。同样的错误也会在使用double(2L)时出现。 - Eric
啊,是的,你说得对,我没有想到继承会如此普遍。 - Vince

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接