以下代码无法正常工作:
错误信息抱怨使用了
根据文档,一个潜在的解决方法是指定静态参数。但这不适用于我的情况。参数几乎在每个函数调用时都会改变。我将代码分成了预处理步骤和计算步骤,预处理步骤执行像
但我仍然想问一下,是否有我不知道的一些解决方法?
def get_unique(arr):
return jnp.unique(arr)
get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))
错误信息抱怨使用了
jnp.unique
:FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()
有关尖锐位的文档解释了如果内部数组的形状取决于参数值,则jit无法工作,这恰好是此处的情况。根据文档,一个潜在的解决方法是指定静态参数。但这不适用于我的情况。参数几乎在每个函数调用时都会改变。我将代码分成了预处理步骤和计算步骤,预处理步骤执行像
jnp.unique
这样的计算,而计算步骤可以进行jit编译。但我仍然想问一下,是否有我不知道的一些解决方法?