Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Can't run jax.scipy.linalg.solve and jax.scipy.linalg.expm from separate @qjit blocks within the same program #1094

Closed
joeycarter opened this issue Sep 3, 2024 · 1 comment · Fixed by #1096
Labels
bug Something isn't working

Comments

@joeycarter
Copy link
Contributor

If you try to run solve and expm from separate @qjit blocks within the same program, you get a Batching rule for 'gather' not implemented error. For example:

import numpy as np

import jax.numpy as jnp
import jax.scipy as jsp

from catalyst import qjit


rng = np.random.default_rng(42)


def bad():
    A = jnp.array(rng.uniform(-1, 1, (3, 3)))
    b = jnp.array(rng.uniform(-1, 1, (3, 1)))

    @qjit
    def f(A, b):
        return jsp.linalg.solve(A, b)
    
    @qjit
    def g(A):
        jsp.linalg.expm(A)

    return f(A, b), g(A)


solve_result, expm_result = bad()
print(solve_result)
print(expm_result)

results in:

Traceback (most recent call last):
  File ".../test.py", line 73, in <module>
    solve_result, expm_result = bad()
  File ".../test.py", line 70, in bad
    return f(A, b), g(A)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 457, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 528, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File ".../work/pennylane/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 610, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jax_tracer.py", line 536, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/test.py", line 64, in f
    return jsp.linalg.solve(A, b)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 1022, in solve
    return _solve(a, b, assume_a, lower)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 944, in _solve
    return jnp.linalg.solve(a, b)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/linalg.py", line 1320, in solve
    return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 321, in wrapped
    result = vectorized_func(*squeezed_args)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 138, in wrapped
    out = func(*args)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 321, in wrapped
    result = vectorized_func(*squeezed_args)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 138, in wrapped
    out = func(*args)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 181, in new_func
    return func(*args, **kwargs, **static_kwargs)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py", line 739, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py", line 352, in _getitem
    return lax_numpy._rewriting_take(self, item)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 5616, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 5646, in _gather
    y = lax.gather(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Batching rule for 'gather' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../work/pennylane/catalyst/frontend/test.py", line 73, in <module>
    solve_result, expm_result = bad()
  File ".../work/pennylane/catalyst/frontend/test.py", line 70, in bad
    return f(A, b), g(A)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 457, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 528, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File ".../work/pennylane/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 610, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jax_tracer.py", line 536, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2362, in trace_to_jaxpr_dynamic2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2377, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File ".../work/pennylane/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File ".../work/pennylane/catalyst/frontend/test.py", line 68, in g
    jsp.linalg.expm(A)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 1163, in expm
    R = lax.cond(n_squarings > max_squarings, _nan, _compute, (A, P, Q))
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 1159, in _compute
    R = _solve_P_Q(P, Q, upper_triangular)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 1201, in _solve_P_Q
    return jnp.linalg.solve(Q, P)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/linalg.py", line 1320, in solve
    return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 321, in wrapped
    result = vectorized_func(*squeezed_args)
  File ".../.conda/envs/xanadu/lib/python3.9/site-packages/jax/_src/numpy/vectorize.py", line 138, in wrapped
    out = func(*args)
NotImplementedError: Batching rule for 'gather' not implemented

Interestingly, if you call solve and expm within the same @qjit block, then it works:

import numpy as np

import jax.numpy as jnp
import jax.scipy as jsp

from catalyst import qjit


rng = np.random.default_rng(42)


def good():
    A = jnp.array(rng.uniform(-1, 1, (3, 3)))
    b = jnp.array(rng.uniform(-1, 1, (3, 1)))

    @qjit
    def f(A, b):
        return jsp.linalg.solve(A, b), jsp.linalg.expm(A)

    return f(A, b)
>>> solve_result, expm_result = good()
>>> print(solve_result)
[[ 0.76916628]
 [-0.19792595]
 [-0.75970703]]
>>> print(expm_result)
[[1.94944403 0.0351314  0.73593198]
 [0.62005861 0.57107875 0.63147519]
 [0.66094589 0.28554113 0.77824354]]

Note that I've run these examples with a modified version of Catalyst that provides the BLAS and LAPACK wrapper functions required by JAX.

@joeycarter joeycarter added the bug Something isn't working label Sep 3, 2024
@joeycarter
Copy link
Contributor Author

@paul0403 suggested a good first place to look might be in make_jaxpr2.

paul0403 added a commit that referenced this issue Sep 4, 2024
…l them (#1096)

**Context:**
There is a bug if multiple qjits call into `gather` primitive #1094 .
With how we currently patch them, jax retains the first custom
`gather2_p` object and keeps using that in the second run.

**Description of the Change:**
Making sure there is only ever one `gather2_p` primitive object by
pulling it out into patches.py as a global.
This way, jax will not complain about "gather batching rule not find",
since we patch with the same gather2_p object every qjit run.

**Benefits:**
Multiple qjits in the same program can call into gather primitives.
Although in the standard recommended usage there is just one qjit in the
overall entry function, this unblocks potential pytests (where multiple
qjits cannot be avoided).

**Related GitHub Issues:** closes #1094 , closes #894

[sc-72775]
rauletorresc pushed a commit that referenced this issue Sep 11, 2024
…l them (#1096)

**Context:**
There is a bug if multiple qjits call into `gather` primitive #1094 .
With how we currently patch them, jax retains the first custom
`gather2_p` object and keeps using that in the second run.

**Description of the Change:**
Making sure there is only ever one `gather2_p` primitive object by
pulling it out into patches.py as a global.
This way, jax will not complain about "gather batching rule not find",
since we patch with the same gather2_p object every qjit run.

**Benefits:**
Multiple qjits in the same program can call into gather primitives.
Although in the standard recommended usage there is just one qjit in the
overall entry function, this unblocks potential pytests (where multiple
qjits cannot be avoided).

**Related GitHub Issues:** closes #1094 , closes #894

[sc-72775]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant