Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VJP/JVP support pytree #501

Merged
merged 21 commits into from
Feb 12, 2024
Merged

VJP/JVP support pytree #501

merged 21 commits into from
Feb 12, 2024

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Feb 9, 2024

Context:

Following #500, we aim to add support for arbitrary return of functions for VJP and JVP.

Description of the Change:

  • JVP and VJP are updated to support pytree as return.
  • Clean the tests.

@rmoyard
Copy link
Contributor Author

rmoyard commented Feb 9, 2024

[sc-55113]

@rmoyard rmoyard marked this pull request as ready for review February 9, 2024 21:55
Base automatically changed from gradient_pytree to main February 12, 2024 16:25
Copy link

codecov bot commented Feb 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (69634c4) 99.55% compared to head (f02523f) 99.55%.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #501   +/-   ##
=======================================
  Coverage   99.55%   99.55%           
=======================================
  Files          43       43           
  Lines        7786     7802   +16     
  Branches      540      542    +2     
=======================================
+ Hits         7751     7767   +16     
  Misses         18       18           
  Partials       17       17           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too familiar with JVPs not VJPs, so some more comments would be nice. But the code looks good! I.e., why is there a midpoint in one of the options, and why is the shape of the VJP the same as the parameters?

@rmoyard
Copy link
Contributor Author

rmoyard commented Feb 12, 2024

@erick-xanadu It is because the JVP have the same shape as the returns, where VJP have the same shape as the parameters.

@rmoyard rmoyard merged commit 45351e2 into main Feb 12, 2024
35 checks passed
@rmoyard rmoyard deleted the vjp_jvp_pytree branch February 12, 2024 20:27
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the behaviour match JAX, is it 1-1 or are there certain deviations?

else:
func_res = results[: len(jaxpr.out_avals)]
vjps = results[len(jaxpr.out_avals) :]
results = tuple([*func_res, tuple(vjps)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This structure seems a bit strange no? The function results are expanded but the vjps are in another tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the general question: our vjp is very different from Jax vjp https://jax.readthedocs.io/en/latest/_autosummary/jax.vjp.html where they return.

res, f_vjp = tuple(res, f_vjp)

Here res are unflatten, after that you need to use the function to get the vjps

vjps = f_vjp(cot)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true about the vjp, I was mainly thinking of the PyTree behaviour for inputs, outputs, tangents, cotangents, and gradients. Those should ideally all match JAX's version.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the vjp difference, I think we should still return a tuple of (results, gradients) just like for the jvp, because like you say we use the same function style for both.

frontend/catalyst/pennylane_extensions.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants