-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to AG for Jax array single assignment #717
Conversation
559a59f
to
2c22446
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Raul, we should also test setting a dynamic index :)
2fd85a8
to
3710821
Compare
Done |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #717 +/- ##
=======================================
Coverage 98.08% 98.08%
=======================================
Files 69 69
Lines 9468 9474 +6
Branches 746 747 +1
=======================================
+ Hits 9287 9293 +6
Misses 147 147
Partials 34 34 ☔ View full report in Codecov by Sentry. |
3710821
to
d85a5b8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me :) Just needs a changelog entry
I was thinking of adding such entry once all in place array assignment cases are covered. What do you think? |
You can still create one now, since this PR should be attached to the changelog entry anyways. |
|
d85a5b8
to
8d08ebf
Compare
Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good 👍
8d08ebf
to
93d8705
Compare
93d8705
to
cc840ee
Compare
…es (#1143) **Context:** #717 added support for converting in-place array updates (`arr[i] = x`) into the equivalent JAX traceable code (`arr.at[i].set(x)`). This PR extends that support to operator assignment array updates. **Description of the Change:** - Add new Autograph converter to map `AugAssign` ast nodes assigning to a single index or a slice subscript to calls to `update_item_with_op` - Implement `update_item_with_op` method that map to the corresponding `jax.numpy.ndarray.at` equivalent methods for JAX arrays and the normal Python operator assignment otherwise - Overload `transform_ast` in `CatalystTransformer` to invoke the new converter **Benefits:** We can use `arr[i] += x` instead of `arr.at[i].add(x)`. **Possible Drawbacks:** It would be cleaner to have the new converter live in the DiastaticMalt project. **Related GitHub Issues:** #757 **Based on the solution presented in this PR:** #769 Note that this PR was originally implemented externally by #769. This PR aims to revisit that PR. --------- Co-authored-by: Spencer Comin <scomin@me.com>
Context: We want to support NumPy-style in-place array updates, such as arr[i] = x
Description of the Change: Overload set_item function from Autograph and enable converter.Feature.LISTS option.
Benefits: We can use directly arr[i] = x instead of arr = arr.at[i].set(x)
Related GitHub Issues: #516
Based on the solution presented in this PR: #582
[sc-60313]