Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix expert grad scaling problem with ZeRO optimizer #6546

Merged
merged 7 commits into from
Oct 23, 2024

Conversation

wyooyw
Copy link
Contributor

@wyooyw wyooyw commented Sep 17, 2024

Fix [#6545]

work:

  • expert gradient average: divide edp_world_size -> divide dp_world_size
  • unit test: make sure model with different dp/ep has same expert gradient
@wyooyw
Copy link
Contributor Author

wyooyw commented Sep 17, 2024

@microsoft-github-policy-service agree

@wyooyw wyooyw changed the title Fix Expert Grad Scaling Problem With Zero Optimizer Sep 17, 2024
@tjruwase tjruwase requested review from tohtana and removed request for loadams September 17, 2024 16:21
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @wyooyw! This looks good to me.
@inkcherry Do you have any suggestion? You worked on a similar issue in #5259.

@wyooyw
Copy link
Contributor Author

wyooyw commented Sep 18, 2024

@wyooyw It seems that you should also delete or comment https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L1072 when you delete https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L1079

Thank you for your suggestion. This redundant line of code has been removed.

@@ -1115,8 +1114,7 @@ def average_tensor(self, tensor):
curr_size += numel
prev_id, prev_process_group = partition_id, process_group

if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only grad for expert is not correct, we only need to make 'grad_reduc' divide edp_world_size -> divide dp_world_size, why we need use 'tensor' for divide, it may contain more data not only gradient ? I just feel confused about here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, there are only gradients waiting to do all-reduce in 'tensor'.

From the code, 'tensor' may be a buffer in 'self.ipg_buffer' or the gradient of 'self.extra_large_param_to_reduce' . So, 'tensor' is composed of data from one or more weight gradients, and the data pointer of 'grad_reduc' points to an address within 'tensor'.

According to the comments in the code, the logic of the old version code is:

  • Averages gradients at parameter level if ipg has a moe param, i.e. do average on 'grad_reduc'
  • Otherwise averaging is done at the entire buffer level at the end of the loop, i.e. do average on 'tensor'.

He did this because he wanted to divide the expert gradient by edp_size and the non-expert gradient by dp_size, so he must do the average at the parameter level when there is a moe param. But in our PR, we divide all weight gradients by dp_size, so we can directly do the average at the entire buffer level.

In addition, maybe I need also delete those old comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for clarification, I agree with you for deleting those old comments.

@wyooyw wyooyw force-pushed the fix_expert_weight_grad_with_zero branch from 6e1e90c to b1231c4 Compare September 18, 2024 07:01
@tohtana tohtana added this pull request to the merge queue Oct 14, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 15, 2024
@tohtana tohtana added this pull request to the merge queue Oct 17, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 18, 2024
@tohtana tohtana added this pull request to the merge queue Oct 18, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to no response for status checks Oct 18, 2024
@tohtana tohtana added this pull request to the merge queue Oct 23, 2024
Merged via the queue into microsoft:master with commit b647fb2 Oct 23, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
4 participants