-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing gather batching rule not implemented
when multiple qjits call them
#1096
Conversation
…lling it out into patches.py as a global. This way, jax will not complain about "gather batching rule not find", since we patch with the same gather2_p object every qjit run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks @paul0403!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are just moving the standard_primitive to the global scope, right? That sounds good. Do you know if this also fixes these tests?
test/pytest/test_callback.py:@pytest.mark.parametrize("order", ["good", "bad"])
test/pytest/test_callback.py:def test_vjp_as_residual(arg, order):
test/pytest/test_callback.py: if order == "bad":
test/pytest/test_callback.py: if order == "bad":
test/pytest/test_callback.py:@pytest.mark.parametrize("order", ["good", "bad"])
test/pytest/test_callback.py:def test_vjp_as_residual_automatic(arg, order):
test/pytest/test_callback.py: if order == "bad":
test/pytest/test_callback.py: if order == "bad":
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1096 +/- ##
=======================================
Coverage 97.65% 97.65%
=======================================
Files 76 76
Lines 10769 10769
Branches 1245 1245
=======================================
Hits 10517 10517
Misses 203 203
Partials 49 49 ☔ View full report in Codecov by Sentry. |
just checked, these are fixed as well |
@paul0403 if you want, you can remove the skips in this PR or in another one that follows. Happy with either. |
I'll just unskip them here. |
…l them (#1096) **Context:** There is a bug if multiple qjits call into `gather` primitive #1094 . With how we currently patch them, jax retains the first custom `gather2_p` object and keeps using that in the second run. **Description of the Change:** Making sure there is only ever one `gather2_p` primitive object by pulling it out into patches.py as a global. This way, jax will not complain about "gather batching rule not find", since we patch with the same gather2_p object every qjit run. **Benefits:** Multiple qjits in the same program can call into gather primitives. Although in the standard recommended usage there is just one qjit in the overall entry function, this unblocks potential pytests (where multiple qjits cannot be avoided). **Related GitHub Issues:** closes #1094 , closes #894 [sc-72775]
Context:
There is a bug if multiple qjits call into
gather
primitive #1094 . With how we currently patch them, jax retains the first customgather2_p
object and keeps using that in the second run.Description of the Change:
Making sure there is only ever one
gather2_p
primitive object by pulling it out into patches.py as a global.This way, jax will not complain about "gather batching rule not find", since we patch with the same gather2_p object every qjit run.
Benefits:
Multiple qjits in the same program can call into gather primitives. Although in the standard recommended usage there is just one qjit in the overall entry function, this unblocks potential pytests (where multiple qjits cannot be avoided).
Related GitHub Issues: closes #1094 , closes #894
[sc-72775]