Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing gather batching rule not implemented when multiple qjits call them #1096

Merged
merged 5 commits into from
Sep 4, 2024

Conversation

paul0403
Copy link
Contributor

@paul0403 paul0403 commented Sep 4, 2024

Context:
There is a bug if multiple qjits call into gather primitive #1094 . With how we currently patch them, jax retains the first custom gather2_p object and keeps using that in the second run.

Description of the Change:
Making sure there is only ever one gather2_p primitive object by pulling it out into patches.py as a global.
This way, jax will not complain about "gather batching rule not find", since we patch with the same gather2_p object every qjit run.

Benefits:
Multiple qjits in the same program can call into gather primitives. Although in the standard recommended usage there is just one qjit in the overall entry function, this unblocks potential pytests (where multiple qjits cannot be avoided).

Related GitHub Issues: closes #1094 , closes #894

[sc-72775]

…lling it out into patches.py as a global.

This way, jax will not complain about "gather batching rule not find", since we patch with the same gather2_p object every qjit run.
Copy link
Contributor

@joeycarter joeycarter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks @paul0403!

Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are just moving the standard_primitive to the global scope, right? That sounds good. Do you know if this also fixes these tests?

test/pytest/test_callback.py:@pytest.mark.parametrize("order", ["good", "bad"])
test/pytest/test_callback.py:def test_vjp_as_residual(arg, order):
test/pytest/test_callback.py:    if order == "bad":
test/pytest/test_callback.py:    if order == "bad":
test/pytest/test_callback.py:@pytest.mark.parametrize("order", ["good", "bad"])
test/pytest/test_callback.py:def test_vjp_as_residual_automatic(arg, order):
test/pytest/test_callback.py:    if order == "bad":
test/pytest/test_callback.py:    if order == "bad":

Copy link

codecov bot commented Sep 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.65%. Comparing base (fc5c02e) to head (351185e).
Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1096   +/-   ##
=======================================
  Coverage   97.65%   97.65%           
=======================================
  Files          76       76           
  Lines       10769    10769           
  Branches     1245     1245           
=======================================
  Hits        10517    10517           
  Misses        203      203           
  Partials       49       49           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@paul0403
Copy link
Contributor Author

paul0403 commented Sep 4, 2024

You are just moving the standard_primitive to the global scope, right? That sounds good. Do you know if this also fixes these tests?

test/pytest/test_callback.py:@pytest.mark.parametrize("order", ["good", "bad"])
test/pytest/test_callback.py:def test_vjp_as_residual(arg, order):
test/pytest/test_callback.py:    if order == "bad":
test/pytest/test_callback.py:    if order == "bad":
test/pytest/test_callback.py:@pytest.mark.parametrize("order", ["good", "bad"])
test/pytest/test_callback.py:def test_vjp_as_residual_automatic(arg, order):
test/pytest/test_callback.py:    if order == "bad":
test/pytest/test_callback.py:    if order == "bad":

just checked, these are fixed as well

@erick-xanadu
Copy link
Contributor

erick-xanadu commented Sep 4, 2024

@paul0403 if you want, you can remove the skips in this PR or in another one that follows. Happy with either.

@paul0403
Copy link
Contributor Author

paul0403 commented Sep 4, 2024

@paul0403 if you want, you can remove the skips in this PR or in another one that follows. Happy with either.

I'll just unskip them here.

@paul0403 paul0403 merged commit 2d36275 into main Sep 4, 2024
42 checks passed
@paul0403 paul0403 deleted the gather_rule_reuse branch September 4, 2024 17:37
rauletorresc pushed a commit that referenced this pull request Sep 11, 2024
…l them (#1096)

**Context:**
There is a bug if multiple qjits call into `gather` primitive #1094 .
With how we currently patch them, jax retains the first custom
`gather2_p` object and keeps using that in the second run.

**Description of the Change:**
Making sure there is only ever one `gather2_p` primitive object by
pulling it out into patches.py as a global.
This way, jax will not complain about "gather batching rule not find",
since we patch with the same gather2_p object every qjit run.

**Benefits:**
Multiple qjits in the same program can call into gather primitives.
Although in the standard recommended usage there is just one qjit in the
overall entry function, this unblocks potential pytests (where multiple
qjits cannot be avoided).

**Related GitHub Issues:** closes #1094 , closes #894

[sc-72775]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants