从另一个数组中包含的索引修改数组。

3
我有一个形如 (2,10) 的数组,例如:
arr = jnp.ones(shape=(2,10)) * 2

或者
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]

还有一个数组,例如[2,4]

我希望第二个数组告诉我们应该从哪个索引开始掩盖arr中的元素。在这个例子中,结果将是:

[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
 [2. 2. 2. 2.  -1. -1. -1. -1. -1. -1.]]

我需要使用jax.numpy,并希望答案向量化且尽可能快,即不使用循环。

1
使用numpy非常简单:for i in range(2): arr[i, idx[i]:] = -1。你是在寻找一些个人魔法吗? - hpaulj
谢谢你的回答,但我正在寻找使用jax.numpy进行矢量化且快速的解决方案。 - Valentin Macé
我和其他人可能都希望看到更明确的问题,一个展示你正在做什么的问题 - 即使它像我建议的那样迭代。不要让神奇的“向量化”目标未被说明! - hpaulj
我同意。我修改了我的问题,以便未来的读者参考。 - Valentin Macé
在一些快速测试中,循环相对于掩码where的速度取决于数组形状(n,m)。当m相对较大时,使用n个切片更快。 - hpaulj
1个回答

3

您可以使用vmapped的三项 jnp.where 语句来完成此操作。例如:

import jax.numpy as jnp
import jax

arr = jnp.ones(shape=(2,10)) * 2
idx = jnp.array([2, 4])

@jax.vmap
def f(row, ind):
  return jnp.where(jnp.arange(len(row)) < ind, row, -1)

f(arr, idx)
# DeviceArray([[ 2.,  2., -1., -1., -1., -1., -1., -1., -1., -1.],
#              [ 2.,  2.,  2.,  2., -1., -1., -1., -1., -1., -1.]], dtype=float32)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接