我有一个形如 (2,10) 的数组,例如:
或者
我需要使用
arr = jnp.ones(shape=(2,10)) * 2
或者
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
还有一个数组,例如[2,4]
。
我希望第二个数组告诉我们应该从哪个索引开始掩盖arr
中的元素。在这个例子中,结果将是:
[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
[2. 2. 2. 2. -1. -1. -1. -1. -1. -1.]]
我需要使用
jax.numpy
,并希望答案向量化且尽可能快,即不使用循环。
numpy
非常简单:for i in range(2): arr[i, idx[i]:] = -1
。你是在寻找一些个人魔法吗? - hpauljwhere
的速度取决于数组形状(n,m)。当m
相对较大时,使用n
个切片更快。 - hpaulj