沿着给定轴线洗牌NumPy数组

30

给定以下 NumPy 数组,

> a = array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5],[1, 2, 3, 4, 5]])

洗牌单行很简单,

> shuffle(a[0])
> a
array([[4, 2, 1, 3, 5],[1, 2, 3, 4, 5],[1, 2, 3, 4, 5]])

使用索引符号独立地打乱每行是否可能?或者必须迭代数组。我想到的是这样的内容:

> numpy.shuffle(a[:])
> a
array([[4, 2, 3, 5, 1],[3, 1, 4, 5, 2],[4, 2, 1, 3, 5]]) # Not the real output

虽然这显然行不通。

3个回答

30

使用 rand+argsort 技巧的向量化解决方案

我们可以沿指定的轴生成唯一索引,并使用 高级索引 将其索引到输入数组中。为了生成唯一索引,我们将使用随机浮点数生成 + 排序 技巧,从而提供给我们一个向量化解决方案。我们还将使用np.take_along_axis来涵盖通用的 n-dim 数组和通用的 axes。最终实现看起来应该像这样-

def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

请注意,这种洗牌不会在原地进行,并返回一个打乱顺序的副本。

示例运行 -

In [33]: a
Out[33]: 
array([[18, 95, 45, 33],
       [40, 78, 31, 52],
       [75, 49, 42, 94]])

In [34]: shuffle_along_axis(a, axis=0)
Out[34]: 
array([[75, 78, 42, 94],
       [40, 49, 45, 52],
       [18, 95, 31, 33]])

In [35]: shuffle_along_axis(a, axis=1)
Out[35]: 
array([[45, 18, 33, 95],
       [31, 78, 52, 40],
       [42, 75, 94, 49]])

有趣的解决方案!然而,我进行了一个快速实验,发现它比下面的朴素解决方案慢得多(大约慢了1000倍),后者反复调用rng.shuffle。有人能确认这一点吗?为什么会这么慢? - Nils
@Nils,我不确定你所提到的朴素解决方案是否仍然存在,但是一个解释是rng.shuffle只进行原地洗牌(O(n)时间复杂度)。对于这个解决方案,您必须为唯一索引分配内存,使用argsort进行排序(O(nlogn)时间复杂度),然后还必须为结果分配新内存。因此,对于大型数组,朴素解决方案的扩展性更好。 - Naphat Amundsen

24

由于您正在独立地洗牌多个序列,因此需要多次调用numpy.random.shuffle()numpy.random.shuffle()适用于任何可变序列,并且实际上不是ufunc。 可能最短、最高效的代码来单独洗牌二维数组a的所有行可能是:

list(map(numpy.random.shuffle, a))

有些人更喜欢将其编写成列表推导式:

[numpy.random.shuffle(x) for x in a]

至少对于Python 3.5和NumPy 1.10.2,这不起作用,a保持不变。 - drevicko
@drevicko:你的数组有多少维?这个答案是为了打乱一个二维数组的所有行(我相信它也适用于你的Python和Numpy版本的组合)。 - Sven Marnach
1
啊哈!我明白发生了什么:在Python 3.5中,map是惰性的,会生成一个迭代器,并且只有在你遍历它时才进行映射。如果你这样做:for _ in map(...): pass 它就会工作。 - drevicko
1
@drevicko 那很有道理。最好将代码编写为 for x in a: numpy.random.shuffle(x) - Sven Marnach
我猜是这样的...当你迭代a时,你确实会得到一个视图,不是吗?还有一个混乱的一行代码:如果a不太大,可以使用list(map(...)),但是for循环开始变得更加有吸引力 ;) - drevicko
显示剩余3条评论

7
对于最近查看此问题的人,numpy 提供了 permuted 方法,可以在指定的轴上 独立地 洗牌数组。
从他们的文档中(使用 random.Generator
rng = np.random.default_rng()
x = np.arange(24).reshape(3, 8)
x
array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23]])

y = rng.permuted(x, axis=1)
y
array([[ 4,  3,  6,  7,  1,  2,  5,  0],  
       [15, 10, 14,  9, 12, 11,  8, 13],
       [17, 16, 20, 21, 18, 22, 23, 19]])

1
非常好的答案,正是我在寻找的 - 这是现在执行此操作的规范方式。 - Praveen

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接