我对jax还比较陌生,遇到了以下问题:我需要在给定索引的情况下计算数组中的函数(求和/最小值/最大值等),为了解决这个问题,我找到了jnp.ops.segment_sum函数。这个函数对于一个数组非常有效,但是如何将这种方法推广到一批数组呢?例如:
import jax.numpy as jnp
indexes = jnp.array([[1,0,1],[0,0,1]])
batch_of_matrixes = jnp.array([
np.arange(9).reshape((3,3)),
np.arange(9).reshape((3, 3))
])
# The following works for one array but not multiple
jax.ops.segment_sum(
data=batch_of_matrixes[0],
segment_ids=indexes[0],
num_segments=2)
# How can I get this to work with the full dataset along the 0 dimension?
# Intended Outcome:
[
[
[ 3 4 5],
[ 6 8 10]
],
[
[3 5 7],
[6 7 8]
]
]
如果有比obs.segment_*系列更通用的方法,请也告诉我。感谢您提供帮助和建议!