You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While converting a {model} from TensorFlow {to Keras Core}, I am facing error with the JAX backend at the end of the epoch. The error is thrown by the StochasticDepth layer {used here in the network}. The code works with PyTorch and TensorFlow backends and shows the leakage with JAX. I do understand the concept of pure functions in JAX. Still not able to figure out whether it's the problem of my framework or Keras Core. The StochasticDepth layer works perfectly CCT example {here}.
Here's the stack trace if that can help:
Epoch 1/5
48/49 ━━━━━━━━━━━━━━━━━━━━ 0s 967ms/step - accuracy: 0.1990 - loss: 6.7065
---------------------------------------------------------------------------
UnexpectedTracerError Traceback (most recent call last)
[<ipython-input-14-61822f405d65>](https://localhost:8080/#) in <cell line: 12>()
10 ],
11 )
---> 12 history = model.fit(
13 pipeline_train,
14 batch_size=BATCH_SIZE,
4 frames
[/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
121 # To get the full stack trace, call:
122 # `keras_core.config.disable_traceback_filtering()`
--> 123 raise e.with_traceback(filtered_tb) from None
124 finally:
125 del filtered_tb
[... skipping hidden 20 frame]
[/content/./focalnet-keras-core/focalnet_keras_core/layers.py](https://localhost:8080/#) in call(self, x, training)
154 keep_prob = 1 - self.drop_path_rate
155 shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
--> 156 random_tensor = keep_prob + keras.random.uniform(shape, 0, 1)
157 random_tensor = ops.floor(random_tensor)
158 return (x / keep_prob) * random_tensor
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in op(self, *args)
721 def _forward_operator_to_aval(name):
722 def op(self, *args):
--> 723 return getattr(self.aval, f"_{name}")(self, *args)
724 return op
725
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in deferring_binary_op(self, other)
254 args = (other, self) if swap else (self, other)
255 if isinstance(other, _accepted_binop_types):
--> 256 return binary_op(*args)
257 if isinstance(other, _rejected_binop_types):
258 raise TypeError(f"unsupported operand type(s) for {opchar}: "
[... skipping hidden 5 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/partial_eval.py](https://localhost:8080/#) in _assert_live(self)
1577 def _assert_live(self) -> None:
1578 if not self._trace.main.jaxpr_stack: # type: ignore
-> 1579 raise core.escaped_tracer_error(self, None)
1580
1581 def get_referent(self):
UnexpectedTracerError: Exception encountered when calling StochasticDepth.call().
Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was compiled_train_step at /usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/trainer.py:203 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/core.py:19 (_initialize).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py:867 (stateless_call)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/stateless_scope.py:66 (__exit__)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/variables.py:370 (initialize_all_variables)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/variables.py:87 (_deferred_initialize)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/core.py:19 (_initialize)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Arguments received by StochasticDepth.call():
• x=jnp.ndarray(shape=(48, 3136, 96), dtype=float32)
• training=True
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The text was updated successfully, but these errors were encountered:
So you've verified the StochasticDepth layer works fine in JAX otherwise?
Typically the "tracer leak" type of issue happens when you take a tensor from an intermediate computation and store it as an attribute of a permanent object, like a layer or model.
I think the line random_tensor = keep_prob + keras.random.uniform(shape, 0, 1) in the trace is interesting. What happens if you replace it with a constant, like 0.6?
While converting a {model} from TensorFlow {to Keras Core}, I am facing error with the JAX backend at the end of the epoch. The error is thrown by the
StochasticDepth
layer {used here in the network}. The code works with PyTorch and TensorFlow backends and shows the leakage with JAX. I do understand the concept of pure functions in JAX. Still not able to figure out whether it's the problem of my framework or Keras Core. The StochasticDepth layer works perfectly CCT example {here}.Here's the stack trace if that can help:
The text was updated successfully, but these errors were encountered: