Skip to content

Commit

Permalink
Fix cache for callbacks (#907)
Browse files Browse the repository at this point in the history
**Context:** callback cache should take into account input types as well
as output types. The following simple example illustrates when a
callback would have different mismatching inputs yet always returns the
same type.

```python
@pure_callback
def always_return_float(x) -> float:
    if x == 0.0:
        return x
    else:
        return x + 0.0

```

**Description of the Change:** Adds input types as part of the cache key

**Benefits:** No error.

**Possible Drawbacks:**

**Related GitHub Issues:** Fixes #851
  • Loading branch information
erick-xanadu authored Jul 5, 2024
1 parent da0ec10 commit 5d910b6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
[(#822)](https://github.com/PennyLaneAI/catalyst/pull/822)
[(#834)](https://github.com/PennyLaneAI/catalyst/pull/834)
[(#882)](https://github.com/PennyLaneAI/catalyst/pull/882)
[(#907)](https://github.com/PennyLaneAI/catalyst/pull/907)

- When using callbacks that do not return any values, such as `catalyst.debug.callback` and
`catalyst.debug.print`, these functions are marked as 'inactive' and do not contribute to or
Expand Down
6 changes: 4 additions & 2 deletions frontend/catalyst/api_extensions/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,12 @@ class MemrefCallable(FlatCallable):

CACHE = {}

def __new__(cls, func, results_aval, *_args, **_kwargs):
def __new__(cls, func, results_aval, *args, **kwargs):
# Hash-cons: https://en.wikipedia.org/wiki/Hash_consing
absargs, abskwargs = tree_map(shaped_abstractify, (args, kwargs))
flat_params, _ = tree_flatten((absargs, abskwargs))
flat_results_aval, _ = tree_flatten(results_aval)
cache_key = (func, *flat_results_aval)
cache_key = (func, *flat_params, *flat_results_aval)
if cls.CACHE.get(cache_key):
return cls.CACHE.get(cache_key)

Expand Down
60 changes: 60 additions & 0 deletions frontend/test/lit/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# RUN: %PYTHON %s | FileCheck %s

import pennylane as qml
from catalyst import pure_callback


def i(x):
return x


# CHECK-LABEL: module @one_callback_cached
@qml.qjit
# CHECK-NOT: catalyst.callback @callback
# CHECK-LABEL: func.func public @jit_one_callback_cached
def one_callback_cached(x: float):
"""Single callback is created, but called twice"""
c = pure_callback(i, float)
return c(x), c(x)


# CHECK-LABEL: catalyst.callback @callback
# CHECK-NOT: catalyst.callback @callback
print(one_callback_cached.mlir)


@pure_callback
def always_return_float(x) -> float:
if x == 0.0:
return x
else:
return x + 0.0


# CHECK-LABEL: module @test2
@qml.qjit
# CHECK-NOT: catalyst.callback @callback
# CHECK-LABEL: func.func public @jit_test2
def test2():
return always_return_float(0.0), always_return_float(1)


# CHECK-LABEL: catalyst.callback @callback
# CHECK-LABEL: catalyst.callback @callback
# CHECK-NOT: catalyst.callback @callback

print(test2.mlir)

0 comments on commit 5d910b6

Please sign in to comment.