-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix expert grad scaling problem with ZeRO optimizer #6546
Fix expert grad scaling problem with ZeRO optimizer #6546
Conversation
@microsoft-github-policy-service agree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @wyooyw! This looks good to me.
@inkcherry Do you have any suggestion? You worked on a similar issue in #5259.
@wyooyw It seems that you should also delete or comment https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L1072 when you delete https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L1079 |
Thank you for your suggestion. This redundant line of code has been removed. |
@@ -1115,8 +1114,7 @@ def average_tensor(self, tensor): | |||
curr_size += numel | |||
prev_id, prev_process_group = partition_id, process_group | |||
|
|||
if not self.ipg_bucket_has_moe_params: | |||
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only grad for expert is not correct, we only need to make 'grad_reduc' divide edp_world_size -> divide dp_world_size, why we need use 'tensor' for divide, it may contain more data not only gradient ? I just feel confused about here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my understanding, there are only gradients waiting to do all-reduce in 'tensor'.
From the code, 'tensor' may be a buffer in 'self.ipg_buffer' or the gradient of 'self.extra_large_param_to_reduce' . So, 'tensor' is composed of data from one or more weight gradients, and the data pointer of 'grad_reduc' points to an address within 'tensor'.
According to the comments in the code, the logic of the old version code is:
- Averages gradients at parameter level if ipg has a moe param, i.e. do average on 'grad_reduc'
- Otherwise averaging is done at the entire buffer level at the end of the loop, i.e. do average on 'tensor'.
He did this because he wanted to divide the expert gradient by edp_size and the non-expert gradient by dp_size, so he must do the average at the parameter level when there is a moe param. But in our PR, we divide all weight gradients by dp_size, so we can directly do the average at the entire buffer level.
In addition, maybe I need also delete those old comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for clarification, I agree with you for deleting those old comments.
6e1e90c
to
b1231c4
Compare
Fix [#6545]
work: