返回一个元素的Numpy数组的简洁方法

3

有没有一种简洁的方法编写函数,使其返回一个元素的numpy数组作为该元素本身?

假设我想将一个简单的平方函数向量化,并且希望我的返回值与输入的数据类型相同。我可以编写类似于以下内容的代码:

def foo(x):
    result = np.square(x)
    if len(result) == 1:
        return result[0]
    return result

或者

def foo(x):
    if len(x) == 1:
        return x**2
    return np.square(x)

有没有更简便的方法来实现这个功能?这样我就可以同时用于标量和数组了?

我知道可以直接检查输入的数据类型并使用 IF 语句使其工作,但有没有更简洁的方式呢?


这是一个针对StackOverflow的问题,但是这里有一个提示:请澄清您所期望的结果与默认的NumPy行为有何不同,例如np.square(2.)np.square([2.])np.square([1,2,3])。此外,请尝试展示相同输入的代码示例输出。 - mjul
1
这个对你有用吗 - (x**2).squeeze() - Divakar
2个回答

3

我不确定我是否完全理解了这个问题,但也许这样做可以帮到你?

def square(x):
    if 'numpy' in str(type(x)):
        return np.square(x)
    else:
        if isinstance(x, list):
            return list(np.square(x))
        if isinstance(x, int):
            return int(np.square(x))
        if isinstance(x, float):
            return float(np.square(x))

我定义了一些测试用例:

np_array_one = np.array([3.4])
np_array_mult = np.array([3.4, 2, 6])
int_ = 5
list_int = [2, 4, 2.9]
float_ = float(5.3)
list_float = [float(4.5), float(9.1), float(7.5)]

examples = [np_array_one, np_array_mult, int_, list_int, float_, list_float]

所以我们可以看到函数的行为。
for case in examples:
    print 'Input type: {}.'.format(type(case))
    out = square(case)
    print out
    print 'Output type: {}'.format(type(out))
    print '-----------------'

输出结果:

Input type: <type 'numpy.ndarray'>.
[ 11.56]
Output type: <type 'numpy.ndarray'>
-----------------
Input type: <type 'numpy.ndarray'>.
[ 11.56   4.    36.  ]
Output type: <type 'numpy.ndarray'>
-----------------
Input type: <type 'int'>.
25
Output type: <type 'int'>
-----------------
Input type: <type 'list'>.
[4.0, 16.0, 8.4100000000000001]
Output type: <type 'list'>
-----------------
Input type: <type 'float'>.
28.09
Output type: <type 'float'>
-----------------
Input type: <type 'list'>.
[20.25, 82.809999999999988, 56.25]
Output type: <type 'list'>
-----------------

在测试用例中,输入和输出总是相同的。但是,这个函数并不是真正的干净。

我使用了一些来自这个问题的代码。


{btsdaf} - Valentin Calomme
那么或许直接在SO上提问会更好。 - HonzaB

1
我认为你需要一个非常好的理由才能想要那个。 (你能解释一下为什么需要这个吗?)
所有该函数的客户端都必须检查结果是数组还是单个元素,或者将其转换为数组。通常,即使只有一个元素,如果迭代数组的所有元素,您也会得到非常优雅的代码。
除非它总是必须是单个元素(这是一个转换函数),否则返回语句应在空/长数组上抛出异常。
除此之外,您拥有的代码完全可以理解/可读。任何“改进”它的聪明技巧都会成为未来您或同事每次阅读它时的心理负担。
-编辑-
我明白你的意思。可能你已经遇到了len(1)不允许的问题(int / float没有len()),因此您可以对输入参数进行类型检查。例如:
if (type(x) == list) ...

{btsdaf} - Valentin Calomme

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接