Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax Error with Stochastic Depth #18404

Open
anas-rz opened this issue Sep 16, 2023 · 2 comments
Open

Jax Error with Stochastic Depth #18404

anas-rz opened this issue Sep 16, 2023 · 2 comments

Comments

@anas-rz
Copy link
Contributor

anas-rz commented Sep 16, 2023

While converting a {model} from TensorFlow {to Keras Core}, I am facing error with the JAX backend at the end of the epoch. The error is thrown by the StochasticDepth layer {used here in the network}. The code works with PyTorch and TensorFlow backends and shows the leakage with JAX. I do understand the concept of pure functions in JAX. Still not able to figure out whether it's the problem of my framework or Keras Core. The StochasticDepth layer works perfectly CCT example {here}.
Here's the stack trace if that can help:

Epoch 1/5
48/49 ━━━━━━━━━━━━━━━━━━━━ 0s 967ms/step - accuracy: 0.1990 - loss: 6.7065
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
[<ipython-input-14-61822f405d65>](https://localhost:8080/#) in <cell line: 12>()
     10         ],
     11     )
---> 12 history = model.fit(
     13     pipeline_train,
     14     batch_size=BATCH_SIZE,

4 frames
[/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    121             # To get the full stack trace, call:
    122             # `keras_core.config.disable_traceback_filtering()`
--> 123             raise e.with_traceback(filtered_tb) from None
    124         finally:
    125             del filtered_tb

    [... skipping hidden 20 frame]

[/content/./focalnet-keras-core/focalnet_keras_core/layers.py](https://localhost:8080/#) in call(self, x, training)
    154             keep_prob = 1 - self.drop_path_rate
    155             shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
--> 156             random_tensor = keep_prob + keras.random.uniform(shape, 0, 1)
    157             random_tensor = ops.floor(random_tensor)
    158             return (x / keep_prob) * random_tensor

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in op(self, *args)
    721 def _forward_operator_to_aval(name):
    722   def op(self, *args):
--> 723     return getattr(self.aval, f"_{name}")(self, *args)
    724   return op
    725 

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in deferring_binary_op(self, other)
    254     args = (other, self) if swap else (self, other)
    255     if isinstance(other, _accepted_binop_types):
--> 256       return binary_op(*args)
    257     if isinstance(other, _rejected_binop_types):
    258       raise TypeError(f"unsupported operand type(s) for {opchar}: "

    [... skipping hidden 5 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/partial_eval.py](https://localhost:8080/#) in _assert_live(self)
   1577   def _assert_live(self) -> None:
   1578     if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1579       raise core.escaped_tracer_error(self, None)
   1580 
   1581   def get_referent(self):

UnexpectedTracerError: Exception encountered when calling StochasticDepth.call().

Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was compiled_train_step at /usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/trainer.py:203 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/core.py:19 (_initialize). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py:867 (stateless_call)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/stateless_scope.py:66 (__exit__)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/variables.py:370 (initialize_all_variables)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/common/variables.py:87 (_deferred_initialize)
/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/core.py:19 (_initialize)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Arguments received by StochasticDepth.call():
  • x=jnp.ndarray(shape=(48, 3136, 96), dtype=float32)
  • training=True
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
@fchollet
Copy link
Member

fchollet commented Sep 17, 2023

So you've verified the StochasticDepth layer works fine in JAX otherwise?

Typically the "tracer leak" type of issue happens when you take a tensor from an intermediate computation and store it as an attribute of a permanent object, like a layer or model.

I think the line random_tensor = keep_prob + keras.random.uniform(shape, 0, 1) in the trace is interesting. What happens if you replace it with a constant, like 0.6?

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@sachinprasadhs sachinprasadhs self-assigned this Jan 18, 2024
@sachinprasadhs
Copy link
Collaborator

@anas-rz , Could you please check with the latest Keras 3 package and let us know if you're still facing the issue.

If you could provide some sample reproducible code, it makes it easier for us to debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
3 participants