Jax在数组维度上对分段求和的处理

3

我对jax还比较陌生,遇到了以下问题:我需要在给定索引的情况下计算数组中的函数(求和/最小值/最大值等),为了解决这个问题,我找到了jnp.ops.segment_sum函数。这个函数对于一个数组非常有效,但是如何将这种方法推广到一批数组呢?例如:

import jax.numpy as jnp
indexes = jnp.array([[1,0,1],[0,0,1]])
batch_of_matrixes = jnp.array([
    np.arange(9).reshape((3,3)),
    np.arange(9).reshape((3, 3))
])
# The following works for one array but not multiple
jax.ops.segment_sum(
    data=batch_of_matrixes[0],
    segment_ids=indexes[0],
    num_segments=2)
# How can I get this to work with the full dataset along the 0 dimension?
# Intended Outcome:
[
    [
        [ 3  4  5],
        [ 6  8 10]
    ],
    [
        [3  5  7],
        [6  7  8]
   ]
]


如果有比obs.segment_*系列更通用的方法,请也告诉我。感谢您提供帮助和建议!
1个回答

3
JAX的vmap转换是专门为这种情况设计的。在您的情况下,您可以像这样使用它:
@jax.vmap
def f(data, index):
  return jax.ops.segment_sum(data, index, num_segments=2)

print(f(batch_of_matrixes, indexes))
# [[[ 3  4  5]
#   [ 6  8 10]]

#  [[ 3  5  7]
#   [ 6  7  8]]]

更多关于此的讨论,请参见JAX 101:自动向量化

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接