-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
first version of flash_attention for jax #19743
Conversation
Thanks for the PR! Have you tried to time it on GPU compared to regular attention? I was under the impression that we were going to need a custom Pallas kernel for this. |
I have used /keras/src/layers/attention/ directory as a template for implementing a flash attention but I don't understand how the mask is generated in the benchmark. I need one but I don't see it |
Hi @fchollet Can you please review this PR? Thank you! |
@@ -76,6 +82,44 @@ def relu6(x): | |||
return Relu6().symbolic_call(x) | |||
return backend.nn.relu6(x) | |||
|
|||
@keras_export(["keras.ops.flash_attention", "keras.ops.nn.flash_attention"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello there,
That's a very wonderful addition. I've a bit of doubt there:
I suppose ops.flash_attention
is an operation, while nn.flash_attention
is just a neural network layer.
The basic difference between these two is that - an operation may not have any trainable parameter with it, while a neural network layer should have trainable parameters.
Am I right till now?
If yes, please provide separate examples of each one of them in the docs!
Best Regards,
Abhas Kumar Sinha
Returns: | ||
A tensor with the same shape as `x`. | ||
|
||
Example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enclose the Example with "```".
"""python
This is an example documentation.
Example:
```
example = example()
```
"""
This helps automated doc renderers to automatically find out code examples from the program docs and render those parts accordingly.
An alternative contribution has been merged. Thanks for the PR in any case! |
This is my first version of the flash attention implementation .It is just for Jax.