Skip to content

Commit

Permalink
Add CircuitPartialOutputs. (#5725)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Jun 5, 2024
1 parent c0b59b8 commit f4b7b14
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 22 deletions.
10 changes: 8 additions & 2 deletions corelib/src/circuit.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ extern fn init_circuit_data<C>() -> CircuitInputAccumulator<C> implicits(RangeCh
/// Returns the descriptor for the circuit.
extern fn get_circuit_descriptor<C>() -> CircuitDescriptor<C> nopanic;


/// The result of filling an input in the circuit instance's data.
pub enum EvalCircuitResult<C> {
/// The circuit evaluation failed.
Failure: CircuitFailureGuarantee,
Failure: (CircuitPartialOutputs<C>, CircuitFailureGuarantee),
/// The circuit was evaluated successfully.
Success: CircuitOutputs<C>,
}
Expand Down Expand Up @@ -110,13 +109,20 @@ extern type CircuitData<C>;
/// A type representing a circuit instance where the outputs are filled.
extern type CircuitOutputs<C>;

/// A type representing a circuit instance where the outputs are partially filled as
/// the evaluation of one of the inverse gates failed.
/// The type is defined for future-compatibility, there is currently no libfunc to extract
/// the partial outputs.
extern type CircuitPartialOutputs<C>;

/// A type representing a circuit descriptor.
extern type CircuitDescriptor<C>;

impl CircuitInputAccumulatorDrop<C> of Drop<CircuitInputAccumulator<C>>;
impl CircuitDataDrop<C> of Drop<CircuitData<C>>;
impl CircuitDescriptorDrop<C> of Drop<CircuitDescriptor<C>>;
impl CircuitOutputsDrop<C> of Drop<CircuitOutputs<C>>;
impl CircuitPartialOutputsDrop<C> of Drop<CircuitPartialOutputs<C>>;


/// A wrapper for circuit elements, used to construct circuits..
Expand Down
2 changes: 1 addition & 1 deletion corelib/src/test/circuit_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn test_circuit() {
.unwrap();

match circ.get_descriptor().eval(data, modulus) {
EvalCircuitResult::Failure(_) => {},
EvalCircuitResult::Failure((_, _)) => {},
EvalCircuitResult::Success(outputs) => { outputs.get_output(out1); }
}
}
3 changes: 3 additions & 0 deletions crates/cairo-lang-sierra-to-casm/src/invocations/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ fn build_circuit_eval(
&[
&[new_add_mod],
&[failure_mul_mod],
// CircuitPartialOutputs
&[values, modulus0, modulus1, modulus2, modulus3, computed_gates],
// CircuitFailureGuarantee
&[
mul_mod_offsets,
n_muls,
Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-sierra-type-size/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub fn get_type_size_map(
CoreTypeConcrete::Circuit(CircuitTypeConcrete::CircuitFailureGuarantee(_)) => Some(8),
CoreTypeConcrete::Circuit(CircuitTypeConcrete::U384LessThanGuarantee(_)) => Some(8),
CoreTypeConcrete::Circuit(CircuitTypeConcrete::CircuitOutputs(_)) => Some(5),
CoreTypeConcrete::Circuit(CircuitTypeConcrete::CircuitPartialOutputs(_)) => Some(6),
CoreTypeConcrete::Circuit(CircuitTypeConcrete::CircuitData(_))
| CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_))
| CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => Some(1),
Expand Down
58 changes: 57 additions & 1 deletion crates/cairo-lang-sierra/src/extensions/modules/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ define_type_hierarchy! {
Circuit(Circuit),
CircuitData(CircuitData),
CircuitOutputs(CircuitOutputs),
CircuitPartialOutputs(CircuitPartialOutputs),
CircuitDescriptor(CircuitDescriptor),
CircuitFailureGuarantee(CircuitFailureGuarantee),
CircuitInput(CircuitInput),
Expand Down Expand Up @@ -492,6 +493,55 @@ impl ConcreteType for ConcreteCircuitOutputs {
}
}

/// A type representing a circuit instance where the outputs are partially filled as
/// the evaluation of one of the inverse gates failed.
#[derive(Default)]
pub struct CircuitPartialOutputs {}
impl NamedType for CircuitPartialOutputs {
type Concrete = ConcreteCircuitPartialOutputs;
const ID: GenericTypeId = GenericTypeId::new_inline("CircuitPartialOutputs");

fn specialize(
&self,
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
Self::Concrete::new(context, args)
}
}

pub struct ConcreteCircuitPartialOutputs {
pub info: TypeInfo,
}

impl ConcreteCircuitPartialOutputs {
fn new(
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self, SpecializationError> {
let circ_ty = args_as_single_type(args)?;
validate_is_circuit(context, circ_ty)?;
Ok(Self {
info: TypeInfo {
long_id: ConcreteTypeLongId {
generic_id: "CircuitPartialOutputs".into(),
generic_args: args.to_vec(),
},
duplicatable: false,
droppable: true,
storable: true,
zero_sized: false,
},
})
}
}

impl ConcreteType for ConcreteCircuitPartialOutputs {
fn info(&self) -> &TypeInfo {
&self.info
}
}

/// A type whose destruction guarantees that the circuit instance invocation failed.
#[derive(Default)]
pub struct CircuitFailureGuarantee {}
Expand Down Expand Up @@ -840,11 +890,17 @@ impl SignatureAndTypeGenericLibfunc for EvalCircuitLibFuncWrapped {
vars: vec![
OutputVarInfo::new_builtin(add_mod_builtin_ty.clone(), 0),
OutputVarInfo::new_builtin(mul_mod_builtin_ty.clone(), 1),
OutputVarInfo {
ty: context.get_concrete_type(
CircuitPartialOutputs::id(),
&[GenericArg::Type(ty.clone())],
)?,
ref_info: OutputVarReferenceInfo::SimpleDerefs,
},
OutputVarInfo {
ty: context.get_concrete_type(CircuitFailureGuarantee::id(), &[])?,
ref_info: OutputVarReferenceInfo::SimpleDerefs,
},
// TODO(ilya): Add CircuitFailedEvalOutputs.
],

ap_change: SierraApChange::Known { new_vars_only: false },
Expand Down
52 changes: 34 additions & 18 deletions tests/e2e_test_data/libfuncs/circuit
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,21 @@ fn foo(
[ap + -1] = [[fp + -12] + 6];
[fp + -8] = [ap + 0] + [ap + -1], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 18;
jmp rel 24;
[ap + 0] = [ap + -2] * 7, ap++;
[ap + 0] = [fp + -13] + 7, ap++;
[ap + 0] = [fp + -12] + [ap + -2], ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [ap + -7], ap++;
[ap + 0] = [fp + -6], ap++;
[ap + 0] = [fp + -5], ap++;
[ap + 0] = [fp + -4], ap++;
[ap + 0] = [fp + -3], ap++;
[ap + 0] = [ap + -11], ap++;
[ap + 0] = [fp + -9], ap++;
[ap + 0] = [fp + -8], ap++;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = [ap + -10], ap++;
[ap + 0] = [ap + -14], ap++;
[ap + 0] = [ap + -16], ap++;
[ap + 0] = [fp + -6], ap++;
[ap + 0] = [fp + -5], ap++;
[ap + 0] = [fp + -4], ap++;
Expand All @@ -314,7 +320,13 @@ ap += 1;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [ap + -10], ap++;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [ap + -16], ap++;
[ap + 0] = [fp + -6], ap++;
[ap + 0] = [fp + -5], ap++;
[ap + 0] = [fp + -4], ap++;
Expand All @@ -323,9 +335,11 @@ ret;

//! > sierra_code
type BoundedInt<0, 0> = BoundedInt<0, 0> [storable: true, drop: true, dup: true, zero_sized: false];
type CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> = CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> [storable: true, drop: true, dup: false, zero_sized: false];
type CircuitFailureGuarantee = CircuitFailureGuarantee [storable: true, drop: false, dup: false, zero_sized: false];
type Tuple<CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee> = Struct<ut@Tuple, CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee> [storable: true, drop: false, dup: false, zero_sized: false];
type CircuitOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> = CircuitOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> [storable: true, drop: true, dup: false, zero_sized: false];
type core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> = Enum<ut@core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee, CircuitOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>> [storable: true, drop: false, dup: false, zero_sized: false];
type core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> = Enum<ut@core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, Tuple<CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee>, CircuitOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>> [storable: true, drop: false, dup: false, zero_sized: false];
type Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)> = Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)> [storable: true, drop: true, dup: true, zero_sized: false];
type core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>> = AddModGate<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>> [storable: false, drop: false, dup: false, zero_sized: false];
type core::circuit::CircuitInput::<2> = CircuitInput<2> [storable: false, drop: false, dup: false, zero_sized: false];
Expand All @@ -349,6 +363,7 @@ libfunc store_temp<BoundedInt<0, 0>> = store_temp<BoundedInt<0, 0>>;
libfunc store_temp<BoundedInt<1, 1>> = store_temp<BoundedInt<1, 1>>;
libfunc eval_circuit<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>> = eval_circuit<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>;
libfunc branch_align = branch_align;
libfunc struct_construct<Tuple<CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee>> = struct_construct<Tuple<CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee>>;
libfunc enum_init<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, 0> = enum_init<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, 0>;
libfunc store_temp<AddMod> = store_temp<AddMod>;
libfunc store_temp<MulMod> = store_temp<MulMod>;
Expand All @@ -359,24 +374,25 @@ const_as_immediate<Const<BoundedInt<0, 0>, 0>>() -> ([5]); // 0
const_as_immediate<Const<BoundedInt<1, 1>, 1>>() -> ([6]); // 1
store_temp<BoundedInt<0, 0>>([5]) -> ([5]); // 2
store_temp<BoundedInt<1, 1>>([6]) -> ([6]); // 3
eval_circuit<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>([0], [1], [2], [3], [4], [5], [6]) { fallthrough([7], [8], [9]) 11([10], [11], [12]) }; // 4
eval_circuit<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>([0], [1], [2], [3], [4], [5], [6]) { fallthrough([7], [8], [9], [10]) 12([11], [12], [13]) }; // 4
branch_align() -> (); // 5
enum_init<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, 0>([9]) -> ([13]); // 6
store_temp<AddMod>([7]) -> ([7]); // 7
store_temp<MulMod>([8]) -> ([8]); // 8
store_temp<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>>([13]) -> ([13]); // 9
return([7], [8], [13]); // 10
branch_align() -> (); // 11
enum_init<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, 1>([12]) -> ([14]); // 12
store_temp<AddMod>([10]) -> ([10]); // 13
store_temp<MulMod>([11]) -> ([11]); // 14
store_temp<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>>([14]) -> ([14]); // 15
return([10], [11], [14]); // 16
struct_construct<Tuple<CircuitPartialOutputs<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, CircuitFailureGuarantee>>([9], [10]) -> ([14]); // 6
enum_init<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, 0>([14]) -> ([15]); // 7
store_temp<AddMod>([7]) -> ([7]); // 8
store_temp<MulMod>([8]) -> ([8]); // 9
store_temp<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>>([15]) -> ([15]); // 10
return([7], [8], [15]); // 11
branch_align() -> (); // 12
enum_init<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, 1>([13]) -> ([16]); // 13
store_temp<AddMod>([11]) -> ([11]); // 14
store_temp<MulMod>([12]) -> ([12]); // 15
store_temp<core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>>([16]) -> ([16]); // 16
return([11], [12], [16]); // 17

test::foo@0([0]: AddMod, [1]: MulMod, [2]: CircuitDescriptor<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, [3]: CircuitData<Circuit<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>, [4]: NonZero<core::circuit::u384>) -> (AddMod, MulMod, core::circuit::EvalCircuitResult::<core::circuit::Circuit::<(core::circuit::AddModGate::<core::circuit::CircuitInput::<0>, core::circuit::CircuitInput::<1>>, core::circuit::CircuitInput::<2>)>>);

//! > function_costs
test::foo: OrderedHashMap({AddMod: 1, MulMod: 3, Const: 3610})
test::foo: OrderedHashMap({AddMod: 1, MulMod: 3, Const: 4210})

//! > ==========================================================================

Expand Down

0 comments on commit f4b7b14

Please sign in to comment.