You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
>>> x = jnp.array([0.1, 0.2, 0.3, 0.4])
>>> qjit(f)(x)
array(0.4)
But I either get ComilationPassError (I think, haven't been able to recreate without crashing lately) or the Python kernel crashing if I attempt to compute the gradient:
>>> qjit(grad(f))(x)
*crashes*
Interesting, it also errors with jax.grad:
>>> qjit(jax.grad(f))(x)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-11-8ec72a4dd125>](https://localhost:8080/#) in <cell line: 1>()
----> 1 qjit(jax.grad(f))(jnp.array([0.1, 0.2, 0.3, 0.4]))
4 frames
[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
645 return self.user_function(*args, **kwargs)
646
--> 647 function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
648 self.compiled_function, *args
649 )
[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args)
620 if not self.compiling_from_textual_ir:
621 self.mlir_module = self.get_mlir(*r_sig)
--> 622 function = self.compile()
623 else:
624 assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION
[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in compile(self)
579 qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
580
--> 581 shared_object, llvm_ir, inferred_func_data = self.compiler.run(
582 self.mlir_module, pipelines=self.compile_options.pipelines
583 )
[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run(self, mlir_module, *args, **kwargs)
399 """
400
--> 401 return self.run_from_ir(
402 mlir_module.operation.get_asm(
403 binary=False, print_generic_op_form=False, assume_verified=True
[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run_from_ir(self, ir, module_name, pipelines, lower_to_llvm)
356 print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)
357
--> 358 compiler_output = run_compiler_driver(
359 ir,
360 workspace,
RuntimeError: Compilation failed:
'tensor.extract' op incorrect number of indices for extract_element
The text was updated successfully, but these errors were encountered:
Consider
This function returns correct results:
But I either get
ComilationPassError
(I think, haven't been able to recreate without crashing lately) or the Python kernel crashing if I attempt to compute the gradient:Interesting, it also errors with
jax.grad
:The text was updated successfully, but these errors were encountered: