Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dynamo issue #6527

Merged
merged 16 commits into from
Oct 25, 2024
Merged

Fix dynamo issue #6527

merged 16 commits into from
Oct 25, 2024

Conversation

oraluben
Copy link
Contributor

@oraluben oraluben commented Sep 12, 2024

Dynamo use faketensor to trace tensor ops. In some case, the mechanism break compiling with deepspeed.

An example could be found at https://gist.github.com/oraluben/9b8240c2fe482eb4382453d6c97a5f76, to see issues, install deepspeed==0.14.4 instead of my fork

without this PR, llama cannot be compiled.

Detailed explanation:

  1. ZeROOrderedDict
    dynamo use deepcopy to copy tensors, which will call object.__reduce__. When copying ZeROOrderedDict, the default implementation do not copy its _parent_module and will lead to failure.
  2. param maybe faketensor and do not have ds_status yet, but during tracing it's ok to just skip the register_external_parameter, it should be done ways before.
@oraluben oraluben marked this pull request as ready for review September 12, 2024 06:17
@oraluben oraluben changed the title Fix dynamo issue in llama Sep 12, 2024
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oraluben Thank you for offering a great investigation! I think this is a clean and simple solution for the issue.

@oraluben
Copy link
Contributor Author

oraluben commented Sep 13, 2024

torch.compiler.is_compiling() should be better for this case, however there's still issue, presumably on dynamo side (since we have faketensor we're definitely tracing). So keep it for now.

[rank1]:   File "/home/yyc/accelerate-demo/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1720, in __getattr__
[rank1]:     return _parameters[name]
[rank1]:   File "/home/yyc/repo/DeepSpeed/deepspeed/runtime/zero/parameter_offload.py", line 67, in __getitem__
[rank1]:     if not is_compiling() and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
[rank1]: torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___self_attn_q_proj(*(FakeTensor(..., device='cuda:1', size=(1, s0, 4096), dtype=torch.float16,
[rank1]:            grad_fn=<MulBackward0>),), **{}):
[rank1]: 'FakeTensor' object has no attribute 'ds_status'

my patch in deepspeed.runtime:

diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py
index 879c0a1a..3994c1f5 100644
--- a/deepspeed/runtime/compiler.py
+++ b/deepspeed/runtime/compiler.py
@@ -10,6 +10,15 @@ def is_compile_supported():
     return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
 
 
+def is_compiling():
+    if not is_compile_supported():
+        return False
+    elif hasattr(torch.compiler, 'is_compiling'):  # torch >= 2.3
+        return torch.compiler.is_compiling()
+    else:
+        return torch._dynamo.is_compiling()
+
+
 def disable(func):
     if is_compile_supported():
         return torch.compiler.disable(func)
@loadams
Copy link
Contributor

loadams commented Oct 23, 2024

@oraluben - sorry this PR has taken so long to be merged, I think it just needed to have master merged again to get the XPU fixes.

@loadams loadams added this pull request to the merge queue Oct 23, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 24, 2024
@tohtana tohtana added this pull request to the merge queue Oct 25, 2024
Merged via the queue into microsoft:master with commit 3d5cf73 Oct 25, 2024
13 checks passed
@oraluben oraluben deleted the fix-compile-deepcopy branch October 25, 2024 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
4 participants