Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dyanmo execute async #4425

Merged
merged 4 commits into from
Jan 11, 2023
Merged

Make dyanmo execute async #4425

merged 4 commits into from
Jan 11, 2023

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Jan 7, 2023

This is to implement #4402,

I verified test/dynamo/test_dynamo.py worked. Next I will try to rerun the inference and training benchmark to verified this does not regress inference and hopefully make training faster.

Update:
for training

command I used

XLA_IR_DEBUG=0 XLA_HLO_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only $model -n 10

on TPUv4 nightly with/without this change

model with this change without this change speed up
resnet50 0.758x 0.724x 1.047
resnet18 0.665x 0.653x 1.018
BERT_pytorch 1.441x 1.400x 1.029
resnext50_32x4d 0.870x 0.849x 1.025
alexnet 0.632x 0.662x 0.955
mobilenet_v2 0.549x 0.539x 1.019
mnasnet1_0 0.698x 0.669x 1.043
vgg16 0.712x 0.721x 0.988
average 1.0155

At least for the single step training, overlapping the execution and training does not help the speed too much.

FYI @wconstab @shunting314 @alanwaketan @wonjoolee95

@JackCaoG JackCaoG added the dynamo label Jan 7, 2023
@JackCaoG JackCaoG force-pushed the JackCaoG/Dyanmo_execute_async branch from ea9054d to 1bc9d72 Compare January 7, 2023 00:52
@JackCaoG JackCaoG changed the title Jack cao g/dyanmo execute async Jan 7, 2023
@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Jan 7, 2023

OK With training I am seeing some slight improvement, with master

--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -334,7 +334,12 @@ PjRtComputationClient::ExecuteComputation(
 
   std::vector<DataPtr> datas;
   datas.reserve(results.size());
+  bool waited = false;
   for (auto& result : results) {
+    if (!waited) {
+      auto status = result->GetReadyFuture().Await();
+      waited = true;
+    }
     std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);
 
     std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(

(diff is needed to make PJRT benchmark accurate)

resnet50 training output

cpu  train resnet50                           0.709x SAME

with this change

cpu  train resnet50                           0.754x SAME

so roughly a 6% improvement. I will test it on more models and post the update here.

torch_xla/csrc/xla_graph_executor.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/xla_graph_executor.cpp Show resolved Hide resolved
torch_xla/csrc/xla_graph_executor.h Show resolved Hide resolved
@@ -209,12 +209,21 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const std::vector<size_t>* indices,
DebugUtil::GraphFormat format = DebugUtil::GetDefaultGraphFormat());

void SaveOutputShapes(torch::lazy::hash_t hash,
std::vector<xla::Shape> outptu_shapes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outptu_shapes => output_shapes, and const std::vector<xla::Shape>& perhaps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outptu_shapes will be saved in a global map and outlive the stack object from the caller, I think it has to be a copy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is also std::move so we are not making an extra copy.

Copy link
Collaborator

@alanwaketan alanwaketan Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my hobby has always been explicitly making the copy in the code instead of relying on value passing. That's why I always ask. Otherwise, it's too much thinking on deciding whether a parameter should be passed by value or reference. Just easier to determine if it's const reference or r reference. And then it becomes a ownership management problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, will update.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, actually in this case we want to keep it as std::vector<xla::Shape>. This way caller can use std::move to avoid the extra copy we need to do if we only pass in a reference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, it's tricky. I'm actually fine with either ways. Just trying to share some of my thoughts here.

torch_xla/csrc/xla_graph_executor.cpp Outdated Show resolved Hide resolved
@JackCaoG
Copy link
Collaborator Author

command I used

XLA_IR_DEBUG=0 XLA_HLO_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only $model -n 10

on TPUv4 nightly with/without this change

model with this change without this change speed up
resnet50 0.758x 0.724x 1.047
resnet18 0.665x 0.653x 1.018
BERT_pytorch 1.441x 1.400x 1.029
resnext50_32x4d 0.870x 0.849x 1.025
alexnet 0.632x 0.662x 0.955
mobilenet_v2 0.549x 0.539x 1.019
mnasnet1_0 0.698x 0.669x 1.043
vgg16 0.712x 0.721x 0.988
average 1.0155

At least for the single step training, overlapping the execution and training does not help the speed too much.

@JackCaoG
Copy link
Collaborator Author

I will clean up this pr and consider making this execution async configurable.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@shunting314 shunting314 self-requested a review January 10, 2023 19:13
@JackCaoG JackCaoG merged commit 53e5d1d into master Jan 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 participants