-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make dyanmo execute async #4425
Conversation
ea9054d
to
1bc9d72
Compare
OK With training I am seeing some slight improvement, with master
(diff is needed to make PJRT benchmark accurate) resnet50 training output
with this change
so roughly a 6% improvement. I will test it on more models and post the update here. |
@@ -209,12 +209,21 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { | |||
const std::vector<size_t>* indices, | |||
DebugUtil::GraphFormat format = DebugUtil::GetDefaultGraphFormat()); | |||
|
|||
void SaveOutputShapes(torch::lazy::hash_t hash, | |||
std::vector<xla::Shape> outptu_shapes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outptu_shapes
=> output_shapes
, and const std::vector<xla::Shape>&
perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outptu_shapes
will be saved in a global map and outlive the stack object from the caller, I think it has to be a copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is also std::move
so we are not making an extra copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my hobby has always been explicitly making the copy in the code instead of relying on value passing. That's why I always ask. Otherwise, it's too much thinking on deciding whether a parameter should be passed by value or reference. Just easier to determine if it's const reference or r reference. And then it becomes a ownership management problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg, will update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, actually in this case we want to keep it as std::vector<xla::Shape>
. This way caller can use std::move
to avoid the extra copy we need to do if we only pass in a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, it's tricky. I'm actually fine with either ways. Just trying to share some of my thoughts here.
command I used
on TPUv4 nightly with/without this change
At least for the single step training, overlapping the execution and training does not help the speed too much. |
I will clean up this pr and consider making this execution async configurable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This is to implement #4402,
I verified
test/dynamo/test_dynamo.py
worked. Next I will try to rerun the inference and training benchmark to verified this does not regress inference and hopefully make training faster.Update:
for training
command I used
on TPUv4 nightly with/without this change
At least for the single step training, overlapping the execution and training does not help the speed too much.
FYI @wconstab @shunting314 @alanwaketan @wonjoolee95