Skip to content

Commit

Permalink
Made bounded-int div_rem rhs - non-zero.
Browse files Browse the repository at this point in the history
commit-id:170bf0ef
  • Loading branch information
orizi committed Jun 5, 2024
1 parent f4b7b14 commit db8aa7f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
5 changes: 3 additions & 2 deletions corelib/src/test/integer_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2003,14 +2003,15 @@ mod bounded_int {
type RemT;
}
extern fn bounded_int_div_rem<T1, T2, impl DRR: DivRemRes<T1, T2>>(
a: T1, b: T2
a: T1, b: NonZero<T2>
) -> (DRR::DivT, DRR::RemT) implicits(RangeCheck) nopanic;
extern fn bounded_int_wrap_non_zero<T>(v: T) -> NonZero<T> nopanic;

/// Same as `bounded_int_div_rem`, but unwraps the result into felt252s.
fn bounded_int_div_rem_unwrapped<T1, T2, impl DRR: DivRemRes<T1, T2>>(
a: T1, b: T2
) -> (felt252, felt252) {
let (q, r) = bounded_int_div_rem(a, b);
let (q, r) = bounded_int_div_rem(a, bounded_int_wrap_non_zero(b));
(upcast(q), upcast(r))
}

Expand Down
11 changes: 6 additions & 5 deletions crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,23 +163,24 @@ impl NamedLibfunc for BoundedIntDivRemLibfunc {
let (lhs, rhs) = args_as_two_types(args)?;
let lhs_range = Range::from_type(context, lhs.clone())?;
let rhs_range = Range::from_type(context, rhs.clone())?;
// Supporting only division of a non-negative number by a positive number.
// Supporting only division of a non-negative number by a positive number (non zero + and
// non negative).
// TODO(orizi): Consider relaxing the constraint, and defining the div_rem of negatives.
if lhs_range.lower.is_negative() || !rhs_range.lower.is_positive() {
if lhs_range.lower.is_negative() || rhs_range.lower.is_negative() {
return Err(SpecializationError::UnsupportedGenericArg);
}
// Making sure the algorithm is runnable.
if BoundedIntDivRemAlgorithm::try_new(&lhs_range, &rhs_range).is_none() {
return Err(SpecializationError::UnsupportedGenericArg);
}
let quotient_min = lhs_range.lower / (&rhs_range.upper - 1);
let quotient_max = (&lhs_range.upper - 1) / rhs_range.lower;
let quotient_max = (&lhs_range.upper - 1) / std::cmp::max(rhs_range.lower, BigInt::one());
let range_check_type = context.get_concrete_type(RangeCheckType::id(), &[])?;
Ok(LibfuncSignature::new_non_branch_ex(
vec![
ParamSignature::new(range_check_type.clone()).with_allow_add_const(),
ParamSignature::new(lhs.clone()),
ParamSignature::new(rhs.clone()),
ParamSignature::new(nonzero_ty(context, &rhs)?),
],
vec![
OutputVarInfo::new_builtin(range_check_type.clone(), 0),
Expand Down Expand Up @@ -240,7 +241,7 @@ impl BoundedIntDivRemAlgorithm {
/// Assumption: `lhs` is non-negative and `rhs` is positive.
pub fn try_new(lhs: &Range, rhs: &Range) -> Option<Self> {
let prime = Felt252::prime().to_bigint().unwrap();
let q_max = (&lhs.upper - 1) / &rhs.lower;
let q_max = (&lhs.upper - 1) / std::cmp::max(&rhs.lower, &BigInt::one());
let u128_limit = BigInt::one().shl(128);
// `q` is range checked in all algorithm variants, so `q_max` must be smaller than `2**128`.
require(q_max < u128_limit)?;
Expand Down
21 changes: 12 additions & 9 deletions tests/e2e_test_data/libfuncs/bounded_int
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ type DivRemType = (
BoundedInt<0, 7>,
);

extern fn bounded_int_div_rem<T1, T2>(a: T1, b: T2) -> DivRemType implicits(RangeCheck) nopanic;
extern fn bounded_int_div_rem<T1, T2>(a: T1, b: NonZero<T2>) -> DivRemType implicits(RangeCheck) nopanic;

fn foo(a: BoundedInt<128, 255>, b: BoundedInt<3, 8>) -> DivRemType {
fn foo(a: BoundedInt<128, 255>, b: NonZero<BoundedInt<3, 8>>) -> DivRemType {
bounded_int_div_rem(a, b)
}

Expand All @@ -198,6 +198,7 @@ type BoundedInt<16, 85> = BoundedInt<16, 85> [storable: true, drop: true, dup: t
type BoundedInt<0, 7> = BoundedInt<0, 7> [storable: true, drop: true, dup: true, zero_sized: false];
type Tuple<BoundedInt<16, 85>, BoundedInt<0, 7>> = Struct<ut@Tuple, BoundedInt<16, 85>, BoundedInt<0, 7>> [storable: true, drop: true, dup: true, zero_sized: false];
type BoundedInt<3, 8> = BoundedInt<3, 8> [storable: true, drop: true, dup: true, zero_sized: false];
type NonZero<BoundedInt<3, 8>> = NonZero<BoundedInt<3, 8>> [storable: true, drop: true, dup: true, zero_sized: false];
type BoundedInt<128, 255> = BoundedInt<128, 255> [storable: true, drop: true, dup: true, zero_sized: false];

libfunc bounded_int_div_rem<BoundedInt<128, 255>, BoundedInt<3, 8>> = bounded_int_div_rem<BoundedInt<128, 255>, BoundedInt<3, 8>>;
Expand All @@ -211,7 +212,7 @@ store_temp<RangeCheck>([3]) -> ([3]); // 2
store_temp<Tuple<BoundedInt<16, 85>, BoundedInt<0, 7>>>([6]) -> ([6]); // 3
return([3], [6]); // 4

test::foo@0([0]: RangeCheck, [1]: BoundedInt<128, 255>, [2]: BoundedInt<3, 8>) -> (RangeCheck, Tuple<BoundedInt<16, 85>, BoundedInt<0, 7>>);
test::foo@0([0]: RangeCheck, [1]: BoundedInt<128, 255>, [2]: NonZero<BoundedInt<3, 8>>) -> (RangeCheck, Tuple<BoundedInt<16, 85>, BoundedInt<0, 7>>);

//! > ==========================================================================

Expand All @@ -227,9 +228,9 @@ type DivRemType = (
BoundedInt<0, 0xfffffffffffffffffffffffffffffffe>,
);

extern fn bounded_int_div_rem<T1, T2>(a: T1, b: T2) -> DivRemType implicits(RangeCheck) nopanic;
extern fn bounded_int_div_rem<T1, T2>(a: T1, b: NonZero<T2>) -> DivRemType implicits(RangeCheck) nopanic;

fn foo(a: u128, b: BoundedInt<1, 0xffffffffffffffffffffffffffffffff>) -> DivRemType {
fn foo(a: u128, b: NonZero<BoundedInt<1, 0xffffffffffffffffffffffffffffffff>>) -> DivRemType {
bounded_int_div_rem(a, b)
}

Expand Down Expand Up @@ -262,6 +263,7 @@ type BoundedInt<0, 340282366920938463463374607431768211455> = BoundedInt<0, 3402
type BoundedInt<0, 340282366920938463463374607431768211454> = BoundedInt<0, 340282366920938463463374607431768211454> [storable: true, drop: true, dup: true, zero_sized: false];
type Tuple<BoundedInt<0, 340282366920938463463374607431768211455>, BoundedInt<0, 340282366920938463463374607431768211454>> = Struct<ut@Tuple, BoundedInt<0, 340282366920938463463374607431768211455>, BoundedInt<0, 340282366920938463463374607431768211454>> [storable: true, drop: true, dup: true, zero_sized: false];
type BoundedInt<1, 340282366920938463463374607431768211455> = BoundedInt<1, 340282366920938463463374607431768211455> [storable: true, drop: true, dup: true, zero_sized: false];
type NonZero<BoundedInt<1, 340282366920938463463374607431768211455>> = NonZero<BoundedInt<1, 340282366920938463463374607431768211455>> [storable: true, drop: true, dup: true, zero_sized: false];
type u128 = u128 [storable: true, drop: true, dup: true, zero_sized: false];

libfunc bounded_int_div_rem<u128, BoundedInt<1, 340282366920938463463374607431768211455>> = bounded_int_div_rem<u128, BoundedInt<1, 340282366920938463463374607431768211455>>;
Expand All @@ -275,7 +277,7 @@ store_temp<RangeCheck>([3]) -> ([3]); // 2
store_temp<Tuple<BoundedInt<0, 340282366920938463463374607431768211455>, BoundedInt<0, 340282366920938463463374607431768211454>>>([6]) -> ([6]); // 3
return([3], [6]); // 4

test::foo@0([0]: RangeCheck, [1]: u128, [2]: BoundedInt<1, 340282366920938463463374607431768211455>) -> (RangeCheck, Tuple<BoundedInt<0, 340282366920938463463374607431768211455>, BoundedInt<0, 340282366920938463463374607431768211454>>);
test::foo@0([0]: RangeCheck, [1]: u128, [2]: NonZero<BoundedInt<1, 340282366920938463463374607431768211455>>) -> (RangeCheck, Tuple<BoundedInt<0, 340282366920938463463374607431768211455>, BoundedInt<0, 340282366920938463463374607431768211454>>);

//! > ==========================================================================

Expand All @@ -291,9 +293,9 @@ type DivRemType = (
BoundedInt<0, 0xfffffffffffffffffffffffffffffff>,
);

extern fn bounded_int_div_rem<T1, T2>(a: T1, b: T2) -> DivRemType implicits(RangeCheck) nopanic;
extern fn bounded_int_div_rem<T1, T2>(a: T1, b: NonZero<T2>) -> DivRemType implicits(RangeCheck) nopanic;

fn foo(a: u128, b: BoundedInt<0x10000000000000000000000000000000, 0x10000000000000000000000000000000>) -> DivRemType {
fn foo(a: u128, b: NonZero<BoundedInt<0x10000000000000000000000000000000, 0x10000000000000000000000000000000>>) -> DivRemType {
bounded_int_div_rem(a, b)
}

Expand Down Expand Up @@ -322,6 +324,7 @@ type BoundedInt<0, 15> = BoundedInt<0, 15> [storable: true, drop: true, dup: tru
type BoundedInt<0, 21267647932558653966460912964485513215> = BoundedInt<0, 21267647932558653966460912964485513215> [storable: true, drop: true, dup: true, zero_sized: false];
type Tuple<BoundedInt<0, 15>, BoundedInt<0, 21267647932558653966460912964485513215>> = Struct<ut@Tuple, BoundedInt<0, 15>, BoundedInt<0, 21267647932558653966460912964485513215>> [storable: true, drop: true, dup: true, zero_sized: false];
type BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216> = BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216> [storable: true, drop: true, dup: true, zero_sized: false];
type NonZero<BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216>> = NonZero<BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216>> [storable: true, drop: true, dup: true, zero_sized: false];
type u128 = u128 [storable: true, drop: true, dup: true, zero_sized: false];

libfunc bounded_int_div_rem<u128, BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216>> = bounded_int_div_rem<u128, BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216>>;
Expand All @@ -335,7 +338,7 @@ store_temp<RangeCheck>([3]) -> ([3]); // 2
store_temp<Tuple<BoundedInt<0, 15>, BoundedInt<0, 21267647932558653966460912964485513215>>>([6]) -> ([6]); // 3
return([3], [6]); // 4

test::foo@0([0]: RangeCheck, [1]: u128, [2]: BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216>) -> (RangeCheck, Tuple<BoundedInt<0, 15>, BoundedInt<0, 21267647932558653966460912964485513215>>);
test::foo@0([0]: RangeCheck, [1]: u128, [2]: NonZero<BoundedInt<21267647932558653966460912964485513216, 21267647932558653966460912964485513216>>) -> (RangeCheck, Tuple<BoundedInt<0, 15>, BoundedInt<0, 21267647932558653966460912964485513215>>);

//! > ==========================================================================

Expand Down

0 comments on commit db8aa7f

Please sign in to comment.