能否对使用jax.numpy.unique的函数进行JIT编译?

3
以下代码无法正常工作:
def get_unique(arr):
    return jnp.unique(arr)

get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))

错误信息抱怨使用了jnp.unique
FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()
有关尖锐位的文档解释了如果内部数组的形状取决于参数值,则jit无法工作,这恰好是此处的情况。
根据文档,一个潜在的解决方法是指定静态参数。但这不适用于我的情况。参数几乎在每个函数调用时都会改变。我将代码分成了预处理步骤和计算步骤,预处理步骤执行像jnp.unique这样的计算,而计算步骤可以进行jit编译。
但我仍然想问一下,是否有我不知道的一些解决方法?
1个回答

1

目前,由于您提到的原因,无法使用jnp.unique处理非静态值。

在类似情况下,JAX有时会添加额外的参数,用于指定输出的静态大小(例如,在jax.numpy.nonzero中的size参数),但目前没有类似jnp.unique的实现。如果您需要此功能,请提交特性请求


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接