Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support pytree (nested list) in optimizer build #18443

Open
refraction-ray opened this issue Jul 13, 2023 · 3 comments
Open
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:feature The user is asking for a new feature.

Comments

@refraction-ray
Copy link

w = jax.numpy.ones([4, 1])
b1 = jax.numpy.ones([1])
b2 = jax.numpy.ones([1])

opt = keras.optimizers.Adam(1e-2)
opt.build([w, b1, b2]) # ok
opt = keras.optimizers.Adam(1e-2)
opt.build([w, [b1, b2]]) # failed
# AttributeError: 'list' object has no attribute 'shape'

The latter case is very common, when one use functional programming paradigm, as model.variables is a list of tensors (similar to b1, b2 above), and there could be some other variables outside the model (similar to w above) that the user also want to optimize together. A full pytree support in optimizer.build would be fantastic to use.

@refraction-ray refraction-ray changed the title Support pytree (nested list) in optimizer build Jul 13, 2023
@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@sachinprasadhs sachinprasadhs added the type:feature The user is asking for a new feature. label Feb 16, 2024
@sachinprasadhs sachinprasadhs self-assigned this Feb 16, 2024
@sachinprasadhs
Copy link
Collaborator

Now, both the cases are failing with the error AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_unique_id'

Attached Gist here for reference

@sachinprasadhs sachinprasadhs added the keras-team-review-pending Pending review by a Keras team member. label Feb 16, 2024
@fchollet
Copy link
Member

fchollet commented Feb 22, 2024

  1. Why not just call flatten on your structure before calling build()?
  2. If you pass a nested structure to build(), do you expect to also pass the same nested structure in stateless_apply(optimizer_variables, grads, trainable_variables)? (as optimizer_variables)
@hertschuh hertschuh removed the keras-team-review-pending Pending review by a Keras team member. label Feb 22, 2024
@refraction-ray
Copy link
Author

@fchollet

  1. this question can apply to every pytree compatible API, it is just more elegant and easy to use for an API directly accepts pytree structures, as most APIs in keras did.
  2. yes.
@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Feb 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:feature The user is asking for a new feature.
4 participants