Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subclassing a model, writing a custom train_step, and distributed training in Tensorflow #20164

Open
dryglicki opened this issue Aug 26, 2024 · 1 comment
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@dryglicki
Copy link
Contributor

Tensorflow version: 2.17.0
Keras version: 3.5.0

Basically, my question is:

Is it possible to sub-class a Model instance that includes different model components within (like a GAN), write a custom train_step, and then perform distributed training in Tensorflow?

I'm trying to combine these two tutorials:
WGAN-GP: https://keras.io/examples/generative/wgan_gp/
Multi-GPU in Keras/Tensorflow: https://keras.io/examples/generative/wgan_gp/

I should point out that on my native machine, I'm logically splitting an RTX Ada A6000 (48GB VRAM) simply for testing purposes. I can run the vanilla Multi-GPU test just fine this way.

Here is a gist that documents progress so far. If I tried running wgan.fit() out of the box, it said that I needed to implement a call function to build the model. All right. That's what I added to the WGAN class and made sure it hit both the generator and the discriminator. Removed the metrics to keep it simple.

My new error looks something like this:

Number of devices: 2
Epoch 1/20
2024-08-26 10:18:33.210283: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:966] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/cond/else/_208/cond/StatefulPartitionedCall/discriminator_3/dropout_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2024-08-26 10:18:37.952827: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
	 [[{{function_node cond_false_19469}}{{node cond/Placeholder_5}}]]
2024-08-26 10:18:37.952916: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
	 [[{{function_node cond_false_19469}}{{node cond/Placeholder_5}}]]
	 [[StatefulPartitionedCall/cond/then/_207/cond/Placeholder_1/_1114]]

...

Traceback (most recent call last):
  File "/home/dryglicki/code/testing/wgan-gp-parallel_test/from_website_wgan-gp_multi.py", line 381, in <module>
    wgan.fit(train_dataset, epochs=epochs)
  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node cond/Placeholder_5 defined at (most recent call last):
  File "/home/dryglicki/code/testing/wgan-gp-parallel_test/from_website_wgan-gp_multi.py", line 381, in <module>

  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit

  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator

Detected at node cond/Placeholder_5 defined at (most recent call last):
  File "/home/dryglicki/code/testing/wgan-gp-parallel_test/from_website_wgan-gp_multi.py", line 381, in <module>

  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit

  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator

2 root error(s) found.
  (0) INVALID_ARGUMENT:  You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
	 [[{{node cond/Placeholder_5}}]]
	 [[StatefulPartitionedCall/cond/then/_207/cond/Placeholder_1/_1114]]
  (1) INVALID_ARGUMENT:  You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
	 [[{{node cond/Placeholder_5}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_one_step_on_iterator_36547]

Let's ignore the loss reduction and other programmatic/thematic issues for a moment. If you do test this, you are also going to run into that annoying OUT_OF_RANGE issue with tf.data. Let's ignore that, too.

Can I even do this?

@mehtamansi29 mehtamansi29 added type:Bug keras-team-review-pending Pending review by a Keras team member. labels Sep 2, 2024
@divyashreepathihalli divyashreepathihalli removed the keras-team-review-pending Pending review by a Keras team member. label Sep 5, 2024
@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Oct 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
11 participants
@SamanehSaadat @sachinprasadhs @divyashreepathihalli @dryglicki @mehtamansi29 and others