Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Array slicing with strides causes gradients to crash #305

Closed
josh146 opened this issue Oct 6, 2023 · 1 comment
Closed

[BUG] Array slicing with strides causes gradients to crash #305

josh146 opened this issue Oct 6, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@josh146
Copy link
Member

josh146 commented Oct 6, 2023

Consider

def f(x):
    return jnp.sum(x[::2])

This function returns correct results:

>>> x = jnp.array([0.1, 0.2, 0.3, 0.4])
>>> qjit(f)(x)
array(0.4)

But I either get ComilationPassError (I think, haven't been able to recreate without crashing lately) or the Python kernel crashing if I attempt to compute the gradient:

>>> qjit(grad(f))(x)
*crashes*

Interesting, it also errors with jax.grad:

>>> qjit(jax.grad(f))(x)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-11-8ec72a4dd125>](https://localhost:8080/#) in <cell line: 1>()
----> 1 qjit(jax.grad(f))(jnp.array([0.1, 0.2, 0.3, 0.4]))

4 frames

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    645             return self.user_function(*args, **kwargs)
    646 
--> 647         function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
    648             self.compiled_function, *args
    649         )

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args)
    620             if not self.compiling_from_textual_ir:
    621                 self.mlir_module = self.get_mlir(*r_sig)
--> 622             function = self.compile()
    623         else:
    624             assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in compile(self)
    579             qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
    580 
--> 581             shared_object, llvm_ir, inferred_func_data = self.compiler.run(
    582                 self.mlir_module, pipelines=self.compile_options.pipelines
    583             )

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run(self, mlir_module, *args, **kwargs)
    399         """
    400 
--> 401         return self.run_from_ir(
    402             mlir_module.operation.get_asm(
    403                 binary=False, print_generic_op_form=False, assume_verified=True

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run_from_ir(self, ir, module_name, pipelines, lower_to_llvm)
    356             print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)
    357 
--> 358         compiler_output = run_compiler_driver(
    359             ir,
    360             workspace,

RuntimeError: Compilation failed:
'tensor.extract' op incorrect number of indices for extract_element
@josh146 josh146 added the bug Something isn't working label Oct 6, 2023
@dime10 dime10 added dependencies Pull requests that update a dependency file and removed dependencies Pull requests that update a dependency file labels Nov 8, 2023
@dime10
Copy link
Collaborator

dime10 commented Mar 25, 2024

This bug has been fixed #552

@dime10 dime10 closed this as completed Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants