-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement missing XLASymNodeImpl::Sub #4551
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
fbfed72
to
908d4f6
Compare
torch_xla/csrc/ops/dynamic_ir.cpp
Outdated
@@ -96,6 +96,37 @@ XlaOpVector SizeAdd::Lower(LoweringContext* loctx) const { | |||
return ReturnOp((input1 + input2), loctx); | |||
} | |||
|
|||
SizeSub::SizeSub(torch::lazy::Value a, torch::lazy::Value b) | |||
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString( | |||
"aten::sub")}, // TODO: should it be something like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel the OpKind here can be more specific such as "aten::size_sub", as opposed to the general ones such as "aten::sub". wdyt? @miladm @JackCaoG @alanwaketan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c10::Symbol is something defined in upstream. If you can find the corresponding enum in upstream, I guess it's better. Yea, we don't store "aten::sub" as a string. It's an enum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the enums are defined in pytorch/aten/src/ATen/core/interned_strings.h? size_sub
is not one of the enums but I could still compile locally and run the test.
@@ -96,6 +96,36 @@ XlaOpVector SizeAdd::Lower(LoweringContext* loctx) const { | |||
return ReturnOp((input1 + input2), loctx); | |||
} | |||
|
|||
SizeSub::SizeSub(torch::lazy::Value a, torch::lazy::Value b) | |||
: XlaNode( | |||
torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_sub")}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I am ok with either way as long as it runs, but if you want to update this for SizeSub, you need to update this for all nodes including SizeAdd
above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the code, it looks like c10::Symbol would just assign any unseen string an incremental ID. I guess the problem I'm having is that there is no straightforward way you can do a comparison later on. For any aten ops, you can simply do:
op == at::aten::sub
For any custom op, I guess you have to do:
op == c10::Symbol::fromQualString("aten::size_sub")
Unless you could define an enum like structure in our end. So what benefits do you think we get from aten::size_sub?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems semantically cleaner to me: at::aten::sub
feels like we are doing subtraction between 2 tensors, but here it's between 2 shape dimension (eg tensor.shape[0]). wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the code perspective, yes. But I wonder if this helps anything practical for instance, LazyIR dump or HLO dump.
2685418
to
af7f1d8
Compare
* implementing xlasymnodeimpl::sub * the code buillds fine and the test passed. * fix lnter * change opkind to be more specific * check counter in the test. * fix linter * unify the op kind * fix linter * fix a failing test.
This is to unblock the failing dynamic shape tests when functionalization is enabled.