You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is it possible to sub-class a Model instance that includes different model components within (like a GAN), write a custom train_step, and then perform distributed training in Tensorflow?
I should point out that on my native machine, I'm logically splitting an RTX Ada A6000 (48GB VRAM) simply for testing purposes. I can run the vanilla Multi-GPU test just fine this way.
Here is a gist that documents progress so far. If I tried running wgan.fit() out of the box, it said that I needed to implement a call function to build the model. All right. That's what I added to the WGAN class and made sure it hit both the generator and the discriminator. Removed the metrics to keep it simple.
My new error looks something like this:
Number of devices: 2
Epoch 1/20
2024-08-26 10:18:33.210283: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:966] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/cond/else/_208/cond/StatefulPartitionedCall/discriminator_3/dropout_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2024-08-26 10:18:37.952827: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
[[{{function_node cond_false_19469}}{{node cond/Placeholder_5}}]]
2024-08-26 10:18:37.952916: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
[[{{function_node cond_false_19469}}{{node cond/Placeholder_5}}]]
[[StatefulPartitionedCall/cond/then/_207/cond/Placeholder_1/_1114]]
...
Traceback (most recent call last):
File "/home/dryglicki/code/testing/wgan-gp-parallel_test/from_website_wgan-gp_multi.py", line 381, in <module>
wgan.fit(train_dataset, epochs=epochs)
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node cond/Placeholder_5 defined at (most recent call last):
File "/home/dryglicki/code/testing/wgan-gp-parallel_test/from_website_wgan-gp_multi.py", line 381, in <module>
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator
Detected at node cond/Placeholder_5 defined at (most recent call last):
File "/home/dryglicki/code/testing/wgan-gp-parallel_test/from_website_wgan-gp_multi.py", line 381, in <module>
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d17_py3d11/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator
2 root error(s) found.
(0) INVALID_ARGUMENT: You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
[[{{node cond/Placeholder_5}}]]
[[StatefulPartitionedCall/cond/then/_207/cond/Placeholder_1/_1114]]
(1) INVALID_ARGUMENT: You must feed a value for placeholder tensor 'StatefulPartitionedCall/cond/else/_208/cond/Placeholder_5' with dtype int32
[[{{node cond/Placeholder_5}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_one_step_on_iterator_36547]
Let's ignore the loss reduction and other programmatic/thematic issues for a moment. If you do test this, you are also going to run into that annoying OUT_OF_RANGE issue with tf.data. Let's ignore that, too.
Can I even do this?
The text was updated successfully, but these errors were encountered:
Tensorflow version: 2.17.0
Keras version: 3.5.0
Basically, my question is:
Is it possible to sub-class a
Model
instance that includes different model components within (like a GAN), write a custom train_step, and then perform distributed training in Tensorflow?I'm trying to combine these two tutorials:
WGAN-GP: https://keras.io/examples/generative/wgan_gp/
Multi-GPU in Keras/Tensorflow: https://keras.io/examples/generative/wgan_gp/
I should point out that on my native machine, I'm logically splitting an RTX Ada A6000 (48GB VRAM) simply for testing purposes. I can run the vanilla Multi-GPU test just fine this way.
Here is a gist that documents progress so far. If I tried running
wgan.fit()
out of the box, it said that I needed to implement acall
function to build the model. All right. That's what I added to theWGAN
class and made sure it hit both the generator and the discriminator. Removed the metrics to keep it simple.My new error looks something like this:
Let's ignore the loss reduction and other programmatic/thematic issues for a moment. If you do test this, you are also going to run into that annoying
OUT_OF_RANGE
issue withtf.data
. Let's ignore that, too.Can I even do this?
The text was updated successfully, but these errors were encountered: