From 91e839aae54df18cde0dd923102ca9cdc9f2ce0a Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 11 Apr 2023 11:09:19 +0000 Subject: [PATCH 01/11] Add extra #[inline]; this speeds up the avx2 backend slightly --- src/backend/vector/avx2/field.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/backend/vector/avx2/field.rs b/src/backend/vector/avx2/field.rs index 9f278723d..614c32750 100644 --- a/src/backend/vector/avx2/field.rs +++ b/src/backend/vector/avx2/field.rs @@ -765,6 +765,7 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { /// The coefficients of the result are bounded with \\( b < 0.007 \\). /// #[rustfmt::skip] // keep alignment of z* calculations + #[inline] fn mul(self, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 { #[inline(always)] fn m(x: u32x8, y: u32x8) -> u64x4 { From 0db8783be8879662e110e4472d8ae06fc919f59e Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 11 Apr 2023 11:13:18 +0000 Subject: [PATCH 02/11] Runtime backend autodetection --- .github/workflows/rust.yml | 33 +- Cargo.toml | 17 +- Makefile | 1 - README.md | 50 +-- src/backend/mod.rs | 300 +++++++++++++++++- src/backend/serial/mod.rs | 4 - src/backend/vector/avx2/edwards.rs | 26 +- src/backend/vector/avx2/field.rs | 11 + src/backend/vector/avx2/mod.rs | 2 + src/backend/vector/ifma/edwards.rs | 24 +- src/backend/vector/ifma/field.rs | 20 +- src/backend/vector/ifma/mod.rs | 2 + src/backend/vector/mod.rs | 51 +-- src/backend/vector/packed_simd.rs | 34 +- src/backend/vector/scalar_mul/pippenger.rs | 25 +- .../vector/scalar_mul/precomputed_straus.rs | 21 +- src/backend/vector/scalar_mul/straus.rs | 15 +- .../vector/scalar_mul/variable_base.rs | 15 +- .../vector/scalar_mul/vartime_double_base.rs | 26 +- src/edwards.rs | 25 +- src/lib.rs | 10 +- src/ristretto.rs | 25 +- 22 files changed, 570 insertions(+), 167 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index be98f9751..45aa87b0f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -55,19 +55,19 @@ jobs: - run: cargo build --target thumbv7em-none-eabi --release - run: cargo build --target thumbv7em-none-eabi --release --features serde - build-simd-nightly: - name: Build simd backend (nightly) + test-simd-native: + name: Test simd backend (native) runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@nightly - # Build with AVX2 features, then with AVX512 features - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo build --target x86_64-unknown-linux-gnu - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx512ifma' - run: cargo build --target x86_64-unknown-linux-gnu + # This will: + # 1) build all of the x86_64 SIMD code, + # 2) run all of the SIMD-specific tests that the test runner supports, + # 3) run all of the normal tests using the best available SIMD backend. + RUSTFLAGS: '-C target_cpu=native' + run: cargo test --features simd --target x86_64-unknown-linux-gnu test-simd-avx2: name: Test simd backend (avx2) @@ -76,8 +76,10 @@ jobs: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo test --target x86_64-unknown-linux-gnu + # This will run AVX2-specific tests and run all of the normal tests + # with the AVX2 backend, even if the runner supports AVX512. + RUSTFLAGS: '-C target_feature=+avx2' + run: cargo test --no-default-features --features alloc,precomputed-tables,zeroize,simd_avx2 --target x86_64-unknown-linux-gnu build-docs: name: Build docs @@ -131,12 +133,7 @@ jobs: - uses: dtolnay/rust-toolchain@nightly with: components: clippy - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo clippy --target x86_64-unknown-linux-gnu - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx512ifma' - run: cargo clippy --target x86_64-unknown-linux-gnu + - run: cargo clippy --target x86_64-unknown-linux-gnu rustfmt: name: Check formatting @@ -162,9 +159,7 @@ jobs: - uses: dtolnay/rust-toolchain@1.60.0 - run: cargo build --no-default-features --features serde # Also make sure the AVX2 build works - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo build --target x86_64-unknown-linux-gnu + - run: cargo build --target x86_64-unknown-linux-gnu bench: name: Check that benchmarks compile diff --git a/Cargo.toml b/Cargo.toml index a1dafcb24..1a76f8742 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ rustdoc-args = [ "--html-in-header", "docs/assets/rustdoc-include-katex-header.html", "--cfg", "docsrs", ] -rustc-args = ["--cfg", "curve25519_dalek_backend=\"simd\""] features = ["serde", "rand_core", "digest", "legacy_compatibility"] [dev-dependencies] @@ -54,15 +53,29 @@ digest = { version = "0.10", default-features = false, optional = true } subtle = { version = "2.3.0", default-features = false } serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] } zeroize = { version = "1", default-features = false, optional = true } +unsafe_target_feature = { version = "0.1.1", optional = true } + +[target.'cfg(target_arch = "x86_64")'.dependencies] +cpufeatures = "0.2.6" [target.'cfg(curve25519_dalek_backend = "fiat")'.dependencies] fiat-crypto = "0.1.19" [features] -default = ["alloc", "precomputed-tables", "zeroize"] +default = ["alloc", "precomputed-tables", "zeroize", "simd"] alloc = ["zeroize?/alloc"] precomputed-tables = [] legacy_compatibility = [] +# Whether to allow the use of the AVX2 SIMD backend. +simd_avx2 = ["unsafe_target_feature"] + +# Whether to allow the use of the AVX512 SIMD backend. +# (Note: This requires Rust nightly; on Rust stable this feature will be ignored.) +simd_avx512 = ["unsafe_target_feature"] + +# A meta-feature to allow all SIMD backends to be used. +simd = ["simd_avx2", "simd_avx512"] + [profile.dev] opt-level = 2 diff --git a/Makefile b/Makefile index 3b41b1756..bb61cc844 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,5 @@ FEATURES := serde rand_core digest legacy_compatibility -export RUSTFLAGS := --cfg=curve25519_dalek_backend="simd" export RUSTDOCFLAGS := \ --cfg docsrs \ --html-in-header docs/assets/rustdoc-include-katex-header.html diff --git a/README.md b/README.md index 429bae9ac..02434ff14 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,9 @@ curve25519-dalek = "4.0.0-rc.2" | `alloc` | ✓ | Enables Edwards and Ristretto multiscalar multiplication, batch scalar inversion, and batch Ristretto double-and-compress. Also enables `zeroize`. | | `zeroize` | ✓ | Enables [`Zeroize`][zeroize-trait] for all scalar and curve point types. | | `precomputed-tables` | ✓ | Includes precomputed basepoint multiplication tables. This speeds up `EdwardsPoint::mul_base` and `RistrettoPoint::mul_base` by ~4x, at the cost of ~30KB added to the code size. | +| `simd_avx2` | ✓ | Allows the AVX2 SIMD backend to be used, if available. | +| `simd_avx512` | ✓ | Allows the AVX512 SIMD backend to be used, if available. | +| `simd` | ✓ | Allows every SIMD backend to be used, if available. | | `rand_core` | | Enables `Scalar::random` and `RistrettoPoint::random`. This is an optional dependency whose version is not subject to SemVer. See [below](#public-api-semver-exemptions) for more details. | | `digest` | | Enables `RistrettoPoint::{from_hash, hash_from_bytes}` and `Scalar::{from_hash, hash_from_bytes}`. This is an optional dependency whose version is not subject to SemVer. See [below](#public-api-semver-exemptions) for more details. | | `serde` | | Enables `serde` serialization/deserialization for all the point and scalar types. | @@ -95,18 +98,17 @@ See tracking issue: [curve25519-dalek/issues/521](https://github.com/dalek-crypt Curve arithmetic is implemented and used by selecting one of the following backends: -| Backend | Implementation | Target backends | -| :--- | :--- | :--- | -| `[default]` | Serial formulas | `u32`
`u64` | -| `simd` | [Parallel][parallel_doc], using Advanced Vector Extensions | `avx2`
`avx512ifma` | -| `fiat` | Formally verified field arithmetic from [fiat-crypto] | `fiat_u32`
`fiat_u64` | +| Backend | Implementation | Target backends | +| :--- | :--- | :--- | +| `[default]` | Automatic runtime backend selection (either serial or SIMD) | `u32`
`u64`
`avx2`
`avx512` | +| `fiat` | Formally verified field arithmetic from [fiat-crypto] | `fiat_u32`
`fiat_u64` | -To choose a backend other than the `[default]` serial backend, set the +To choose a backend other than the `[default]` backend, set the environment variable: ```sh RUSTFLAGS='--cfg curve25519_dalek_backend="BACKEND"' ``` -where `BACKEND` is `simd` or `fiat`. Equivalently, you can write to +where `BACKEND` is `fiat`. Equivalently, you can write to `~/.cargo/config`: ```toml [build] @@ -114,11 +116,8 @@ rustflags = ['--cfg=curve25519_dalek_backend="BACKEND"'] ``` More info [here](https://doc.rust-lang.org/cargo/reference/config.html#buildrustflags). -The `simd` backend requires extra configuration. See [the SIMD -section](#simd-target-backends). - Note for contributors: The target backends are not entirely independent of each -other. The `simd` backend directly depends on parts of the the `u64` backend to +other. The SIMD backend directly depends on parts of the the `u64` backend to function. ## Word size for serial backends @@ -137,7 +136,7 @@ RUSTFLAGS='--cfg curve25519_dalek_bits="SIZE"' where `SIZE` is `32` or `64`. As in the above section, this can also be placed in `~/.cargo/config`. -**NOTE:** The `simd` backend CANNOT be used with word size 32. +**NOTE:** Using a word size of 32 will automatically disable SIMD support. ### Cross-compilation @@ -152,18 +151,19 @@ $ cargo build --target i686-unknown-linux-gnu ## SIMD target backends -Target backend selection within `simd` must be done manually by setting the -`RUSTFLAGS` environment variable to one of the below options: +The SIMD target backend selection is done automatically at runtime depending +on the available CPU features, provided the appropriate feature flag is enabled. -| CPU feature | `RUSTFLAGS` | Requires nightly? | -| :--- | :--- | :--- | -| avx2 | `-C target_feature=+avx2` | no | -| avx512ifma | `-C target_feature=+avx512ifma` | yes | +You can also specify an appropriate `-C target_feature` to build a binary +which assumes the required SIMD instructions are always available. -Or you can use `-C target_cpu=native` if you don't know what to set. +| Backend | Feature flag | `RUSTFLAGS` | Requires nightly? | +| :--- | :--- | :--- | :--- | +| avx2 | `simd_avx2` | `-C target_feature=+avx2` | no | +| avx512 | `simd_avx512` | `-C target_feature=+avx512ifma,+avx512vl` | yes | -The AVX512 backend requires Rust nightly. If enabled and when compiled on a non-nightly -compiler it will fall back to using the AVX2 backend. +The AVX512 backend requires Rust nightly. When compiled on a non-nightly +compiler it will always be disabled. # Documentation @@ -243,7 +243,8 @@ The implementation is memory-safe, and contains no significant `unsafe` code. The SIMD backend uses `unsafe` internally to call SIMD intrinsics. These are marked `unsafe` only because invoking them on an inappropriate CPU would cause `SIGILL`, but the entire backend is only -compiled with appropriate `target_feature`s, so this cannot occur. +invoked when the appropriate CPU features are detected at runtime, or +when the whole program is compiled with the appropriate `target_feature`s. # Performance @@ -251,8 +252,7 @@ Benchmarks are run using [`criterion.rs`][criterion]: ```sh cargo bench --features "rand_core" -# Uses avx2 or ifma only if compiled for an appropriate target. -export RUSTFLAGS='--cfg curve25519_dalek_backend="simd" -C target_cpu=native' +export RUSTFLAGS='-C target_cpu=native' cargo +nightly bench --features "rand_core" ``` @@ -294,7 +294,7 @@ universe's beauty, but also his deep hatred of the Daleks. Rusty destroys the other Daleks and departs the ship, determined to track down and bring an end to the Dalek race.* -`curve25519-dalek` is authored by Isis Agora Lovecruft and Henry de Valence. +`curve25519-dalek` is authored by Isis Agora Lovecruft and Henry de Valence. Portions of this library were originally a port of [Adam Langley's Golang ed25519 library](https://github.com/agl/ed25519), which was in diff --git a/src/backend/mod.rs b/src/backend/mod.rs index b6cea7ebf..09cfaf8bc 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -34,7 +34,305 @@ //! The [`vector`] backend is selected by the `simd_backend` cargo //! feature; it uses the [`serial`] backend for non-vectorized operations. +use crate::EdwardsPoint; +use crate::Scalar; + pub mod serial; -#[cfg(any(curve25519_dalek_backend = "simd", docsrs))] +#[cfg(all( + target_arch = "x86_64", + any(feature = "simd_avx2", all(feature = "simd_avx512", nightly)), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") +))] pub mod vector; + +#[derive(Copy, Clone)] +enum BackendKind { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx2, + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx512, + Serial, +} + +#[inline] +fn get_selected_backend() -> BackendKind { + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + { + cpufeatures::new!(cpuid_avx512, "avx512ifma", "avx512vl"); + let token_avx512: cpuid_avx512::InitToken = cpuid_avx512::init(); + if token_avx512.get() { + return BackendKind::Avx512; + } + } + + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + { + cpufeatures::new!(cpuid_avx2, "avx2"); + let token_avx2: cpuid_avx2::InitToken = cpuid_avx2::init(); + if token_avx2.get() { + return BackendKind::Avx2; + } + } + + BackendKind::Serial +} + +#[cfg(feature = "alloc")] +pub fn pippenger_optional_multiscalar_mul(scalars: I, points: J) -> Option +where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator>, +{ + use crate::traits::VartimeMultiscalarMul; + + match get_selected_backend() { + #[cfg(all(target_arch = "x86_64", feature = "simd_avx2", curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx2 => + self::vector::scalar_mul::pippenger::spec_avx2::Pippenger::optional_multiscalar_mul::(scalars, points), + #[cfg(all(target_arch = "x86_64", all(feature = "simd_avx512", nightly), curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx512 => + self::vector::scalar_mul::pippenger::spec_avx512ifma_avx512vl::Pippenger::optional_multiscalar_mul::(scalars, points), + BackendKind::Serial => + self::serial::scalar_mul::pippenger::Pippenger::optional_multiscalar_mul::(scalars, points), + } +} + +#[cfg(feature = "alloc")] +pub(crate) enum VartimePrecomputedStraus { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx2(self::vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx512ifma( + self::vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus, + ), + Scalar(self::serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus), +} + +#[cfg(feature = "alloc")] +impl VartimePrecomputedStraus { + pub fn new(static_points: I) -> Self + where + I: IntoIterator, + I::Item: core::borrow::Borrow, + { + use crate::traits::VartimePrecomputedMultiscalarMul; + + match get_selected_backend() { + #[cfg(all(target_arch = "x86_64", feature = "simd_avx2", curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx2 => + VartimePrecomputedStraus::Avx2(self::vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus::new(static_points)), + #[cfg(all(target_arch = "x86_64", all(feature = "simd_avx512", nightly), curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx512 => + VartimePrecomputedStraus::Avx512ifma(self::vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus::new(static_points)), + BackendKind::Serial => + VartimePrecomputedStraus::Scalar(self::serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points)) + } + } + + pub fn optional_mixed_multiscalar_mul( + &self, + static_scalars: I, + dynamic_scalars: J, + dynamic_points: K, + ) -> Option + where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator, + J::Item: core::borrow::Borrow, + K: IntoIterator>, + { + use crate::traits::VartimePrecomputedMultiscalarMul; + + match self { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + VartimePrecomputedStraus::Avx2(inner) => inner.optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points, + ), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + VartimePrecomputedStraus::Avx512ifma(inner) => inner.optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points, + ), + VartimePrecomputedStraus::Scalar(inner) => inner.optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points, + ), + } + } +} + +#[cfg(feature = "alloc")] +pub fn straus_multiscalar_mul(scalars: I, points: J) -> EdwardsPoint +where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator, + J::Item: core::borrow::Borrow, +{ + use crate::traits::MultiscalarMul; + + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => { + self::vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::( + scalars, points, + ) + } + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::< + I, + J, + >(scalars, points) + } + BackendKind::Serial => { + self::serial::scalar_mul::straus::Straus::multiscalar_mul::(scalars, points) + } + } +} + +#[cfg(feature = "alloc")] +pub fn straus_optional_multiscalar_mul(scalars: I, points: J) -> Option +where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator>, +{ + use crate::traits::VartimeMultiscalarMul; + + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => { + self::vector::scalar_mul::straus::spec_avx2::Straus::optional_multiscalar_mul::( + scalars, points, + ) + } + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::optional_multiscalar_mul::< + I, + J, + >(scalars, points) + } + BackendKind::Serial => { + self::serial::scalar_mul::straus::Straus::optional_multiscalar_mul::( + scalars, points, + ) + } + } +} + +/// Perform constant-time, variable-base scalar multiplication. +pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => self::vector::scalar_mul::variable_base::spec_avx2::mul(point, scalar), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar) + } + BackendKind::Serial => self::serial::scalar_mul::variable_base::mul(point, scalar), + } +} + +/// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the Ed25519 basepoint. +#[allow(non_snake_case)] +pub fn vartime_double_base_mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => self::vector::scalar_mul::vartime_double_base::spec_avx2::mul(a, A, b), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::vartime_double_base::spec_avx512ifma_avx512vl::mul(a, A, b) + } + BackendKind::Serial => self::serial::scalar_mul::vartime_double_base::mul(a, A, b), + } +} diff --git a/src/backend/serial/mod.rs b/src/backend/serial/mod.rs index 933bb88de..13fef5c63 100644 --- a/src/backend/serial/mod.rs +++ b/src/backend/serial/mod.rs @@ -42,8 +42,4 @@ cfg_if! { pub mod curve_models; -#[cfg(not(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") -)))] pub mod scalar_mul; diff --git a/src/backend/vector/avx2/edwards.rs b/src/backend/vector/avx2/edwards.rs index 032265069..7bb58b1ee 100644 --- a/src/backend/vector/avx2/edwards.rs +++ b/src/backend/vector/avx2/edwards.rs @@ -41,8 +41,13 @@ use core::ops::{Add, Neg, Sub}; use subtle::Choice; use subtle::ConditionallySelectable; +use unsafe_target_feature::unsafe_target_feature; + use crate::edwards; -use crate::window::{LookupTable, NafLookupTable5, NafLookupTable8}; +use crate::window::{LookupTable, NafLookupTable5}; + +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +use crate::window::NafLookupTable8; use crate::traits::Identity; @@ -59,12 +64,14 @@ use super::field::{FieldElement2625x4, Lanes, Shuffle}; #[derive(Copy, Clone, Debug)] pub struct ExtendedPoint(pub(super) FieldElement2625x4); +#[unsafe_target_feature("avx2")] impl From for ExtendedPoint { fn from(P: edwards::EdwardsPoint) -> ExtendedPoint { ExtendedPoint(FieldElement2625x4::new(&P.X, &P.Y, &P.Z, &P.T)) } } +#[unsafe_target_feature("avx2")] impl From for edwards::EdwardsPoint { fn from(P: ExtendedPoint) -> edwards::EdwardsPoint { let tmp = P.0.split(); @@ -77,6 +84,7 @@ impl From for edwards::EdwardsPoint { } } +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for ExtendedPoint { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { ExtendedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice)) @@ -87,18 +95,21 @@ impl ConditionallySelectable for ExtendedPoint { } } +#[unsafe_target_feature("avx2")] impl Default for ExtendedPoint { fn default() -> ExtendedPoint { ExtendedPoint::identity() } } +#[unsafe_target_feature("avx2")] impl Identity for ExtendedPoint { fn identity() -> ExtendedPoint { constants::EXTENDEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx2")] impl ExtendedPoint { /// Compute the double of this point. pub fn double(&self) -> ExtendedPoint { @@ -184,6 +195,7 @@ impl ExtendedPoint { #[derive(Copy, Clone, Debug)] pub struct CachedPoint(pub(super) FieldElement2625x4); +#[unsafe_target_feature("avx2")] impl From for CachedPoint { fn from(P: ExtendedPoint) -> CachedPoint { let mut x = P.0; @@ -202,18 +214,21 @@ impl From for CachedPoint { } } +#[unsafe_target_feature("avx2")] impl Default for CachedPoint { fn default() -> CachedPoint { CachedPoint::identity() } } +#[unsafe_target_feature("avx2")] impl Identity for CachedPoint { fn identity() -> CachedPoint { constants::CACHEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for CachedPoint { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { CachedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice)) @@ -224,6 +239,7 @@ impl ConditionallySelectable for CachedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a> Neg for &'a CachedPoint { type Output = CachedPoint; /// Lazily negate the point. @@ -238,6 +254,7 @@ impl<'a> Neg for &'a CachedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -275,6 +292,7 @@ impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -288,6 +306,7 @@ impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { fn from(point: &'a edwards::EdwardsPoint) -> Self { let P = ExtendedPoint::from(*point); @@ -299,6 +318,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { } } +#[unsafe_target_feature("avx2")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -312,6 +332,8 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { } } +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +#[unsafe_target_feature("avx2")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -325,6 +347,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { } } +#[cfg(target_feature = "avx2")] #[cfg(test)] mod test { use super::*; @@ -524,6 +547,7 @@ mod test { doubling_test_helper(P); } + #[cfg(any(feature = "precomputed-tables", feature = "alloc"))] #[test] fn basepoint_odd_lookup_table_verify() { use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; diff --git a/src/backend/vector/avx2/field.rs b/src/backend/vector/avx2/field.rs index 614c32750..bdb55efa5 100644 --- a/src/backend/vector/avx2/field.rs +++ b/src/backend/vector/avx2/field.rs @@ -48,6 +48,8 @@ use crate::backend::vector::avx2::constants::{ P_TIMES_16_HI, P_TIMES_16_LO, P_TIMES_2_HI, P_TIMES_2_LO, }; +use unsafe_target_feature::unsafe_target_feature; + /// Unpack 32-bit lanes into 64-bit lanes: /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) @@ -57,6 +59,7 @@ use crate::backend::vector::avx2::constants::{ /// (a0, 0, b0, 0, c0, 0, d0, 0) /// (a1, 0, b1, 0, c1, 0, d1, 0) /// ``` +#[unsafe_target_feature("avx2")] #[inline(always)] fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { let a: u32x8; @@ -80,6 +83,7 @@ fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) /// ``` +#[unsafe_target_feature("avx2")] #[inline(always)] fn repack_pair(x: u32x8, y: u32x8) -> u32x8 { unsafe { @@ -151,6 +155,7 @@ pub struct FieldElement2625x4(pub(crate) [u32x8; 5]); use subtle::Choice; use subtle::ConditionallySelectable; +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for FieldElement2625x4 { fn conditional_select( a: &FieldElement2625x4, @@ -179,6 +184,7 @@ impl ConditionallySelectable for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl FieldElement2625x4 { pub const ZERO: FieldElement2625x4 = FieldElement2625x4([u32x8::splat_const::<0>(); 5]); @@ -675,6 +681,7 @@ impl FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Neg for FieldElement2625x4 { type Output = FieldElement2625x4; @@ -703,6 +710,7 @@ impl Neg for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Add for FieldElement2625x4 { type Output = FieldElement2625x4; /// Add two `FieldElement2625x4`s, without performing a reduction. @@ -718,6 +726,7 @@ impl Add for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { type Output = FieldElement2625x4; /// Perform a multiplication by a vector of small constants. @@ -750,6 +759,7 @@ impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { type Output = FieldElement2625x4; /// Multiply `self` by `rhs`. @@ -860,6 +870,7 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { } } +#[cfg(target_feature = "avx2")] #[cfg(test)] mod test { use super::*; diff --git a/src/backend/vector/avx2/mod.rs b/src/backend/vector/avx2/mod.rs index b3e2d14ea..fba39f05c 100644 --- a/src/backend/vector/avx2/mod.rs +++ b/src/backend/vector/avx2/mod.rs @@ -16,3 +16,5 @@ pub(crate) mod field; pub(crate) mod edwards; pub(crate) mod constants; + +pub(crate) use self::edwards::{CachedPoint, ExtendedPoint}; diff --git a/src/backend/vector/ifma/edwards.rs b/src/backend/vector/ifma/edwards.rs index 5bdc3ce07..ccfe092c8 100644 --- a/src/backend/vector/ifma/edwards.rs +++ b/src/backend/vector/ifma/edwards.rs @@ -16,8 +16,13 @@ use core::ops::{Add, Neg, Sub}; use subtle::Choice; use subtle::ConditionallySelectable; +use unsafe_target_feature::unsafe_target_feature; + use crate::edwards; -use crate::window::{LookupTable, NafLookupTable5, NafLookupTable8}; +use crate::window::{LookupTable, NafLookupTable5}; + +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +use crate::window::NafLookupTable8; use super::constants; use super::field::{F51x4Reduced, F51x4Unreduced, Lanes, Shuffle}; @@ -28,12 +33,14 @@ pub struct ExtendedPoint(pub(super) F51x4Unreduced); #[derive(Copy, Clone, Debug)] pub struct CachedPoint(pub(super) F51x4Reduced); +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for ExtendedPoint { fn from(P: edwards::EdwardsPoint) -> ExtendedPoint { ExtendedPoint(F51x4Unreduced::new(&P.X, &P.Y, &P.Z, &P.T)) } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for edwards::EdwardsPoint { fn from(P: ExtendedPoint) -> edwards::EdwardsPoint { let reduced = F51x4Reduced::from(P.0); @@ -47,6 +54,7 @@ impl From for edwards::EdwardsPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for CachedPoint { fn from(P: ExtendedPoint) -> CachedPoint { let mut x = P.0; @@ -59,18 +67,21 @@ impl From for CachedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Default for ExtendedPoint { fn default() -> ExtendedPoint { ExtendedPoint::identity() } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Identity for ExtendedPoint { fn identity() -> ExtendedPoint { constants::EXTENDEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl ExtendedPoint { pub fn double(&self) -> ExtendedPoint { // (Y1 X1 T1 Z1) -- uses vpshufd (1c latency @ 1/c) @@ -122,6 +133,7 @@ impl ExtendedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -151,18 +163,21 @@ impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Default for CachedPoint { fn default() -> CachedPoint { CachedPoint::identity() } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Identity for CachedPoint { fn identity() -> CachedPoint { constants::CACHEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl ConditionallySelectable for CachedPoint { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { CachedPoint(F51x4Reduced::conditional_select(&a.0, &b.0, choice)) @@ -173,6 +188,7 @@ impl ConditionallySelectable for CachedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> Neg for &'a CachedPoint { type Output = CachedPoint; @@ -182,6 +198,7 @@ impl<'a> Neg for &'a CachedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -191,6 +208,7 @@ impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { fn from(point: &'a edwards::EdwardsPoint) -> Self { let P = ExtendedPoint::from(*point); @@ -202,6 +220,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -215,6 +234,8 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { } } +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -228,6 +249,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { } } +#[cfg(target_feature = "avx512ifma,avx512vl")] #[cfg(test)] mod test { use super::*; diff --git a/src/backend/vector/ifma/field.rs b/src/backend/vector/ifma/field.rs index fd1955315..5928e14a2 100644 --- a/src/backend/vector/ifma/field.rs +++ b/src/backend/vector/ifma/field.rs @@ -16,15 +16,19 @@ use core::ops::{Add, Mul, Neg}; use crate::backend::serial::u64::field::FieldElement51; +use unsafe_target_feature::unsafe_target_feature; + /// A wrapper around `vpmadd52luq` that works on `u64x4`. -#[inline(always)] +#[unsafe_target_feature("avx512ifma,avx512vl")] +#[inline] unsafe fn madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x4 { use core::arch::x86_64::_mm256_madd52lo_epu64; _mm256_madd52lo_epu64(z.into(), x.into(), y.into()).into() } /// A wrapper around `vpmadd52huq` that works on `u64x4`. -#[inline(always)] +#[unsafe_target_feature("avx512ifma,avx512vl")] +#[inline] unsafe fn madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x4 { use core::arch::x86_64::_mm256_madd52hi_epu64; _mm256_madd52hi_epu64(z.into(), x.into(), y.into()).into() @@ -53,6 +57,7 @@ pub enum Shuffle { CACA, } +#[unsafe_target_feature("avx512ifma,avx512vl")] #[inline(always)] fn shuffle_lanes(x: u64x4, control: Shuffle) -> u64x4 { unsafe { @@ -84,6 +89,7 @@ pub enum Lanes { BCD, } +#[unsafe_target_feature("avx512ifma,avx512vl")] #[inline] fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 { unsafe { @@ -100,6 +106,7 @@ fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl F51x4Unreduced { pub const ZERO: F51x4Unreduced = F51x4Unreduced([u64x4::splat_const::<0>(); 5]); @@ -198,6 +205,7 @@ impl F51x4Unreduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Neg for F51x4Reduced { type Output = F51x4Reduced; @@ -209,6 +217,7 @@ impl Neg for F51x4Reduced { use subtle::Choice; use subtle::ConditionallySelectable; +#[unsafe_target_feature("avx512ifma,avx512vl")] impl ConditionallySelectable for F51x4Reduced { #[inline] fn conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced { @@ -235,6 +244,7 @@ impl ConditionallySelectable for F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl F51x4Reduced { #[inline] pub fn shuffle(&self, control: Shuffle) -> F51x4Reduced { @@ -373,6 +383,7 @@ impl F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for F51x4Unreduced { #[inline] fn from(x: F51x4Reduced) -> F51x4Unreduced { @@ -380,6 +391,7 @@ impl From for F51x4Unreduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for F51x4Reduced { #[inline] fn from(x: F51x4Unreduced) -> F51x4Reduced { @@ -405,6 +417,7 @@ impl From for F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Add for F51x4Unreduced { type Output = F51x4Unreduced; #[inline] @@ -419,6 +432,7 @@ impl Add for F51x4Unreduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced { type Output = F51x4Unreduced; #[inline] @@ -470,6 +484,7 @@ impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced { type Output = F51x4Unreduced; #[inline] @@ -614,6 +629,7 @@ impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced { } } +#[cfg(target_feature = "avx512ifma,avx512vl")] #[cfg(test)] mod test { use super::*; diff --git a/src/backend/vector/ifma/mod.rs b/src/backend/vector/ifma/mod.rs index 79a61ff3b..f48748d21 100644 --- a/src/backend/vector/ifma/mod.rs +++ b/src/backend/vector/ifma/mod.rs @@ -16,3 +16,5 @@ pub mod field; pub mod edwards; pub mod constants; + +pub(crate) use self::edwards::{CachedPoint, ExtendedPoint}; diff --git a/src/backend/vector/mod.rs b/src/backend/vector/mod.rs index 51c9e81e3..d720f4acb 100644 --- a/src/backend/vector/mod.rs +++ b/src/backend/vector/mod.rs @@ -11,60 +11,13 @@ #![doc = include_str!("../../../docs/parallel-formulas.md")] -#[cfg(not(any( - target_feature = "avx2", - all(target_feature = "avx512ifma", nightly), - docsrs -)))] -compile_error!("'simd' backend selected without target_feature=+avx2 or +avx512ifma"); - #[allow(missing_docs)] pub mod packed_simd; -#[cfg(any( - all( - target_feature = "avx2", - not(all(target_feature = "avx512ifma", nightly)) - ), - all(docsrs, target_arch = "x86_64") -))] +#[cfg(feature = "simd_avx2")] pub mod avx2; -#[cfg(any( - all( - target_feature = "avx2", - not(all(target_feature = "avx512ifma", nightly)) - ), - all(docsrs, target_arch = "x86_64") -))] -pub(crate) use self::avx2::{edwards::CachedPoint, edwards::ExtendedPoint}; -#[cfg(any( - all(target_feature = "avx512ifma", nightly), - all(docsrs, target_arch = "x86_64") -))] +#[cfg(all(feature = "simd_avx512", nightly))] pub mod ifma; -#[cfg(all(target_feature = "avx512ifma", nightly))] -pub(crate) use self::ifma::{edwards::CachedPoint, edwards::ExtendedPoint}; -#[cfg(any( - target_feature = "avx2", - all(target_feature = "avx512ifma", nightly), - all(docsrs, target_arch = "x86_64") -))] -#[allow(missing_docs)] pub mod scalar_mul; - -// Precomputed table re-exports - -#[cfg(any( - all( - target_feature = "avx2", - not(all(target_feature = "avx512ifma", nightly)), - feature = "precomputed-tables" - ), - all(docsrs, target_arch = "x86_64") -))] -pub(crate) use self::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; - -#[cfg(all(target_feature = "avx512ifma", nightly, feature = "precomputed-tables"))] -pub(crate) use self::ifma::constants::BASEPOINT_ODD_LOOKUP_TABLE; diff --git a/src/backend/vector/packed_simd.rs b/src/backend/vector/packed_simd.rs index 6a3484d72..2491754e4 100644 --- a/src/backend/vector/packed_simd.rs +++ b/src/backend/vector/packed_simd.rs @@ -11,6 +11,8 @@ ///! by the callers of this code. use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub}; +use unsafe_target_feature::unsafe_target_feature; + macro_rules! impl_shared { ( $ty:ident, @@ -26,6 +28,7 @@ macro_rules! impl_shared { #[repr(transparent)] pub struct $ty(core::arch::x86_64::__m256i); + #[unsafe_target_feature("avx2")] impl From<$ty> for core::arch::x86_64::__m256i { #[inline] fn from(value: $ty) -> core::arch::x86_64::__m256i { @@ -33,6 +36,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl From for $ty { #[inline] fn from(value: core::arch::x86_64::__m256i) -> $ty { @@ -40,6 +44,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl PartialEq for $ty { #[inline] fn eq(&self, rhs: &$ty) -> bool { @@ -72,6 +77,7 @@ macro_rules! impl_shared { impl Eq for $ty {} + #[unsafe_target_feature("avx2")] impl Add for $ty { type Output = Self; @@ -81,6 +87,8 @@ macro_rules! impl_shared { } } + #[allow(clippy::assign_op_pattern)] + #[unsafe_target_feature("avx2")] impl AddAssign for $ty { #[inline] fn add_assign(&mut self, rhs: $ty) { @@ -88,6 +96,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl Sub for $ty { type Output = Self; @@ -97,6 +106,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl BitAnd for $ty { type Output = Self; @@ -106,6 +116,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl BitXor for $ty { type Output = Self; @@ -115,6 +126,8 @@ macro_rules! impl_shared { } } + #[allow(clippy::assign_op_pattern)] + #[unsafe_target_feature("avx2")] impl BitAndAssign for $ty { #[inline] fn bitand_assign(&mut self, rhs: $ty) { @@ -122,6 +135,8 @@ macro_rules! impl_shared { } } + #[allow(clippy::assign_op_pattern)] + #[unsafe_target_feature("avx2")] impl BitXorAssign for $ty { #[inline] fn bitxor_assign(&mut self, rhs: $ty) { @@ -129,6 +144,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] #[allow(dead_code)] impl $ty { #[inline] @@ -152,6 +168,7 @@ macro_rules! impl_shared { macro_rules! impl_conv { ($src:ident => $($dst:ident),+) => { $( + #[unsafe_target_feature("avx2")] impl From<$src> for $dst { #[inline] fn from(value: $src) -> $dst { @@ -235,8 +252,9 @@ impl u64x4 { } /// Constructs a new instance. + #[unsafe_target_feature("avx2")] #[inline] - pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> Self { + pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { unsafe { // _mm256_set_epi64 sets the underlying vector in reverse order of the args Self(core::arch::x86_64::_mm256_set_epi64x( @@ -246,8 +264,9 @@ impl u64x4 { } /// Constructs a new instance with all of the elements initialized to the given value. + #[unsafe_target_feature("avx2")] #[inline] - pub fn splat(x: u64) -> Self { + pub fn splat(x: u64) -> u64x4 { unsafe { Self(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } } } @@ -257,6 +276,7 @@ impl u32x8 { /// A constified variant of `new`. /// /// Should only be called from `const` contexts. At runtime `new` is going to be faster. + #[allow(clippy::too_many_arguments)] #[inline] pub const fn new_const( x0: u32, @@ -282,8 +302,10 @@ impl u32x8 { } /// Constructs a new instance. + #[allow(clippy::too_many_arguments)] + #[unsafe_target_feature("avx2")] #[inline] - pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> Self { + pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 { unsafe { // _mm256_set_epi32 sets the underlying vector in reverse order of the args Self(core::arch::x86_64::_mm256_set_epi32( @@ -294,11 +316,15 @@ impl u32x8 { } /// Constructs a new instance with all of the elements initialized to the given value. + #[unsafe_target_feature("avx2")] #[inline] - pub fn splat(x: u32) -> Self { + pub fn splat(x: u32) -> u32x8 { unsafe { Self(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } } +} +#[unsafe_target_feature("avx2")] +impl u32x8 { /// Multiplies the low unsigned 32-bits from each packed 64-bit element /// and returns the unsigned 64-bit results. /// diff --git a/src/backend/vector/scalar_mul/pippenger.rs b/src/backend/vector/scalar_mul/pippenger.rs index f7c161620..6d4b5aaa7 100644 --- a/src/backend/vector/scalar_mul/pippenger.rs +++ b/src/backend/vector/scalar_mul/pippenger.rs @@ -9,12 +9,23 @@ #![allow(non_snake_case)] +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + use alloc::vec::Vec; use core::borrow::Borrow; use core::cmp::Ordering; -use crate::backend::vector::{CachedPoint, ExtendedPoint}; +#[for_target_feature("avx2")] +use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + +#[for_target_feature("avx512ifma")] +use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; use crate::traits::{Identity, VartimeMultiscalarMul}; @@ -49,7 +60,7 @@ impl VartimeMultiscalarMul for Pippenger { // Collect optimized scalars and points in a buffer for repeated access // (scanning the whole collection per each digit position). - let scalars = scalars.into_iter().map(|s| s.borrow().as_radix_2w(w)); + let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); let points = points .into_iter() @@ -127,12 +138,12 @@ impl VartimeMultiscalarMul for Pippenger { #[cfg(test)] mod test { - use super::*; - use crate::constants; - use crate::scalar::Scalar; - #[test] fn test_vartime_pippenger() { + use super::*; + use crate::constants; + use crate::scalar::Scalar; + // Reuse points across different tests let mut n = 512; let x = Scalar::from(2128506u64).invert(); @@ -163,3 +174,5 @@ mod test { } } } + +} diff --git a/src/backend/vector/scalar_mul/precomputed_straus.rs b/src/backend/vector/scalar_mul/precomputed_straus.rs index 359846173..8c7d725b4 100644 --- a/src/backend/vector/scalar_mul/precomputed_straus.rs +++ b/src/backend/vector/scalar_mul/precomputed_straus.rs @@ -11,12 +11,23 @@ #![allow(non_snake_case)] +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + use alloc::vec::Vec; use core::borrow::Borrow; use core::cmp::Ordering; -use crate::backend::vector::{CachedPoint, ExtendedPoint}; +#[for_target_feature("avx2")] +use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + +#[for_target_feature("avx512ifma")] +use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; use crate::traits::Identity; @@ -33,7 +44,7 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { fn new(static_points: I) -> Self where I: IntoIterator, - I::Item: Borrow, + I::Item: Borrow, { Self { static_lookup_tables: static_points @@ -48,13 +59,13 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { static_scalars: I, dynamic_scalars: J, dynamic_points: K, - ) -> Option + ) -> Option where I: IntoIterator, I::Item: Borrow, J: IntoIterator, J::Item: Borrow, - K: IntoIterator>, + K: IntoIterator>, { let static_nafs = static_scalars .into_iter() @@ -113,3 +124,5 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { Some(R.into()) } } + +} diff --git a/src/backend/vector/scalar_mul/straus.rs b/src/backend/vector/scalar_mul/straus.rs index 693415361..1f3e784e2 100644 --- a/src/backend/vector/scalar_mul/straus.rs +++ b/src/backend/vector/scalar_mul/straus.rs @@ -11,6 +11,12 @@ #![allow(non_snake_case)] +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + use alloc::vec::Vec; use core::borrow::Borrow; @@ -18,7 +24,12 @@ use core::cmp::Ordering; use zeroize::Zeroizing; -use crate::backend::vector::{CachedPoint, ExtendedPoint}; +#[for_target_feature("avx2")] +use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + +#[for_target_feature("avx512ifma")] +use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul}; @@ -110,3 +121,5 @@ impl VartimeMultiscalarMul for Straus { Some(Q.into()) } } + +} diff --git a/src/backend/vector/scalar_mul/variable_base.rs b/src/backend/vector/scalar_mul/variable_base.rs index 52e855dd1..0653d2709 100644 --- a/src/backend/vector/scalar_mul/variable_base.rs +++ b/src/backend/vector/scalar_mul/variable_base.rs @@ -1,6 +1,17 @@ #![allow(non_snake_case)] -use crate::backend::vector::{CachedPoint, ExtendedPoint}; +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + +#[for_target_feature("avx2")] +use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + +#[for_target_feature("avx512ifma")] +use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; use crate::traits::Identity; @@ -30,3 +41,5 @@ pub fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { } Q.into() } + +} diff --git a/src/backend/vector/scalar_mul/vartime_double_base.rs b/src/backend/vector/scalar_mul/vartime_double_base.rs index 5ec69ed52..842a729ef 100644 --- a/src/backend/vector/scalar_mul/vartime_double_base.rs +++ b/src/backend/vector/scalar_mul/vartime_double_base.rs @@ -11,9 +11,28 @@ #![allow(non_snake_case)] +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + use core::cmp::Ordering; -use crate::backend::vector::{CachedPoint, ExtendedPoint}; +#[for_target_feature("avx2")] +use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + +#[for_target_feature("avx512ifma")] +use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + +#[cfg(feature = "precomputed-tables")] +#[for_target_feature("avx2")] +use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; + +#[cfg(feature = "precomputed-tables")] +#[for_target_feature("avx512ifma")] +use crate::backend::vector::ifma::constants::BASEPOINT_ODD_LOOKUP_TABLE; + use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; use crate::traits::Identity; @@ -40,7 +59,8 @@ pub fn mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { let table_A = NafLookupTable5::::from(A); #[cfg(feature = "precomputed-tables")] - let table_B = &crate::backend::vector::BASEPOINT_ODD_LOOKUP_TABLE; + let table_B = &BASEPOINT_ODD_LOOKUP_TABLE; + #[cfg(not(feature = "precomputed-tables"))] let table_B = &NafLookupTable5::::from(&crate::constants::ED25519_BASEPOINT_POINT); @@ -77,3 +97,5 @@ pub fn mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { Q.into() } + +} diff --git a/src/edwards.rs b/src/edwards.rs index fae296f66..5d799cd6d 100644 --- a/src/edwards.rs +++ b/src/edwards.rs @@ -144,17 +144,6 @@ use crate::traits::MultiscalarMul; #[cfg(feature = "alloc")] use crate::traits::{VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; -#[cfg(not(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") -)))] -use crate::backend::serial::scalar_mul; -#[cfg(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") -))] -use crate::backend::vector::scalar_mul; - // ------------------------------------------------------------------------ // Compressed points // ------------------------------------------------------------------------ @@ -696,7 +685,7 @@ impl<'a, 'b> Mul<&'b Scalar> for &'a EdwardsPoint { /// For scalar multiplication of a basepoint, /// `EdwardsBasepointTable` is approximately 4x faster. fn mul(self, scalar: &'b Scalar) -> EdwardsPoint { - scalar_mul::variable_base::mul(self, scalar) + crate::backend::variable_base_mul(self, scalar) } } @@ -793,7 +782,7 @@ impl MultiscalarMul for EdwardsPoint { // size-dependent algorithm dispatch, use this as the hint. let _size = s_lo; - scalar_mul::straus::Straus::multiscalar_mul(scalars, points) + crate::backend::straus_multiscalar_mul(scalars, points) } } @@ -825,9 +814,9 @@ impl VartimeMultiscalarMul for EdwardsPoint { let size = s_lo; if size < 190 { - scalar_mul::straus::Straus::optional_multiscalar_mul(scalars, points) + crate::backend::straus_optional_multiscalar_mul(scalars, points) } else { - scalar_mul::pippenger::Pippenger::optional_multiscalar_mul(scalars, points) + crate::backend::pippenger_optional_multiscalar_mul(scalars, points) } } } @@ -837,7 +826,7 @@ impl VartimeMultiscalarMul for EdwardsPoint { // decouple stability of the inner type from the stability of the // outer type. #[cfg(feature = "alloc")] -pub struct VartimeEdwardsPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus); +pub struct VartimeEdwardsPrecomputation(crate::backend::VartimePrecomputedStraus); #[cfg(feature = "alloc")] impl VartimePrecomputedMultiscalarMul for VartimeEdwardsPrecomputation { @@ -848,7 +837,7 @@ impl VartimePrecomputedMultiscalarMul for VartimeEdwardsPrecomputation { I: IntoIterator, I::Item: Borrow, { - Self(scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points)) + Self(crate::backend::VartimePrecomputedStraus::new(static_points)) } fn optional_mixed_multiscalar_mul( @@ -876,7 +865,7 @@ impl EdwardsPoint { A: &EdwardsPoint, b: &Scalar, ) -> EdwardsPoint { - scalar_mul::vartime_double_base::mul(a, A, b) + crate::backend::vartime_double_base_mul(a, A, b) } } diff --git a/src/lib.rs b/src/lib.rs index 83ccdadd4..ecbbe5a2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,13 +11,13 @@ #![no_std] #![cfg_attr( - all( - curve25519_dalek_backend = "simd", - target_feature = "avx512ifma", - nightly - ), + all(target_arch = "x86_64", feature = "simd_avx512", nightly), feature(stdsimd) )] +#![cfg_attr( + all(target_arch = "x86_64", feature = "simd_avx512", nightly), + feature(avx512_target_feature) +)] #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg, doc_cfg_hide))] #![cfg_attr(docsrs, doc(cfg_hide(docsrs)))] //------------------------------------------------------------------------ diff --git a/src/ristretto.rs b/src/ristretto.rs index 705bb91d4..198320519 100644 --- a/src/ristretto.rs +++ b/src/ristretto.rs @@ -180,9 +180,6 @@ use digest::Digest; use crate::constants; use crate::field::FieldElement; -#[cfg(feature = "alloc")] -use cfg_if::cfg_if; - use subtle::Choice; use subtle::ConditionallyNegatable; use subtle::ConditionallySelectable; @@ -203,18 +200,6 @@ use crate::traits::Identity; #[cfg(feature = "alloc")] use crate::traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; -#[cfg(feature = "alloc")] -cfg_if! { - if #[cfg(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") - ))] { - use crate::backend::vector::scalar_mul; - } else { - use crate::backend::serial::scalar_mul; - } -} - // ------------------------------------------------------------------------ // Compressed points // ------------------------------------------------------------------------ @@ -999,7 +984,7 @@ impl VartimeMultiscalarMul for RistrettoPoint { // decouple stability of the inner type from the stability of the // outer type. #[cfg(feature = "alloc")] -pub struct VartimeRistrettoPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus); +pub struct VartimeRistrettoPrecomputation(crate::backend::VartimePrecomputedStraus); #[cfg(feature = "alloc")] impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation { @@ -1010,11 +995,9 @@ impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation { I: IntoIterator, I::Item: Borrow, { - Self( - scalar_mul::precomputed_straus::VartimePrecomputedStraus::new( - static_points.into_iter().map(|P| P.borrow().0), - ), - ) + Self(crate::backend::VartimePrecomputedStraus::new( + static_points.into_iter().map(|P| P.borrow().0), + )) } fn optional_mixed_multiscalar_mul( From 219995dbc9070383fd1eccb1929d7d420c5098fb Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 11 Apr 2023 11:13:28 +0000 Subject: [PATCH 03/11] `rustfmt src/backend/vector/scalar_mul` (no changes besides formatting) --- src/backend/vector/scalar_mul/pippenger.rs | 297 +++++++++--------- .../vector/scalar_mul/precomputed_straus.rs | 181 ++++++----- src/backend/vector/scalar_mul/straus.rs | 191 ++++++----- .../vector/scalar_mul/variable_base.rs | 63 ++-- .../vector/scalar_mul/vartime_double_base.rs | 124 ++++---- 5 files changed, 426 insertions(+), 430 deletions(-) diff --git a/src/backend/vector/scalar_mul/pippenger.rs b/src/backend/vector/scalar_mul/pippenger.rs index 6d4b5aaa7..b00cb87c5 100644 --- a/src/backend/vector/scalar_mul/pippenger.rs +++ b/src/backend/vector/scalar_mul/pippenger.rs @@ -15,164 +15,163 @@ )] pub mod spec { -use alloc::vec::Vec; - -use core::borrow::Borrow; -use core::cmp::Ordering; - -#[for_target_feature("avx2")] -use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; - -#[for_target_feature("avx512ifma")] -use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; - -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::{Identity, VartimeMultiscalarMul}; - -/// Implements a version of Pippenger's algorithm. -/// -/// See the documentation in the serial `scalar_mul::pippenger` module for details. -pub struct Pippenger; - -impl VartimeMultiscalarMul for Pippenger { - type Point = EdwardsPoint; - - fn optional_multiscalar_mul(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - let mut scalars = scalars.into_iter(); - let size = scalars.by_ref().size_hint().0; - let w = if size < 500 { - 6 - } else if size < 800 { - 7 - } else { - 8 - }; - - let max_digit: usize = 1 << w; - let digits_count: usize = Scalar::to_radix_2w_size_hint(w); - let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket - - // Collect optimized scalars and points in a buffer for repeated access - // (scanning the whole collection per each digit position). - let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); - - let points = points - .into_iter() - .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); - - let scalars_points = scalars - .zip(points) - .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) - .collect::>>()?; - - // Prepare 2^w/2 buckets. - // buckets[i] corresponds to a multiplication factor (i+1). - let mut buckets: Vec = (0..buckets_count) - .map(|_| ExtendedPoint::identity()) - .collect(); - - let mut columns = (0..digits_count).rev().map(|digit_index| { - // Clear the buckets when processing another digit. - for bucket in &mut buckets { - *bucket = ExtendedPoint::identity(); - } + use alloc::vec::Vec; + + use core::borrow::Borrow; + use core::cmp::Ordering; + + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::{Identity, VartimeMultiscalarMul}; + + /// Implements a version of Pippenger's algorithm. + /// + /// See the documentation in the serial `scalar_mul::pippenger` module for details. + pub struct Pippenger; + + impl VartimeMultiscalarMul for Pippenger { + type Point = EdwardsPoint; + + fn optional_multiscalar_mul(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + let mut scalars = scalars.into_iter(); + let size = scalars.by_ref().size_hint().0; + let w = if size < 500 { + 6 + } else if size < 800 { + 7 + } else { + 8 + }; + + let max_digit: usize = 1 << w; + let digits_count: usize = Scalar::to_radix_2w_size_hint(w); + let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket + + // Collect optimized scalars and points in a buffer for repeated access + // (scanning the whole collection per each digit position). + let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); + + let points = points + .into_iter() + .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); + + let scalars_points = scalars + .zip(points) + .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) + .collect::>>()?; + + // Prepare 2^w/2 buckets. + // buckets[i] corresponds to a multiplication factor (i+1). + let mut buckets: Vec = (0..buckets_count) + .map(|_| ExtendedPoint::identity()) + .collect(); + + let mut columns = (0..digits_count).rev().map(|digit_index| { + // Clear the buckets when processing another digit. + for bucket in &mut buckets { + *bucket = ExtendedPoint::identity(); + } - // Iterate over pairs of (point, scalar) - // and add/sub the point to the corresponding bucket. - // Note: if we add support for precomputed lookup tables, - // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. - for (digits, pt) in scalars_points.iter() { - // Widen digit so that we don't run into edge cases when w=8. - let digit = digits[digit_index] as i16; - match digit.cmp(&0) { - Ordering::Greater => { - let b = (digit - 1) as usize; - buckets[b] = &buckets[b] + pt; - } - Ordering::Less => { - let b = (-digit - 1) as usize; - buckets[b] = &buckets[b] - pt; + // Iterate over pairs of (point, scalar) + // and add/sub the point to the corresponding bucket. + // Note: if we add support for precomputed lookup tables, + // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. + for (digits, pt) in scalars_points.iter() { + // Widen digit so that we don't run into edge cases when w=8. + let digit = digits[digit_index] as i16; + match digit.cmp(&0) { + Ordering::Greater => { + let b = (digit - 1) as usize; + buckets[b] = &buckets[b] + pt; + } + Ordering::Less => { + let b = (-digit - 1) as usize; + buckets[b] = &buckets[b] - pt; + } + Ordering::Equal => {} } - Ordering::Equal => {} } - } - // Add the buckets applying the multiplication factor to each bucket. - // The most efficient way to do that is to have a single sum with two running sums: - // an intermediate sum from last bucket to the first, and a sum of intermediate sums. - // - // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: - // C - // C B - // C B A Sum = C + (C+B) + (C+B+A) - let mut buckets_intermediate_sum = buckets[buckets_count - 1]; - let mut buckets_sum = buckets[buckets_count - 1]; - for i in (0..(buckets_count - 1)).rev() { - buckets_intermediate_sum = - &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); - buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); - } + // Add the buckets applying the multiplication factor to each bucket. + // The most efficient way to do that is to have a single sum with two running sums: + // an intermediate sum from last bucket to the first, and a sum of intermediate sums. + // + // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: + // C + // C B + // C B A Sum = C + (C+B) + (C+B+A) + let mut buckets_intermediate_sum = buckets[buckets_count - 1]; + let mut buckets_sum = buckets[buckets_count - 1]; + for i in (0..(buckets_count - 1)).rev() { + buckets_intermediate_sum = + &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); + buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); + } - buckets_sum - }); + buckets_sum + }); - // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. - // `unwrap()` always succeeds because we know we have more than zero digits. - let hi_column = columns.next().unwrap(); + // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. + // `unwrap()` always succeeds because we know we have more than zero digits. + let hi_column = columns.next().unwrap(); - Some( - columns - .fold(hi_column, |total, p| { - &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) - }) - .into(), - ) + Some( + columns + .fold(hi_column, |total, p| { + &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) + }) + .into(), + ) + } } -} -#[cfg(test)] -mod test { - #[test] - fn test_vartime_pippenger() { - use super::*; - use crate::constants; - use crate::scalar::Scalar; - - // Reuse points across different tests - let mut n = 512; - let x = Scalar::from(2128506u64).invert(); - let y = Scalar::from(4443282u64).invert(); - let points: Vec<_> = (0..n) - .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) - .collect(); - let scalars: Vec<_> = (0..n) - .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars - .collect(); - - let premultiplied: Vec = scalars - .iter() - .zip(points.iter()) - .map(|(sc, pt)| sc * pt) - .collect(); - - while n > 0 { - let scalars = &scalars[0..n].to_vec(); - let points = &points[0..n].to_vec(); - let control: EdwardsPoint = premultiplied[0..n].iter().sum(); - - let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); - - assert_eq!(subject.compress(), control.compress()); - - n = n / 2; + #[cfg(test)] + mod test { + #[test] + fn test_vartime_pippenger() { + use super::*; + use crate::constants; + use crate::scalar::Scalar; + + // Reuse points across different tests + let mut n = 512; + let x = Scalar::from(2128506u64).invert(); + let y = Scalar::from(4443282u64).invert(); + let points: Vec<_> = (0..n) + .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) + .collect(); + let scalars: Vec<_> = (0..n) + .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars + .collect(); + + let premultiplied: Vec = scalars + .iter() + .zip(points.iter()) + .map(|(sc, pt)| sc * pt) + .collect(); + + while n > 0 { + let scalars = &scalars[0..n].to_vec(); + let points = &points[0..n].to_vec(); + let control: EdwardsPoint = premultiplied[0..n].iter().sum(); + + let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); + + assert_eq!(subject.compress(), control.compress()); + + n = n / 2; + } } } } - -} diff --git a/src/backend/vector/scalar_mul/precomputed_straus.rs b/src/backend/vector/scalar_mul/precomputed_straus.rs index 8c7d725b4..8c45c29cf 100644 --- a/src/backend/vector/scalar_mul/precomputed_straus.rs +++ b/src/backend/vector/scalar_mul/precomputed_straus.rs @@ -17,112 +17,111 @@ )] pub mod spec { -use alloc::vec::Vec; + use alloc::vec::Vec; -use core::borrow::Borrow; -use core::cmp::Ordering; + use core::borrow::Borrow; + use core::cmp::Ordering; -#[for_target_feature("avx2")] -use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; -#[for_target_feature("avx512ifma")] -use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::Identity; -use crate::traits::VartimePrecomputedMultiscalarMul; -use crate::window::{NafLookupTable5, NafLookupTable8}; + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::Identity; + use crate::traits::VartimePrecomputedMultiscalarMul; + use crate::window::{NafLookupTable5, NafLookupTable8}; -pub struct VartimePrecomputedStraus { - static_lookup_tables: Vec>, -} + pub struct VartimePrecomputedStraus { + static_lookup_tables: Vec>, + } -impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { - type Point = EdwardsPoint; + impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { + type Point = EdwardsPoint; + + fn new(static_points: I) -> Self + where + I: IntoIterator, + I::Item: Borrow, + { + Self { + static_lookup_tables: static_points + .into_iter() + .map(|P| NafLookupTable8::::from(P.borrow())) + .collect(), + } + } - fn new(static_points: I) -> Self - where - I: IntoIterator, - I::Item: Borrow, - { - Self { - static_lookup_tables: static_points + fn optional_mixed_multiscalar_mul( + &self, + static_scalars: I, + dynamic_scalars: J, + dynamic_points: K, + ) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator, + J::Item: Borrow, + K: IntoIterator>, + { + let static_nafs = static_scalars .into_iter() - .map(|P| NafLookupTable8::::from(P.borrow())) - .collect(), - } - } + .map(|c| c.borrow().non_adjacent_form(5)) + .collect::>(); + let dynamic_nafs: Vec<_> = dynamic_scalars + .into_iter() + .map(|c| c.borrow().non_adjacent_form(5)) + .collect::>(); - fn optional_mixed_multiscalar_mul( - &self, - static_scalars: I, - dynamic_scalars: J, - dynamic_points: K, - ) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - K: IntoIterator>, - { - let static_nafs = static_scalars - .into_iter() - .map(|c| c.borrow().non_adjacent_form(5)) - .collect::>(); - let dynamic_nafs: Vec<_> = dynamic_scalars - .into_iter() - .map(|c| c.borrow().non_adjacent_form(5)) - .collect::>(); - - let dynamic_lookup_tables = dynamic_points - .into_iter() - .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) - .collect::>>()?; - - let sp = self.static_lookup_tables.len(); - let dp = dynamic_lookup_tables.len(); - assert_eq!(sp, static_nafs.len()); - assert_eq!(dp, dynamic_nafs.len()); - - // We could save some doublings by looking for the highest - // nonzero NAF coefficient, but since we might have a lot of - // them to search, it's not clear it's worthwhile to check. - let mut R = ExtendedPoint::identity(); - for j in (0..256).rev() { - R = R.double(); - - for i in 0..dp { - let t_ij = dynamic_nafs[i][j]; - match t_ij.cmp(&0) { - Ordering::Greater => { - R = &R + &dynamic_lookup_tables[i].select(t_ij as usize); - } - Ordering::Less => { - R = &R - &dynamic_lookup_tables[i].select(-t_ij as usize); + let dynamic_lookup_tables = dynamic_points + .into_iter() + .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) + .collect::>>()?; + + let sp = self.static_lookup_tables.len(); + let dp = dynamic_lookup_tables.len(); + assert_eq!(sp, static_nafs.len()); + assert_eq!(dp, dynamic_nafs.len()); + + // We could save some doublings by looking for the highest + // nonzero NAF coefficient, but since we might have a lot of + // them to search, it's not clear it's worthwhile to check. + let mut R = ExtendedPoint::identity(); + for j in (0..256).rev() { + R = R.double(); + + for i in 0..dp { + let t_ij = dynamic_nafs[i][j]; + match t_ij.cmp(&0) { + Ordering::Greater => { + R = &R + &dynamic_lookup_tables[i].select(t_ij as usize); + } + Ordering::Less => { + R = &R - &dynamic_lookup_tables[i].select(-t_ij as usize); + } + Ordering::Equal => {} } - Ordering::Equal => {} } - } - #[allow(clippy::needless_range_loop)] - for i in 0..sp { - let t_ij = static_nafs[i][j]; - match t_ij.cmp(&0) { - Ordering::Greater => { - R = &R + &self.static_lookup_tables[i].select(t_ij as usize); + #[allow(clippy::needless_range_loop)] + for i in 0..sp { + let t_ij = static_nafs[i][j]; + match t_ij.cmp(&0) { + Ordering::Greater => { + R = &R + &self.static_lookup_tables[i].select(t_ij as usize); + } + Ordering::Less => { + R = &R - &self.static_lookup_tables[i].select(-t_ij as usize); + } + Ordering::Equal => {} } - Ordering::Less => { - R = &R - &self.static_lookup_tables[i].select(-t_ij as usize); - } - Ordering::Equal => {} } } - } - Some(R.into()) + Some(R.into()) + } } } - -} diff --git a/src/backend/vector/scalar_mul/straus.rs b/src/backend/vector/scalar_mul/straus.rs index 1f3e784e2..046bcd14c 100644 --- a/src/backend/vector/scalar_mul/straus.rs +++ b/src/backend/vector/scalar_mul/straus.rs @@ -17,109 +17,108 @@ )] pub mod spec { -use alloc::vec::Vec; - -use core::borrow::Borrow; -use core::cmp::Ordering; - -use zeroize::Zeroizing; - -#[for_target_feature("avx2")] -use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; - -#[for_target_feature("avx512ifma")] -use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; - -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul}; -use crate::window::{LookupTable, NafLookupTable5}; - -/// Multiscalar multiplication using interleaved window / Straus' -/// method. See the `Straus` struct in the serial backend for more -/// details. -/// -/// This exists as a seperate implementation from that one because the -/// AVX2 code uses different curve models (it does not pass between -/// multiple models during scalar mul), and it has to convert the -/// point representation on the fly. -pub struct Straus {} - -impl MultiscalarMul for Straus { - type Point = EdwardsPoint; - - fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - { - // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] - // for each input point P - let lookup_tables: Vec<_> = points - .into_iter() - .map(|point| LookupTable::::from(point.borrow())) - .collect(); - - let scalar_digits_vec: Vec<_> = scalars - .into_iter() - .map(|s| s.borrow().as_radix_16()) - .collect(); - // Pass ownership to a `Zeroizing` wrapper - let scalar_digits = Zeroizing::new(scalar_digits_vec); - - let mut Q = ExtendedPoint::identity(); - for j in (0..64).rev() { - Q = Q.mul_by_pow_2(4); - let it = scalar_digits.iter().zip(lookup_tables.iter()); - for (s_i, lookup_table_i) in it { - // Q = Q + s_{i,j} * P_i - Q = &Q + &lookup_table_i.select(s_i[j]); + use alloc::vec::Vec; + + use core::borrow::Borrow; + use core::cmp::Ordering; + + use zeroize::Zeroizing; + + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul}; + use crate::window::{LookupTable, NafLookupTable5}; + + /// Multiscalar multiplication using interleaved window / Straus' + /// method. See the `Straus` struct in the serial backend for more + /// details. + /// + /// This exists as a seperate implementation from that one because the + /// AVX2 code uses different curve models (it does not pass between + /// multiple models during scalar mul), and it has to convert the + /// point representation on the fly. + pub struct Straus {} + + impl MultiscalarMul for Straus { + type Point = EdwardsPoint; + + fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator, + J::Item: Borrow, + { + // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] + // for each input point P + let lookup_tables: Vec<_> = points + .into_iter() + .map(|point| LookupTable::::from(point.borrow())) + .collect(); + + let scalar_digits_vec: Vec<_> = scalars + .into_iter() + .map(|s| s.borrow().as_radix_16()) + .collect(); + // Pass ownership to a `Zeroizing` wrapper + let scalar_digits = Zeroizing::new(scalar_digits_vec); + + let mut Q = ExtendedPoint::identity(); + for j in (0..64).rev() { + Q = Q.mul_by_pow_2(4); + let it = scalar_digits.iter().zip(lookup_tables.iter()); + for (s_i, lookup_table_i) in it { + // Q = Q + s_{i,j} * P_i + Q = &Q + &lookup_table_i.select(s_i[j]); + } } + Q.into() } - Q.into() } -} -impl VartimeMultiscalarMul for Straus { - type Point = EdwardsPoint; - - fn optional_multiscalar_mul(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - let nafs: Vec<_> = scalars - .into_iter() - .map(|c| c.borrow().non_adjacent_form(5)) - .collect(); - let lookup_tables: Vec<_> = points - .into_iter() - .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) - .collect::>>()?; - - let mut Q = ExtendedPoint::identity(); - - for i in (0..256).rev() { - Q = Q.double(); - - for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) { - match naf[i].cmp(&0) { - Ordering::Greater => { - Q = &Q + &lookup_table.select(naf[i] as usize); + impl VartimeMultiscalarMul for Straus { + type Point = EdwardsPoint; + + fn optional_multiscalar_mul(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + let nafs: Vec<_> = scalars + .into_iter() + .map(|c| c.borrow().non_adjacent_form(5)) + .collect(); + let lookup_tables: Vec<_> = points + .into_iter() + .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) + .collect::>>()?; + + let mut Q = ExtendedPoint::identity(); + + for i in (0..256).rev() { + Q = Q.double(); + + for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) { + match naf[i].cmp(&0) { + Ordering::Greater => { + Q = &Q + &lookup_table.select(naf[i] as usize); + } + Ordering::Less => { + Q = &Q - &lookup_table.select(-naf[i] as usize); + } + Ordering::Equal => {} } - Ordering::Less => { - Q = &Q - &lookup_table.select(-naf[i] as usize); - } - Ordering::Equal => {} } } - } - Some(Q.into()) + Some(Q.into()) + } } } - -} diff --git a/src/backend/vector/scalar_mul/variable_base.rs b/src/backend/vector/scalar_mul/variable_base.rs index 0653d2709..2da479926 100644 --- a/src/backend/vector/scalar_mul/variable_base.rs +++ b/src/backend/vector/scalar_mul/variable_base.rs @@ -6,40 +6,39 @@ )] pub mod spec { -#[for_target_feature("avx2")] -use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; -#[for_target_feature("avx512ifma")] -use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::Identity; -use crate::window::LookupTable; + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::Identity; + use crate::window::LookupTable; -/// Perform constant-time, variable-base scalar multiplication. -pub fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { - // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] - let lookup_table = LookupTable::::from(point); - // Setting s = scalar, compute - // - // s = s_0 + s_1*16^1 + ... + s_63*16^63, - // - // with `-8 ≤ s_i < 8` for `0 ≤ i < 63` and `-8 ≤ s_63 ≤ 8`. - let scalar_digits = scalar.as_radix_16(); - // Compute s*P as - // - // s*P = P*(s_0 + s_1*16^1 + s_2*16^2 + ... + s_63*16^63) - // s*P = P*s_0 + P*s_1*16^1 + P*s_2*16^2 + ... + P*s_63*16^63 - // s*P = P*s_0 + 16*(P*s_1 + 16*(P*s_2 + 16*( ... + P*s_63)...)) - // - // We sum right-to-left. - let mut Q = ExtendedPoint::identity(); - for i in (0..64).rev() { - Q = Q.mul_by_pow_2(4); - Q = &Q + &lookup_table.select(scalar_digits[i]); + /// Perform constant-time, variable-base scalar multiplication. + pub fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { + // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] + let lookup_table = LookupTable::::from(point); + // Setting s = scalar, compute + // + // s = s_0 + s_1*16^1 + ... + s_63*16^63, + // + // with `-8 ≤ s_i < 8` for `0 ≤ i < 63` and `-8 ≤ s_63 ≤ 8`. + let scalar_digits = scalar.as_radix_16(); + // Compute s*P as + // + // s*P = P*(s_0 + s_1*16^1 + s_2*16^2 + ... + s_63*16^63) + // s*P = P*s_0 + P*s_1*16^1 + P*s_2*16^2 + ... + P*s_63*16^63 + // s*P = P*s_0 + 16*(P*s_1 + 16*(P*s_2 + 16*( ... + P*s_63)...)) + // + // We sum right-to-left. + let mut Q = ExtendedPoint::identity(); + for i in (0..64).rev() { + Q = Q.mul_by_pow_2(4); + Q = &Q + &lookup_table.select(scalar_digits[i]); + } + Q.into() } - Q.into() -} - } diff --git a/src/backend/vector/scalar_mul/vartime_double_base.rs b/src/backend/vector/scalar_mul/vartime_double_base.rs index 842a729ef..191572bb1 100644 --- a/src/backend/vector/scalar_mul/vartime_double_base.rs +++ b/src/backend/vector/scalar_mul/vartime_double_base.rs @@ -17,85 +17,85 @@ )] pub mod spec { -use core::cmp::Ordering; + use core::cmp::Ordering; -#[for_target_feature("avx2")] -use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; -#[for_target_feature("avx512ifma")] -use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; -#[cfg(feature = "precomputed-tables")] -#[for_target_feature("avx2")] -use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; - -#[cfg(feature = "precomputed-tables")] -#[for_target_feature("avx512ifma")] -use crate::backend::vector::ifma::constants::BASEPOINT_ODD_LOOKUP_TABLE; - -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::Identity; -use crate::window::NafLookupTable5; - -/// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the Ed25519 basepoint. -pub fn mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { - let a_naf = a.non_adjacent_form(5); + #[cfg(feature = "precomputed-tables")] + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; #[cfg(feature = "precomputed-tables")] - let b_naf = b.non_adjacent_form(8); - #[cfg(not(feature = "precomputed-tables"))] - let b_naf = b.non_adjacent_form(5); - - // Find starting index - let mut i: usize = 255; - for j in (0..256).rev() { - i = j; - if a_naf[i] != 0 || b_naf[i] != 0 { - break; + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::constants::BASEPOINT_ODD_LOOKUP_TABLE; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::Identity; + use crate::window::NafLookupTable5; + + /// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the Ed25519 basepoint. + pub fn mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { + let a_naf = a.non_adjacent_form(5); + + #[cfg(feature = "precomputed-tables")] + let b_naf = b.non_adjacent_form(8); + #[cfg(not(feature = "precomputed-tables"))] + let b_naf = b.non_adjacent_form(5); + + // Find starting index + let mut i: usize = 255; + for j in (0..256).rev() { + i = j; + if a_naf[i] != 0 || b_naf[i] != 0 { + break; + } } - } - let table_A = NafLookupTable5::::from(A); + let table_A = NafLookupTable5::::from(A); - #[cfg(feature = "precomputed-tables")] - let table_B = &BASEPOINT_ODD_LOOKUP_TABLE; + #[cfg(feature = "precomputed-tables")] + let table_B = &BASEPOINT_ODD_LOOKUP_TABLE; - #[cfg(not(feature = "precomputed-tables"))] - let table_B = &NafLookupTable5::::from(&crate::constants::ED25519_BASEPOINT_POINT); + #[cfg(not(feature = "precomputed-tables"))] + let table_B = + &NafLookupTable5::::from(&crate::constants::ED25519_BASEPOINT_POINT); - let mut Q = ExtendedPoint::identity(); + let mut Q = ExtendedPoint::identity(); - loop { - Q = Q.double(); + loop { + Q = Q.double(); - match a_naf[i].cmp(&0) { - Ordering::Greater => { - Q = &Q + &table_A.select(a_naf[i] as usize); + match a_naf[i].cmp(&0) { + Ordering::Greater => { + Q = &Q + &table_A.select(a_naf[i] as usize); + } + Ordering::Less => { + Q = &Q - &table_A.select(-a_naf[i] as usize); + } + Ordering::Equal => {} } - Ordering::Less => { - Q = &Q - &table_A.select(-a_naf[i] as usize); - } - Ordering::Equal => {} - } - match b_naf[i].cmp(&0) { - Ordering::Greater => { - Q = &Q + &table_B.select(b_naf[i] as usize); + match b_naf[i].cmp(&0) { + Ordering::Greater => { + Q = &Q + &table_B.select(b_naf[i] as usize); + } + Ordering::Less => { + Q = &Q - &table_B.select(-b_naf[i] as usize); + } + Ordering::Equal => {} } - Ordering::Less => { - Q = &Q - &table_B.select(-b_naf[i] as usize); + + if i == 0 { + break; } - Ordering::Equal => {} + i -= 1; } - if i == 0 { - break; - } - i -= 1; + Q.into() } - - Q.into() -} - } From 1b6fee354d78455e71ac65c11f4c8e02fe41e0a2 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 11 Apr 2023 11:44:32 +0000 Subject: [PATCH 04/11] Make clippy happy --- src/backend/serial/u64/constants.rs | 2 +- src/constants.rs | 2 +- src/scalar.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/backend/serial/u64/constants.rs b/src/backend/serial/u64/constants.rs index 1aaed3109..67d51492d 100644 --- a/src/backend/serial/u64/constants.rs +++ b/src/backend/serial/u64/constants.rs @@ -327,7 +327,7 @@ pub const EIGHT_TORSION_INNER_DOC_HIDDEN: [EdwardsPoint; 8] = [ /// Table containing precomputed multiples of the Ed25519 basepoint \\(B = (x, 4/5)\\). #[cfg(feature = "precomputed-tables")] -pub static ED25519_BASEPOINT_TABLE: &'static EdwardsBasepointTable = +pub static ED25519_BASEPOINT_TABLE: &EdwardsBasepointTable = &ED25519_BASEPOINT_TABLE_INNER_DOC_HIDDEN; /// Inner constant, used to avoid filling the docs with precomputed points. diff --git a/src/constants.rs b/src/constants.rs index 36cba9db1..93e226838 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -99,7 +99,7 @@ use crate::ristretto::RistrettoBasepointTable; /// The Ristretto basepoint, as a `RistrettoBasepointTable` for scalar multiplication. #[cfg(feature = "precomputed-tables")] -pub static RISTRETTO_BASEPOINT_TABLE: &'static RistrettoBasepointTable = unsafe { +pub static RISTRETTO_BASEPOINT_TABLE: &RistrettoBasepointTable = unsafe { // SAFETY: `RistrettoBasepointTable` is a `#[repr(transparent)]` newtype of // `EdwardsBasepointTable` &*(ED25519_BASEPOINT_TABLE as *const EdwardsBasepointTable as *const RistrettoBasepointTable) diff --git a/src/scalar.rs b/src/scalar.rs index 829f56021..406a12fee 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -184,7 +184,7 @@ cfg_if! { } /// The `Scalar` struct holds an element of \\(\mathbb Z / \ell\mathbb Z \\). -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Copy, Clone, Hash)] pub struct Scalar { /// `bytes` is a little-endian byte encoding of an integer representing a scalar modulo the From 996b1e9077bf0554706328151d87ae18811eeb9a Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 11 Apr 2023 11:48:58 +0000 Subject: [PATCH 05/11] Make cargodoc happy --- src/backend/mod.rs | 3 +++ src/backend/vector/scalar_mul/mod.rs | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 09cfaf8bc..18c8c2251 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -99,6 +99,7 @@ fn get_selected_backend() -> BackendKind { BackendKind::Serial } +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub fn pippenger_optional_multiscalar_mul(scalars: I, points: J) -> Option where @@ -209,6 +210,7 @@ impl VartimePrecomputedStraus { } } +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub fn straus_multiscalar_mul(scalars: I, points: J) -> EdwardsPoint where @@ -249,6 +251,7 @@ where } } +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub fn straus_optional_multiscalar_mul(scalars: I, points: J) -> Option where diff --git a/src/backend/vector/scalar_mul/mod.rs b/src/backend/vector/scalar_mul/mod.rs index 36a7047a2..fed3470e7 100644 --- a/src/backend/vector/scalar_mul/mod.rs +++ b/src/backend/vector/scalar_mul/mod.rs @@ -9,15 +9,22 @@ // - isis agora lovecruft // - Henry de Valence +//! Implementations of various multiplication algorithms for the SIMD backends. + +#[allow(missing_docs)] pub mod variable_base; +#[allow(missing_docs)] pub mod vartime_double_base; +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub mod straus; +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub mod precomputed_straus; +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub mod pippenger; From 738cfee02019c7bc3788d8529c0a6dd3a0c9f4e3 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Tue, 11 Apr 2023 11:55:47 +0000 Subject: [PATCH 06/11] Get rid of the `unused_unsafe` warning on old versions of Rust. --- build.rs | 7 +++++++ src/lib.rs | 1 + 2 files changed, 8 insertions(+) diff --git a/build.rs b/build.rs index 80c0eb1fb..04f4d9ca3 100644 --- a/build.rs +++ b/build.rs @@ -27,6 +27,13 @@ fn main() { { println!("cargo:rustc-cfg=nightly"); } + + let rustc_version = rustc_version::version().expect("failed to detect rustc version"); + if rustc_version.major == 1 && rustc_version.minor <= 64 { + // Old versions of Rust complain when you have an `unsafe fn` and you use `unsafe {}` inside, + // so for those we want to apply the `#[allow(unused_unsafe)]` attribute to get rid of that warning. + println!("cargo:rustc-cfg=allow_unused_unsafe"); + } } // Deterministic cfg(curve25519_dalek_bits) when this is not explicitly set. diff --git a/src/lib.rs b/src/lib.rs index ecbbe5a2b..f4d1d8223 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ )] #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg, doc_cfg_hide))] #![cfg_attr(docsrs, doc(cfg_hide(docsrs)))] +#![cfg_attr(allow_unused_unsafe, allow(unused_unsafe))] //------------------------------------------------------------------------ // Documentation: //------------------------------------------------------------------------ From a7df9c7918076b30f2aa2bbd5b2b865c56f1b391 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Wed, 17 May 2023 04:21:17 +0000 Subject: [PATCH 07/11] Remove `Self`s which don't compile anymore --- src/backend/vector/packed_simd.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/backend/vector/packed_simd.rs b/src/backend/vector/packed_simd.rs index 2491754e4..371410d6f 100644 --- a/src/backend/vector/packed_simd.rs +++ b/src/backend/vector/packed_simd.rs @@ -257,7 +257,7 @@ impl u64x4 { pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { unsafe { // _mm256_set_epi64 sets the underlying vector in reverse order of the args - Self(core::arch::x86_64::_mm256_set_epi64x( + u64x4(core::arch::x86_64::_mm256_set_epi64x( x3 as i64, x2 as i64, x1 as i64, x0 as i64, )) } @@ -267,7 +267,7 @@ impl u64x4 { #[unsafe_target_feature("avx2")] #[inline] pub fn splat(x: u64) -> u64x4 { - unsafe { Self(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } + unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } } } @@ -308,7 +308,7 @@ impl u32x8 { pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 { unsafe { // _mm256_set_epi32 sets the underlying vector in reverse order of the args - Self(core::arch::x86_64::_mm256_set_epi32( + u32x8(core::arch::x86_64::_mm256_set_epi32( x7 as i32, x6 as i32, x5 as i32, x4 as i32, x3 as i32, x2 as i32, x1 as i32, x0 as i32, )) @@ -319,7 +319,7 @@ impl u32x8 { #[unsafe_target_feature("avx2")] #[inline] pub fn splat(x: u32) -> u32x8 { - unsafe { Self(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } + unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } } } From c67e430cfdf9699cf9b90226ab08a3b48cadacc6 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Wed, 17 May 2023 05:06:26 +0000 Subject: [PATCH 08/11] (work-in-progress) Partially remove `unsafe_target_feature` --- src/backend/vector/avx2/field.rs | 383 ++++++++++++--------- src/backend/vector/packed_simd.rs | 25 +- src/backend/vector/scalar_mul/pippenger.rs | 330 +++++++++--------- 3 files changed, 406 insertions(+), 332 deletions(-) diff --git a/src/backend/vector/avx2/field.rs b/src/backend/vector/avx2/field.rs index bdb55efa5..5a81afc81 100644 --- a/src/backend/vector/avx2/field.rs +++ b/src/backend/vector/avx2/field.rs @@ -48,8 +48,6 @@ use crate::backend::vector::avx2::constants::{ P_TIMES_16_HI, P_TIMES_16_LO, P_TIMES_2_HI, P_TIMES_2_LO, }; -use unsafe_target_feature::unsafe_target_feature; - /// Unpack 32-bit lanes into 64-bit lanes: /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) @@ -59,9 +57,9 @@ use unsafe_target_feature::unsafe_target_feature; /// (a0, 0, b0, 0, c0, 0, d0, 0) /// (a1, 0, b1, 0, c1, 0, d1, 0) /// ``` -#[unsafe_target_feature("avx2")] -#[inline(always)] -fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { +#[target_feature(enable = "avx2")] +#[inline] +unsafe fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { let a: u32x8; let b: u32x8; let zero = u32x8::splat(0); @@ -83,9 +81,9 @@ fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) /// ``` -#[unsafe_target_feature("avx2")] -#[inline(always)] -fn repack_pair(x: u32x8, y: u32x8) -> u32x8 { +#[target_feature(enable = "avx2")] +#[inline] +unsafe fn repack_pair(x: u32x8, y: u32x8) -> u32x8 { unsafe { use core::arch::x86_64::_mm256_blend_epi32; use core::arch::x86_64::_mm256_shuffle_epi32; @@ -155,43 +153,62 @@ pub struct FieldElement2625x4(pub(crate) [u32x8; 5]); use subtle::Choice; use subtle::ConditionallySelectable; -#[unsafe_target_feature("avx2")] impl ConditionallySelectable for FieldElement2625x4 { + #[inline(always)] fn conditional_select( a: &FieldElement2625x4, b: &FieldElement2625x4, choice: Choice, ) -> FieldElement2625x4 { - let mask = (-(choice.unwrap_u8() as i32)) as u32; - let mask_vec = u32x8::splat(mask); - FieldElement2625x4([ - a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])), - a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])), - a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])), - a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])), - a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])), - ]) + #[target_feature(enable = "avx2")] + unsafe fn inner( + a: &FieldElement2625x4, + b: &FieldElement2625x4, + choice: Choice, + ) -> FieldElement2625x4 { + let mask = (-(choice.unwrap_u8() as i32)) as u32; + let mask_vec = u32x8::splat(mask); + FieldElement2625x4([ + a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])), + a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])), + a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])), + a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])), + a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])), + ]) + } + + unsafe { inner(a, b, choice) } } + #[inline(always)] fn conditional_assign(&mut self, other: &FieldElement2625x4, choice: Choice) { - let mask = (-(choice.unwrap_u8() as i32)) as u32; - let mask_vec = u32x8::splat(mask); - self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]); - self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]); - self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]); - self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]); - self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]); + #[target_feature(enable = "avx2")] + unsafe fn inner( + itself: &mut FieldElement2625x4, + other: &FieldElement2625x4, + choice: Choice, + ) { + let mask = (-(choice.unwrap_u8() as i32)) as u32; + let mask_vec = u32x8::splat(mask); + itself.0[0] ^= mask_vec & (itself.0[0] ^ other.0[0]); + itself.0[1] ^= mask_vec & (itself.0[1] ^ other.0[1]); + itself.0[2] ^= mask_vec & (itself.0[2] ^ other.0[2]); + itself.0[3] ^= mask_vec & (itself.0[3] ^ other.0[3]); + itself.0[4] ^= mask_vec & (itself.0[4] ^ other.0[4]); + } + + unsafe { inner(self, other, choice) } } } -#[unsafe_target_feature("avx2")] impl FieldElement2625x4 { pub const ZERO: FieldElement2625x4 = FieldElement2625x4([u32x8::splat_const::<0>(); 5]); /// Split this vector into an array of four (serial) field /// elements. #[rustfmt::skip] // keep alignment of extracted lanes - pub fn split(&self) -> [FieldElement51; 4] { + #[target_feature(enable = "avx2")] + pub unsafe fn split(&self) -> [FieldElement51; 4] { let mut out = [FieldElement51::ZERO; 4]; for i in 0..5 { let a_2i = self.0[i].extract::<0>() as u64; // @@ -218,7 +235,8 @@ impl FieldElement2625x4 { /// that when this function is inlined, LLVM is able to lower the /// shuffle using an immediate. #[inline] - pub fn shuffle(&self, control: Shuffle) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn shuffle(&self, control: Shuffle) -> FieldElement2625x4 { #[inline(always)] fn shuffle_lanes(x: u32x8, control: Shuffle) -> u32x8 { unsafe { @@ -258,7 +276,8 @@ impl FieldElement2625x4 { /// that this function can be inlined and LLVM can lower it to a /// blend instruction using an immediate. #[inline] - pub fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 { #[inline(always)] fn blend_lanes(x: u32x8, y: u32x8, control: Lanes) -> u32x8 { unsafe { @@ -322,7 +341,8 @@ impl FieldElement2625x4 { } /// Convenience wrapper around `new(x,x,x,x)`. - pub fn splat(x: &FieldElement51) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn splat(x: &FieldElement51) -> FieldElement2625x4 { FieldElement2625x4::new(x, x, x, x) } @@ -332,7 +352,8 @@ impl FieldElement2625x4 { /// /// The resulting `FieldElement2625x4` is bounded with \\( b < 0.0002 \\). #[rustfmt::skip] // keep alignment of computed lanes - pub fn new( + #[target_feature(enable = "avx2")] + pub unsafe fn new( x0: &FieldElement51, x1: &FieldElement51, x2: &FieldElement51, @@ -371,7 +392,8 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 1 \\). #[inline] - pub fn negate_lazy(&self) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn negate_lazy(&self) -> FieldElement2625x4 { // The limbs of self are bounded with b < 0.999, while the // smallest limb of 2*p is 67108845 > 2^{26+0.9999}, so // underflows are not possible. @@ -394,7 +416,8 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 1.6 \\). #[inline] - pub fn diff_sum(&self) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn diff_sum(&self) -> FieldElement2625x4 { // tmp1 = (B, A, D, C) let tmp1 = self.shuffle(Shuffle::BADC); // tmp2 = (-A, B, -C, D) @@ -409,7 +432,8 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 0.0002 \\). #[inline] - pub fn reduce(&self) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn reduce(&self) -> FieldElement2625x4 { let shifts = u32x8::new(26, 26, 25, 25, 26, 26, 25, 25); let masks = u32x8::new( (1 << 26) - 1, @@ -518,7 +542,8 @@ impl FieldElement2625x4 { /// The coefficients of the result are bounded with \\( b < 0.007 \\). #[inline] #[rustfmt::skip] // keep alignment of carry chain - fn reduce64(mut z: [u64x4; 10]) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + unsafe fn reduce64(mut z: [u64x4; 10]) -> FieldElement2625x4 { // These aren't const because splat isn't a const fn let LOW_25_BITS: u64x4 = u64x4::splat((1 << 25) - 1); let LOW_26_BITS: u64x4 = u64x4::splat((1 << 26) - 1); @@ -594,7 +619,8 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 0.007 \\). #[rustfmt::skip] // keep alignment of z* calculations - pub fn square_and_negate_D(&self) -> FieldElement2625x4 { + #[target_feature(enable = "avx2")] + pub unsafe fn square_and_negate_D(&self) -> FieldElement2625x4 { #[inline(always)] fn m(x: u32x8, y: u32x8) -> u64x4 { x.mul32(y) @@ -681,7 +707,6 @@ impl FieldElement2625x4 { } } -#[unsafe_target_feature("avx2")] impl Neg for FieldElement2625x4 { type Output = FieldElement2625x4; @@ -697,36 +722,46 @@ impl Neg for FieldElement2625x4 { /// # Postconditions /// /// The coefficients of the result are bounded with \\( b < 0.0002 \\). - #[inline] + #[inline(always)] fn neg(self) -> FieldElement2625x4 { - FieldElement2625x4([ - P_TIMES_16_LO - self.0[0], - P_TIMES_16_HI - self.0[1], - P_TIMES_16_HI - self.0[2], - P_TIMES_16_HI - self.0[3], - P_TIMES_16_HI - self.0[4], - ]) - .reduce() + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn inner(itself: FieldElement2625x4) -> FieldElement2625x4 { + FieldElement2625x4([ + P_TIMES_16_LO - itself.0[0], + P_TIMES_16_HI - itself.0[1], + P_TIMES_16_HI - itself.0[2], + P_TIMES_16_HI - itself.0[3], + P_TIMES_16_HI - itself.0[4], + ]) + .reduce() + } + + unsafe { inner(self) } } } -#[unsafe_target_feature("avx2")] impl Add for FieldElement2625x4 { type Output = FieldElement2625x4; /// Add two `FieldElement2625x4`s, without performing a reduction. - #[inline] + #[inline(always)] fn add(self, rhs: FieldElement2625x4) -> FieldElement2625x4 { - FieldElement2625x4([ - self.0[0] + rhs.0[0], - self.0[1] + rhs.0[1], - self.0[2] + rhs.0[2], - self.0[3] + rhs.0[3], - self.0[4] + rhs.0[4], - ]) + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn inner(itself: FieldElement2625x4, rhs: FieldElement2625x4) -> FieldElement2625x4 { + FieldElement2625x4([ + itself.0[0] + rhs.0[0], + itself.0[1] + rhs.0[1], + itself.0[2] + rhs.0[2], + itself.0[3] + rhs.0[3], + itself.0[4] + rhs.0[4], + ]) + } + + unsafe { inner(self, rhs) } } } -#[unsafe_target_feature("avx2")] impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { type Output = FieldElement2625x4; /// Perform a multiplication by a vector of small constants. @@ -734,32 +769,40 @@ impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { /// # Postconditions /// /// The coefficients of the result are bounded with \\( b < 0.007 \\). - #[inline] + #[inline(always)] fn mul(self, scalars: (u32, u32, u32, u32)) -> FieldElement2625x4 { - let consts = u32x8::new(scalars.0, 0, scalars.1, 0, scalars.2, 0, scalars.3, 0); - - let (b0, b1) = unpack_pair(self.0[0]); - let (b2, b3) = unpack_pair(self.0[1]); - let (b4, b5) = unpack_pair(self.0[2]); - let (b6, b7) = unpack_pair(self.0[3]); - let (b8, b9) = unpack_pair(self.0[4]); - - FieldElement2625x4::reduce64([ - b0.mul32(consts), - b1.mul32(consts), - b2.mul32(consts), - b3.mul32(consts), - b4.mul32(consts), - b5.mul32(consts), - b6.mul32(consts), - b7.mul32(consts), - b8.mul32(consts), - b9.mul32(consts), - ]) + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn inner( + itself: FieldElement2625x4, + scalars: (u32, u32, u32, u32), + ) -> FieldElement2625x4 { + let consts = u32x8::new(scalars.0, 0, scalars.1, 0, scalars.2, 0, scalars.3, 0); + + let (b0, b1) = unpack_pair(itself.0[0]); + let (b2, b3) = unpack_pair(itself.0[1]); + let (b4, b5) = unpack_pair(itself.0[2]); + let (b6, b7) = unpack_pair(itself.0[3]); + let (b8, b9) = unpack_pair(itself.0[4]); + + FieldElement2625x4::reduce64([ + b0.mul32(consts), + b1.mul32(consts), + b2.mul32(consts), + b3.mul32(consts), + b4.mul32(consts), + b5.mul32(consts), + b6.mul32(consts), + b7.mul32(consts), + b8.mul32(consts), + b9.mul32(consts), + ]) + } + + unsafe { inner(self, scalars) } } } -#[unsafe_target_feature("avx2")] impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { type Output = FieldElement2625x4; /// Multiply `self` by `rhs`. @@ -775,98 +818,106 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { /// The coefficients of the result are bounded with \\( b < 0.007 \\). /// #[rustfmt::skip] // keep alignment of z* calculations - #[inline] + #[inline(always)] fn mul(self, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 { - #[inline(always)] - fn m(x: u32x8, y: u32x8) -> u64x4 { - x.mul32(y) - } - - #[inline(always)] - fn m_lo(x: u32x8, y: u32x8) -> u32x8 { - x.mul32(y).into() - } - - let (x0, x1) = unpack_pair(self.0[0]); - let (x2, x3) = unpack_pair(self.0[1]); - let (x4, x5) = unpack_pair(self.0[2]); - let (x6, x7) = unpack_pair(self.0[3]); - let (x8, x9) = unpack_pair(self.0[4]); + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn inner<'a, 'b>(itself: &'a FieldElement2625x4, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 { + #[inline(always)] + fn m(x: u32x8, y: u32x8) -> u64x4 { + x.mul32(y) + } - let (y0, y1) = unpack_pair(rhs.0[0]); - let (y2, y3) = unpack_pair(rhs.0[1]); - let (y4, y5) = unpack_pair(rhs.0[2]); - let (y6, y7) = unpack_pair(rhs.0[3]); - let (y8, y9) = unpack_pair(rhs.0[4]); + #[inline(always)] + fn m_lo(x: u32x8, y: u32x8) -> u32x8 { + x.mul32(y).into() + } - let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0); + let (x0, x1) = unpack_pair(itself.0[0]); + let (x2, x3) = unpack_pair(itself.0[1]); + let (x4, x5) = unpack_pair(itself.0[2]); + let (x6, x7) = unpack_pair(itself.0[3]); + let (x8, x9) = unpack_pair(itself.0[4]); + + let (y0, y1) = unpack_pair(rhs.0[0]); + let (y2, y3) = unpack_pair(rhs.0[1]); + let (y4, y5) = unpack_pair(rhs.0[2]); + let (y6, y7) = unpack_pair(rhs.0[3]); + let (y8, y9) = unpack_pair(rhs.0[4]); + + let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0); + + let y1_19 = m_lo(v19, y1); // This fits in a u32 + let y2_19 = m_lo(v19, y2); // iff 26 + b + lg(19) < 32 + let y3_19 = m_lo(v19, y3); // if b < 32 - 26 - 4.248 = 1.752 + let y4_19 = m_lo(v19, y4); + let y5_19 = m_lo(v19, y5); + let y6_19 = m_lo(v19, y6); + let y7_19 = m_lo(v19, y7); + let y8_19 = m_lo(v19, y8); + let y9_19 = m_lo(v19, y9); + + let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32 + let x3_2 = x3 + x3; // iff b < 6 + let x5_2 = x5 + x5; + let x7_2 = x7 + x7; + let x9_2 = x9 + x9; + + let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19); + let z1 = m(x0, y1) + m(x1, y0) + m(x2, y9_19) + m(x3, y8_19) + m(x4, y7_19) + m(x5, y6_19) + m(x6, y5_19) + m(x7, y4_19) + m(x8, y3_19) + m(x9, y2_19); + let z2 = m(x0, y2) + m(x1_2, y1) + m(x2, y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19); + let z3 = m(x0, y3) + m(x1, y2) + m(x2, y1) + m(x3, y0) + m(x4, y9_19) + m(x5, y8_19) + m(x6, y7_19) + m(x7, y6_19) + m(x8, y5_19) + m(x9, y4_19); + let z4 = m(x0, y4) + m(x1_2, y3) + m(x2, y2) + m(x3_2, y1) + m(x4, y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19); + let z5 = m(x0, y5) + m(x1, y4) + m(x2, y3) + m(x3, y2) + m(x4, y1) + m(x5, y0) + m(x6, y9_19) + m(x7, y8_19) + m(x8, y7_19) + m(x9, y6_19); + let z6 = m(x0, y6) + m(x1_2, y5) + m(x2, y4) + m(x3_2, y3) + m(x4, y2) + m(x5_2, y1) + m(x6, y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19); + let z7 = m(x0, y7) + m(x1, y6) + m(x2, y5) + m(x3, y4) + m(x4, y3) + m(x5, y2) + m(x6, y1) + m(x7, y0) + m(x8, y9_19) + m(x9, y8_19); + let z8 = m(x0, y8) + m(x1_2, y7) + m(x2, y6) + m(x3_2, y5) + m(x4, y4) + m(x5_2, y3) + m(x6, y2) + m(x7_2, y1) + m(x8, y0) + m(x9_2, y9_19); + let z9 = m(x0, y9) + m(x1, y8) + m(x2, y7) + m(x3, y6) + m(x4, y5) + m(x5, y4) + m(x6, y3) + m(x7, y2) + m(x8, y1) + m(x9, y0); + + // The bounds on z[i] are the same as in the serial 32-bit code + // and the comment below is copied from there: + + // How big is the contribution to z[i+j] from x[i], y[j]? + // + // Using the bounds above, we get: + // + // i even, j even: x[i]*y[j] < 2^(26+b)*2^(26+b) = 2*2^(51+2*b) + // i odd, j even: x[i]*y[j] < 2^(25+b)*2^(26+b) = 1*2^(51+2*b) + // i even, j odd: x[i]*y[j] < 2^(26+b)*2^(25+b) = 1*2^(51+2*b) + // i odd, j odd: 2*x[i]*y[j] < 2*2^(25+b)*2^(25+b) = 1*2^(51+2*b) + // + // We perform inline reduction mod p by replacing 2^255 by 19 + // (since 2^255 - 19 = 0 mod p). This adds a factor of 19, so + // we get the bounds (z0 is the biggest one, but calculated for + // posterity here in case finer estimation is needed later): + // + // z0 < ( 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 249*2^(51 + 2*b) + // z1 < ( 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 154*2^(51 + 2*b) + // z2 < ( 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 195*2^(51 + 2*b) + // z3 < ( 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 118*2^(51 + 2*b) + // z4 < ( 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 141*2^(51 + 2*b) + // z5 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 82*2^(51 + 2*b) + // z6 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 87*2^(51 + 2*b) + // z7 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 )*2^(51 + 2b) = 46*2^(51 + 2*b) + // z8 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 )*2^(51 + 2b) = 33*2^(51 + 2*b) + // z9 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 )*2^(51 + 2b) = 10*2^(51 + 2*b) + // + // So z[0] fits into a u64 if 51 + 2*b + lg(249) < 64 + // if b < 2.5. + + // In fact this bound is slightly sloppy, since it treats both + // inputs x and y as being bounded by the same parameter b, + // while they are in fact bounded by b_x and b_y, and we + // already require that b_y < 1.75 in order to fit the + // multiplications by 19 into a u32. The tighter bound on b_y + // means we could get a tighter bound on the outputs, or a + // looser bound on b_x. + FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9]) + } - let y1_19 = m_lo(v19, y1); // This fits in a u32 - let y2_19 = m_lo(v19, y2); // iff 26 + b + lg(19) < 32 - let y3_19 = m_lo(v19, y3); // if b < 32 - 26 - 4.248 = 1.752 - let y4_19 = m_lo(v19, y4); - let y5_19 = m_lo(v19, y5); - let y6_19 = m_lo(v19, y6); - let y7_19 = m_lo(v19, y7); - let y8_19 = m_lo(v19, y8); - let y9_19 = m_lo(v19, y9); - - let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32 - let x3_2 = x3 + x3; // iff b < 6 - let x5_2 = x5 + x5; - let x7_2 = x7 + x7; - let x9_2 = x9 + x9; - - let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19); - let z1 = m(x0, y1) + m(x1, y0) + m(x2, y9_19) + m(x3, y8_19) + m(x4, y7_19) + m(x5, y6_19) + m(x6, y5_19) + m(x7, y4_19) + m(x8, y3_19) + m(x9, y2_19); - let z2 = m(x0, y2) + m(x1_2, y1) + m(x2, y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19); - let z3 = m(x0, y3) + m(x1, y2) + m(x2, y1) + m(x3, y0) + m(x4, y9_19) + m(x5, y8_19) + m(x6, y7_19) + m(x7, y6_19) + m(x8, y5_19) + m(x9, y4_19); - let z4 = m(x0, y4) + m(x1_2, y3) + m(x2, y2) + m(x3_2, y1) + m(x4, y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19); - let z5 = m(x0, y5) + m(x1, y4) + m(x2, y3) + m(x3, y2) + m(x4, y1) + m(x5, y0) + m(x6, y9_19) + m(x7, y8_19) + m(x8, y7_19) + m(x9, y6_19); - let z6 = m(x0, y6) + m(x1_2, y5) + m(x2, y4) + m(x3_2, y3) + m(x4, y2) + m(x5_2, y1) + m(x6, y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19); - let z7 = m(x0, y7) + m(x1, y6) + m(x2, y5) + m(x3, y4) + m(x4, y3) + m(x5, y2) + m(x6, y1) + m(x7, y0) + m(x8, y9_19) + m(x9, y8_19); - let z8 = m(x0, y8) + m(x1_2, y7) + m(x2, y6) + m(x3_2, y5) + m(x4, y4) + m(x5_2, y3) + m(x6, y2) + m(x7_2, y1) + m(x8, y0) + m(x9_2, y9_19); - let z9 = m(x0, y9) + m(x1, y8) + m(x2, y7) + m(x3, y6) + m(x4, y5) + m(x5, y4) + m(x6, y3) + m(x7, y2) + m(x8, y1) + m(x9, y0); - - // The bounds on z[i] are the same as in the serial 32-bit code - // and the comment below is copied from there: - - // How big is the contribution to z[i+j] from x[i], y[j]? - // - // Using the bounds above, we get: - // - // i even, j even: x[i]*y[j] < 2^(26+b)*2^(26+b) = 2*2^(51+2*b) - // i odd, j even: x[i]*y[j] < 2^(25+b)*2^(26+b) = 1*2^(51+2*b) - // i even, j odd: x[i]*y[j] < 2^(26+b)*2^(25+b) = 1*2^(51+2*b) - // i odd, j odd: 2*x[i]*y[j] < 2*2^(25+b)*2^(25+b) = 1*2^(51+2*b) - // - // We perform inline reduction mod p by replacing 2^255 by 19 - // (since 2^255 - 19 = 0 mod p). This adds a factor of 19, so - // we get the bounds (z0 is the biggest one, but calculated for - // posterity here in case finer estimation is needed later): - // - // z0 < ( 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 249*2^(51 + 2*b) - // z1 < ( 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 154*2^(51 + 2*b) - // z2 < ( 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 195*2^(51 + 2*b) - // z3 < ( 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 118*2^(51 + 2*b) - // z4 < ( 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 141*2^(51 + 2*b) - // z5 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 82*2^(51 + 2*b) - // z6 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 87*2^(51 + 2*b) - // z7 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 )*2^(51 + 2b) = 46*2^(51 + 2*b) - // z8 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 )*2^(51 + 2b) = 33*2^(51 + 2*b) - // z9 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 )*2^(51 + 2b) = 10*2^(51 + 2*b) - // - // So z[0] fits into a u64 if 51 + 2*b + lg(249) < 64 - // if b < 2.5. - - // In fact this bound is slightly sloppy, since it treats both - // inputs x and y as being bounded by the same parameter b, - // while they are in fact bounded by b_x and b_y, and we - // already require that b_y < 1.75 in order to fit the - // multiplications by 19 into a u32. The tighter bound on b_y - // means we could get a tighter bound on the outputs, or a - // looser bound on b_x. - FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9]) + unsafe { + inner(self, rhs) + } } } diff --git a/src/backend/vector/packed_simd.rs b/src/backend/vector/packed_simd.rs index 371410d6f..201f1eb0f 100644 --- a/src/backend/vector/packed_simd.rs +++ b/src/backend/vector/packed_simd.rs @@ -252,9 +252,9 @@ impl u64x4 { } /// Constructs a new instance. - #[unsafe_target_feature("avx2")] + #[target_feature(enable = "avx2")] #[inline] - pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { + pub unsafe fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { unsafe { // _mm256_set_epi64 sets the underlying vector in reverse order of the args u64x4(core::arch::x86_64::_mm256_set_epi64x( @@ -264,9 +264,9 @@ impl u64x4 { } /// Constructs a new instance with all of the elements initialized to the given value. - #[unsafe_target_feature("avx2")] + #[target_feature(enable = "avx2")] #[inline] - pub fn splat(x: u64) -> u64x4 { + pub unsafe fn splat(x: u64) -> u64x4 { unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } } } @@ -303,9 +303,18 @@ impl u32x8 { /// Constructs a new instance. #[allow(clippy::too_many_arguments)] - #[unsafe_target_feature("avx2")] + #[target_feature(enable = "avx2")] #[inline] - pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 { + pub unsafe fn new( + x0: u32, + x1: u32, + x2: u32, + x3: u32, + x4: u32, + x5: u32, + x6: u32, + x7: u32, + ) -> u32x8 { unsafe { // _mm256_set_epi32 sets the underlying vector in reverse order of the args u32x8(core::arch::x86_64::_mm256_set_epi32( @@ -316,9 +325,9 @@ impl u32x8 { } /// Constructs a new instance with all of the elements initialized to the given value. - #[unsafe_target_feature("avx2")] + #[target_feature(enable = "avx2")] #[inline] - pub fn splat(x: u32) -> u32x8 { + pub unsafe fn splat(x: u32) -> u32x8 { unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } } } diff --git a/src/backend/vector/scalar_mul/pippenger.rs b/src/backend/vector/scalar_mul/pippenger.rs index b00cb87c5..8ac06f8f5 100644 --- a/src/backend/vector/scalar_mul/pippenger.rs +++ b/src/backend/vector/scalar_mul/pippenger.rs @@ -9,169 +9,183 @@ #![allow(non_snake_case)] -#[unsafe_target_feature::unsafe_target_feature_specialize( - conditional("avx2", feature = "simd_avx2"), - conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) -)] -pub mod spec { - - use alloc::vec::Vec; - - use core::borrow::Borrow; - use core::cmp::Ordering; - - #[for_target_feature("avx2")] - use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; - - #[for_target_feature("avx512ifma")] - use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; - - use crate::edwards::EdwardsPoint; - use crate::scalar::Scalar; - use crate::traits::{Identity, VartimeMultiscalarMul}; - - /// Implements a version of Pippenger's algorithm. - /// - /// See the documentation in the serial `scalar_mul::pippenger` module for details. - pub struct Pippenger; - - impl VartimeMultiscalarMul for Pippenger { - type Point = EdwardsPoint; - - fn optional_multiscalar_mul(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - let mut scalars = scalars.into_iter(); - let size = scalars.by_ref().size_hint().0; - let w = if size < 500 { - 6 - } else if size < 800 { - 7 - } else { - 8 - }; - - let max_digit: usize = 1 << w; - let digits_count: usize = Scalar::to_radix_2w_size_hint(w); - let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket - - // Collect optimized scalars and points in a buffer for repeated access - // (scanning the whole collection per each digit position). - let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); - - let points = points - .into_iter() - .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); - - let scalars_points = scalars - .zip(points) - .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) - .collect::>>()?; - - // Prepare 2^w/2 buckets. - // buckets[i] corresponds to a multiplication factor (i+1). - let mut buckets: Vec = (0..buckets_count) - .map(|_| ExtendedPoint::identity()) - .collect(); - - let mut columns = (0..digits_count).rev().map(|digit_index| { - // Clear the buckets when processing another digit. - for bucket in &mut buckets { - *bucket = ExtendedPoint::identity(); - } +macro_rules! implement { + ($module:ident, $backend_module:ident, $features:expr) => { + pub mod $module { + use alloc::vec::Vec; - // Iterate over pairs of (point, scalar) - // and add/sub the point to the corresponding bucket. - // Note: if we add support for precomputed lookup tables, - // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. - for (digits, pt) in scalars_points.iter() { - // Widen digit so that we don't run into edge cases when w=8. - let digit = digits[digit_index] as i16; - match digit.cmp(&0) { - Ordering::Greater => { - let b = (digit - 1) as usize; - buckets[b] = &buckets[b] + pt; - } - Ordering::Less => { - let b = (-digit - 1) as usize; - buckets[b] = &buckets[b] - pt; - } - Ordering::Equal => {} - } - } + use core::borrow::Borrow; + use core::cmp::Ordering; - // Add the buckets applying the multiplication factor to each bucket. - // The most efficient way to do that is to have a single sum with two running sums: - // an intermediate sum from last bucket to the first, and a sum of intermediate sums. - // - // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: - // C - // C B - // C B A Sum = C + (C+B) + (C+B+A) - let mut buckets_intermediate_sum = buckets[buckets_count - 1]; - let mut buckets_sum = buckets[buckets_count - 1]; - for i in (0..(buckets_count - 1)).rev() { - buckets_intermediate_sum = - &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); - buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); - } + use crate::backend::vector::$backend_module::{CachedPoint, ExtendedPoint}; - buckets_sum - }); - - // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. - // `unwrap()` always succeeds because we know we have more than zero digits. - let hi_column = columns.next().unwrap(); - - Some( - columns - .fold(hi_column, |total, p| { - &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) - }) - .into(), - ) - } - } - - #[cfg(test)] - mod test { - #[test] - fn test_vartime_pippenger() { - use super::*; - use crate::constants; + use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; + use crate::traits::{Identity, VartimeMultiscalarMul}; + + /// Implements a version of Pippenger's algorithm. + /// + /// See the documentation in the serial `scalar_mul::pippenger` module for details. + pub struct Pippenger; + + impl VartimeMultiscalarMul for Pippenger { + type Point = EdwardsPoint; + + #[inline(always)] + fn optional_multiscalar_mul(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + #[target_feature(enable = $features)] + unsafe fn inner(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + let mut scalars = scalars.into_iter(); + let size = scalars.by_ref().size_hint().0; + let w = if size < 500 { + 6 + } else if size < 800 { + 7 + } else { + 8 + }; + + let max_digit: usize = 1 << w; + let digits_count: usize = Scalar::to_radix_2w_size_hint(w); + let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket + + // Collect optimized scalars and points in a buffer for repeated access + // (scanning the whole collection per each digit position). + let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); + + let points = points + .into_iter() + .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); + + let scalars_points = scalars + .zip(points) + .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) + .collect::>>()?; + + // Prepare 2^w/2 buckets. + // buckets[i] corresponds to a multiplication factor (i+1). + let mut buckets: Vec = (0..buckets_count) + .map(|_| ExtendedPoint::identity()) + .collect(); + + let mut columns = (0..digits_count).rev().map(|digit_index| { + // Clear the buckets when processing another digit. + for bucket in &mut buckets { + *bucket = ExtendedPoint::identity(); + } + + // Iterate over pairs of (point, scalar) + // and add/sub the point to the corresponding bucket. + // Note: if we add support for precomputed lookup tables, + // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. + for (digits, pt) in scalars_points.iter() { + // Widen digit so that we don't run into edge cases when w=8. + let digit = digits[digit_index] as i16; + match digit.cmp(&0) { + Ordering::Greater => { + let b = (digit - 1) as usize; + buckets[b] = &buckets[b] + pt; + } + Ordering::Less => { + let b = (-digit - 1) as usize; + buckets[b] = &buckets[b] - pt; + } + Ordering::Equal => {} + } + } + + // Add the buckets applying the multiplication factor to each bucket. + // The most efficient way to do that is to have a single sum with two running sums: + // an intermediate sum from last bucket to the first, and a sum of intermediate sums. + // + // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: + // C + // C B + // C B A Sum = C + (C+B) + (C+B+A) + let mut buckets_intermediate_sum = buckets[buckets_count - 1]; + let mut buckets_sum = buckets[buckets_count - 1]; + for i in (0..(buckets_count - 1)).rev() { + buckets_intermediate_sum = + &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); + buckets_sum = + &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); + } + + buckets_sum + }); + + // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. + // `unwrap()` always succeeds because we know we have more than zero digits. + let hi_column = columns.next().unwrap(); + + Some( + columns + .fold(hi_column, |total, p| { + &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) + }) + .into(), + ) + } + unsafe { inner(scalars, points) } + } + } - // Reuse points across different tests - let mut n = 512; - let x = Scalar::from(2128506u64).invert(); - let y = Scalar::from(4443282u64).invert(); - let points: Vec<_> = (0..n) - .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) - .collect(); - let scalars: Vec<_> = (0..n) - .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars - .collect(); - - let premultiplied: Vec = scalars - .iter() - .zip(points.iter()) - .map(|(sc, pt)| sc * pt) - .collect(); - - while n > 0 { - let scalars = &scalars[0..n].to_vec(); - let points = &points[0..n].to_vec(); - let control: EdwardsPoint = premultiplied[0..n].iter().sum(); - - let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); - - assert_eq!(subject.compress(), control.compress()); - - n = n / 2; + #[cfg(test)] + #[cfg(target_feature = $features)] + mod test { + #[test] + fn test_vartime_pippenger() { + use super::*; + use crate::constants; + use crate::scalar::Scalar; + + // Reuse points across different tests + let mut n = 512; + let x = Scalar::from(2128506u64).invert(); + let y = Scalar::from(4443282u64).invert(); + let points: Vec<_> = (0..n) + .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) + .collect(); + let scalars: Vec<_> = (0..n) + .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars + .collect(); + + let premultiplied: Vec = scalars + .iter() + .zip(points.iter()) + .map(|(sc, pt)| sc * pt) + .collect(); + + while n > 0 { + let scalars = &scalars[0..n].to_vec(); + let points = &points[0..n].to_vec(); + let control: EdwardsPoint = premultiplied[0..n].iter().sum(); + + let subject = + Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); + + assert_eq!(subject.compress(), control.compress()); + + n = n / 2; + } + } } } - } + }; } + +#[cfg(feature = "simd_avx2")] +implement!(spec_avx2, avx2, "avx2"); + +#[cfg(all(feature = "simd_avx512", nightly))] +implement!(spec_avx512ifma_avx512vl, ifma, "avx512ifma,avx512vl"); From 94247a79d190155062149ec06a3ce263c0588eed Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Mon, 5 Jun 2023 07:38:55 +0000 Subject: [PATCH 09/11] Revert "(work-in-progress) Partially remove `unsafe_target_feature`" This reverts commit c67e430cfdf9699cf9b90226ab08a3b48cadacc6. --- src/backend/vector/avx2/field.rs | 383 +++++++++------------ src/backend/vector/packed_simd.rs | 25 +- src/backend/vector/scalar_mul/pippenger.rs | 330 +++++++++--------- 3 files changed, 332 insertions(+), 406 deletions(-) diff --git a/src/backend/vector/avx2/field.rs b/src/backend/vector/avx2/field.rs index 5a81afc81..bdb55efa5 100644 --- a/src/backend/vector/avx2/field.rs +++ b/src/backend/vector/avx2/field.rs @@ -48,6 +48,8 @@ use crate::backend::vector::avx2::constants::{ P_TIMES_16_HI, P_TIMES_16_LO, P_TIMES_2_HI, P_TIMES_2_LO, }; +use unsafe_target_feature::unsafe_target_feature; + /// Unpack 32-bit lanes into 64-bit lanes: /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) @@ -57,9 +59,9 @@ use crate::backend::vector::avx2::constants::{ /// (a0, 0, b0, 0, c0, 0, d0, 0) /// (a1, 0, b1, 0, c1, 0, d1, 0) /// ``` -#[target_feature(enable = "avx2")] -#[inline] -unsafe fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { +#[unsafe_target_feature("avx2")] +#[inline(always)] +fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { let a: u32x8; let b: u32x8; let zero = u32x8::splat(0); @@ -81,9 +83,9 @@ unsafe fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) /// ``` -#[target_feature(enable = "avx2")] -#[inline] -unsafe fn repack_pair(x: u32x8, y: u32x8) -> u32x8 { +#[unsafe_target_feature("avx2")] +#[inline(always)] +fn repack_pair(x: u32x8, y: u32x8) -> u32x8 { unsafe { use core::arch::x86_64::_mm256_blend_epi32; use core::arch::x86_64::_mm256_shuffle_epi32; @@ -153,62 +155,43 @@ pub struct FieldElement2625x4(pub(crate) [u32x8; 5]); use subtle::Choice; use subtle::ConditionallySelectable; +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for FieldElement2625x4 { - #[inline(always)] fn conditional_select( a: &FieldElement2625x4, b: &FieldElement2625x4, choice: Choice, ) -> FieldElement2625x4 { - #[target_feature(enable = "avx2")] - unsafe fn inner( - a: &FieldElement2625x4, - b: &FieldElement2625x4, - choice: Choice, - ) -> FieldElement2625x4 { - let mask = (-(choice.unwrap_u8() as i32)) as u32; - let mask_vec = u32x8::splat(mask); - FieldElement2625x4([ - a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])), - a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])), - a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])), - a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])), - a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])), - ]) - } - - unsafe { inner(a, b, choice) } + let mask = (-(choice.unwrap_u8() as i32)) as u32; + let mask_vec = u32x8::splat(mask); + FieldElement2625x4([ + a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])), + a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])), + a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])), + a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])), + a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])), + ]) } - #[inline(always)] fn conditional_assign(&mut self, other: &FieldElement2625x4, choice: Choice) { - #[target_feature(enable = "avx2")] - unsafe fn inner( - itself: &mut FieldElement2625x4, - other: &FieldElement2625x4, - choice: Choice, - ) { - let mask = (-(choice.unwrap_u8() as i32)) as u32; - let mask_vec = u32x8::splat(mask); - itself.0[0] ^= mask_vec & (itself.0[0] ^ other.0[0]); - itself.0[1] ^= mask_vec & (itself.0[1] ^ other.0[1]); - itself.0[2] ^= mask_vec & (itself.0[2] ^ other.0[2]); - itself.0[3] ^= mask_vec & (itself.0[3] ^ other.0[3]); - itself.0[4] ^= mask_vec & (itself.0[4] ^ other.0[4]); - } - - unsafe { inner(self, other, choice) } + let mask = (-(choice.unwrap_u8() as i32)) as u32; + let mask_vec = u32x8::splat(mask); + self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]); + self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]); + self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]); + self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]); + self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]); } } +#[unsafe_target_feature("avx2")] impl FieldElement2625x4 { pub const ZERO: FieldElement2625x4 = FieldElement2625x4([u32x8::splat_const::<0>(); 5]); /// Split this vector into an array of four (serial) field /// elements. #[rustfmt::skip] // keep alignment of extracted lanes - #[target_feature(enable = "avx2")] - pub unsafe fn split(&self) -> [FieldElement51; 4] { + pub fn split(&self) -> [FieldElement51; 4] { let mut out = [FieldElement51::ZERO; 4]; for i in 0..5 { let a_2i = self.0[i].extract::<0>() as u64; // @@ -235,8 +218,7 @@ impl FieldElement2625x4 { /// that when this function is inlined, LLVM is able to lower the /// shuffle using an immediate. #[inline] - #[target_feature(enable = "avx2")] - pub unsafe fn shuffle(&self, control: Shuffle) -> FieldElement2625x4 { + pub fn shuffle(&self, control: Shuffle) -> FieldElement2625x4 { #[inline(always)] fn shuffle_lanes(x: u32x8, control: Shuffle) -> u32x8 { unsafe { @@ -276,8 +258,7 @@ impl FieldElement2625x4 { /// that this function can be inlined and LLVM can lower it to a /// blend instruction using an immediate. #[inline] - #[target_feature(enable = "avx2")] - pub unsafe fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 { + pub fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 { #[inline(always)] fn blend_lanes(x: u32x8, y: u32x8, control: Lanes) -> u32x8 { unsafe { @@ -341,8 +322,7 @@ impl FieldElement2625x4 { } /// Convenience wrapper around `new(x,x,x,x)`. - #[target_feature(enable = "avx2")] - pub unsafe fn splat(x: &FieldElement51) -> FieldElement2625x4 { + pub fn splat(x: &FieldElement51) -> FieldElement2625x4 { FieldElement2625x4::new(x, x, x, x) } @@ -352,8 +332,7 @@ impl FieldElement2625x4 { /// /// The resulting `FieldElement2625x4` is bounded with \\( b < 0.0002 \\). #[rustfmt::skip] // keep alignment of computed lanes - #[target_feature(enable = "avx2")] - pub unsafe fn new( + pub fn new( x0: &FieldElement51, x1: &FieldElement51, x2: &FieldElement51, @@ -392,8 +371,7 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 1 \\). #[inline] - #[target_feature(enable = "avx2")] - pub unsafe fn negate_lazy(&self) -> FieldElement2625x4 { + pub fn negate_lazy(&self) -> FieldElement2625x4 { // The limbs of self are bounded with b < 0.999, while the // smallest limb of 2*p is 67108845 > 2^{26+0.9999}, so // underflows are not possible. @@ -416,8 +394,7 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 1.6 \\). #[inline] - #[target_feature(enable = "avx2")] - pub unsafe fn diff_sum(&self) -> FieldElement2625x4 { + pub fn diff_sum(&self) -> FieldElement2625x4 { // tmp1 = (B, A, D, C) let tmp1 = self.shuffle(Shuffle::BADC); // tmp2 = (-A, B, -C, D) @@ -432,8 +409,7 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 0.0002 \\). #[inline] - #[target_feature(enable = "avx2")] - pub unsafe fn reduce(&self) -> FieldElement2625x4 { + pub fn reduce(&self) -> FieldElement2625x4 { let shifts = u32x8::new(26, 26, 25, 25, 26, 26, 25, 25); let masks = u32x8::new( (1 << 26) - 1, @@ -542,8 +518,7 @@ impl FieldElement2625x4 { /// The coefficients of the result are bounded with \\( b < 0.007 \\). #[inline] #[rustfmt::skip] // keep alignment of carry chain - #[target_feature(enable = "avx2")] - unsafe fn reduce64(mut z: [u64x4; 10]) -> FieldElement2625x4 { + fn reduce64(mut z: [u64x4; 10]) -> FieldElement2625x4 { // These aren't const because splat isn't a const fn let LOW_25_BITS: u64x4 = u64x4::splat((1 << 25) - 1); let LOW_26_BITS: u64x4 = u64x4::splat((1 << 26) - 1); @@ -619,8 +594,7 @@ impl FieldElement2625x4 { /// /// The coefficients of the result are bounded with \\( b < 0.007 \\). #[rustfmt::skip] // keep alignment of z* calculations - #[target_feature(enable = "avx2")] - pub unsafe fn square_and_negate_D(&self) -> FieldElement2625x4 { + pub fn square_and_negate_D(&self) -> FieldElement2625x4 { #[inline(always)] fn m(x: u32x8, y: u32x8) -> u64x4 { x.mul32(y) @@ -707,6 +681,7 @@ impl FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Neg for FieldElement2625x4 { type Output = FieldElement2625x4; @@ -722,46 +697,36 @@ impl Neg for FieldElement2625x4 { /// # Postconditions /// /// The coefficients of the result are bounded with \\( b < 0.0002 \\). - #[inline(always)] + #[inline] fn neg(self) -> FieldElement2625x4 { - #[inline] - #[target_feature(enable = "avx2")] - unsafe fn inner(itself: FieldElement2625x4) -> FieldElement2625x4 { - FieldElement2625x4([ - P_TIMES_16_LO - itself.0[0], - P_TIMES_16_HI - itself.0[1], - P_TIMES_16_HI - itself.0[2], - P_TIMES_16_HI - itself.0[3], - P_TIMES_16_HI - itself.0[4], - ]) - .reduce() - } - - unsafe { inner(self) } + FieldElement2625x4([ + P_TIMES_16_LO - self.0[0], + P_TIMES_16_HI - self.0[1], + P_TIMES_16_HI - self.0[2], + P_TIMES_16_HI - self.0[3], + P_TIMES_16_HI - self.0[4], + ]) + .reduce() } } +#[unsafe_target_feature("avx2")] impl Add for FieldElement2625x4 { type Output = FieldElement2625x4; /// Add two `FieldElement2625x4`s, without performing a reduction. - #[inline(always)] + #[inline] fn add(self, rhs: FieldElement2625x4) -> FieldElement2625x4 { - #[inline] - #[target_feature(enable = "avx2")] - unsafe fn inner(itself: FieldElement2625x4, rhs: FieldElement2625x4) -> FieldElement2625x4 { - FieldElement2625x4([ - itself.0[0] + rhs.0[0], - itself.0[1] + rhs.0[1], - itself.0[2] + rhs.0[2], - itself.0[3] + rhs.0[3], - itself.0[4] + rhs.0[4], - ]) - } - - unsafe { inner(self, rhs) } + FieldElement2625x4([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + self.0[4] + rhs.0[4], + ]) } } +#[unsafe_target_feature("avx2")] impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { type Output = FieldElement2625x4; /// Perform a multiplication by a vector of small constants. @@ -769,40 +734,32 @@ impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { /// # Postconditions /// /// The coefficients of the result are bounded with \\( b < 0.007 \\). - #[inline(always)] + #[inline] fn mul(self, scalars: (u32, u32, u32, u32)) -> FieldElement2625x4 { - #[inline] - #[target_feature(enable = "avx2")] - unsafe fn inner( - itself: FieldElement2625x4, - scalars: (u32, u32, u32, u32), - ) -> FieldElement2625x4 { - let consts = u32x8::new(scalars.0, 0, scalars.1, 0, scalars.2, 0, scalars.3, 0); - - let (b0, b1) = unpack_pair(itself.0[0]); - let (b2, b3) = unpack_pair(itself.0[1]); - let (b4, b5) = unpack_pair(itself.0[2]); - let (b6, b7) = unpack_pair(itself.0[3]); - let (b8, b9) = unpack_pair(itself.0[4]); - - FieldElement2625x4::reduce64([ - b0.mul32(consts), - b1.mul32(consts), - b2.mul32(consts), - b3.mul32(consts), - b4.mul32(consts), - b5.mul32(consts), - b6.mul32(consts), - b7.mul32(consts), - b8.mul32(consts), - b9.mul32(consts), - ]) - } - - unsafe { inner(self, scalars) } + let consts = u32x8::new(scalars.0, 0, scalars.1, 0, scalars.2, 0, scalars.3, 0); + + let (b0, b1) = unpack_pair(self.0[0]); + let (b2, b3) = unpack_pair(self.0[1]); + let (b4, b5) = unpack_pair(self.0[2]); + let (b6, b7) = unpack_pair(self.0[3]); + let (b8, b9) = unpack_pair(self.0[4]); + + FieldElement2625x4::reduce64([ + b0.mul32(consts), + b1.mul32(consts), + b2.mul32(consts), + b3.mul32(consts), + b4.mul32(consts), + b5.mul32(consts), + b6.mul32(consts), + b7.mul32(consts), + b8.mul32(consts), + b9.mul32(consts), + ]) } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { type Output = FieldElement2625x4; /// Multiply `self` by `rhs`. @@ -818,106 +775,98 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { /// The coefficients of the result are bounded with \\( b < 0.007 \\). /// #[rustfmt::skip] // keep alignment of z* calculations - #[inline(always)] + #[inline] fn mul(self, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 { - #[inline] - #[target_feature(enable = "avx2")] - unsafe fn inner<'a, 'b>(itself: &'a FieldElement2625x4, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 { - #[inline(always)] - fn m(x: u32x8, y: u32x8) -> u64x4 { - x.mul32(y) - } - - #[inline(always)] - fn m_lo(x: u32x8, y: u32x8) -> u32x8 { - x.mul32(y).into() - } - - let (x0, x1) = unpack_pair(itself.0[0]); - let (x2, x3) = unpack_pair(itself.0[1]); - let (x4, x5) = unpack_pair(itself.0[2]); - let (x6, x7) = unpack_pair(itself.0[3]); - let (x8, x9) = unpack_pair(itself.0[4]); - - let (y0, y1) = unpack_pair(rhs.0[0]); - let (y2, y3) = unpack_pair(rhs.0[1]); - let (y4, y5) = unpack_pair(rhs.0[2]); - let (y6, y7) = unpack_pair(rhs.0[3]); - let (y8, y9) = unpack_pair(rhs.0[4]); - - let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0); - - let y1_19 = m_lo(v19, y1); // This fits in a u32 - let y2_19 = m_lo(v19, y2); // iff 26 + b + lg(19) < 32 - let y3_19 = m_lo(v19, y3); // if b < 32 - 26 - 4.248 = 1.752 - let y4_19 = m_lo(v19, y4); - let y5_19 = m_lo(v19, y5); - let y6_19 = m_lo(v19, y6); - let y7_19 = m_lo(v19, y7); - let y8_19 = m_lo(v19, y8); - let y9_19 = m_lo(v19, y9); - - let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32 - let x3_2 = x3 + x3; // iff b < 6 - let x5_2 = x5 + x5; - let x7_2 = x7 + x7; - let x9_2 = x9 + x9; - - let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19); - let z1 = m(x0, y1) + m(x1, y0) + m(x2, y9_19) + m(x3, y8_19) + m(x4, y7_19) + m(x5, y6_19) + m(x6, y5_19) + m(x7, y4_19) + m(x8, y3_19) + m(x9, y2_19); - let z2 = m(x0, y2) + m(x1_2, y1) + m(x2, y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19); - let z3 = m(x0, y3) + m(x1, y2) + m(x2, y1) + m(x3, y0) + m(x4, y9_19) + m(x5, y8_19) + m(x6, y7_19) + m(x7, y6_19) + m(x8, y5_19) + m(x9, y4_19); - let z4 = m(x0, y4) + m(x1_2, y3) + m(x2, y2) + m(x3_2, y1) + m(x4, y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19); - let z5 = m(x0, y5) + m(x1, y4) + m(x2, y3) + m(x3, y2) + m(x4, y1) + m(x5, y0) + m(x6, y9_19) + m(x7, y8_19) + m(x8, y7_19) + m(x9, y6_19); - let z6 = m(x0, y6) + m(x1_2, y5) + m(x2, y4) + m(x3_2, y3) + m(x4, y2) + m(x5_2, y1) + m(x6, y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19); - let z7 = m(x0, y7) + m(x1, y6) + m(x2, y5) + m(x3, y4) + m(x4, y3) + m(x5, y2) + m(x6, y1) + m(x7, y0) + m(x8, y9_19) + m(x9, y8_19); - let z8 = m(x0, y8) + m(x1_2, y7) + m(x2, y6) + m(x3_2, y5) + m(x4, y4) + m(x5_2, y3) + m(x6, y2) + m(x7_2, y1) + m(x8, y0) + m(x9_2, y9_19); - let z9 = m(x0, y9) + m(x1, y8) + m(x2, y7) + m(x3, y6) + m(x4, y5) + m(x5, y4) + m(x6, y3) + m(x7, y2) + m(x8, y1) + m(x9, y0); - - // The bounds on z[i] are the same as in the serial 32-bit code - // and the comment below is copied from there: - - // How big is the contribution to z[i+j] from x[i], y[j]? - // - // Using the bounds above, we get: - // - // i even, j even: x[i]*y[j] < 2^(26+b)*2^(26+b) = 2*2^(51+2*b) - // i odd, j even: x[i]*y[j] < 2^(25+b)*2^(26+b) = 1*2^(51+2*b) - // i even, j odd: x[i]*y[j] < 2^(26+b)*2^(25+b) = 1*2^(51+2*b) - // i odd, j odd: 2*x[i]*y[j] < 2*2^(25+b)*2^(25+b) = 1*2^(51+2*b) - // - // We perform inline reduction mod p by replacing 2^255 by 19 - // (since 2^255 - 19 = 0 mod p). This adds a factor of 19, so - // we get the bounds (z0 is the biggest one, but calculated for - // posterity here in case finer estimation is needed later): - // - // z0 < ( 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 249*2^(51 + 2*b) - // z1 < ( 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 154*2^(51 + 2*b) - // z2 < ( 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 195*2^(51 + 2*b) - // z3 < ( 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 118*2^(51 + 2*b) - // z4 < ( 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 141*2^(51 + 2*b) - // z5 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 82*2^(51 + 2*b) - // z6 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 87*2^(51 + 2*b) - // z7 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 )*2^(51 + 2b) = 46*2^(51 + 2*b) - // z8 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 )*2^(51 + 2b) = 33*2^(51 + 2*b) - // z9 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 )*2^(51 + 2b) = 10*2^(51 + 2*b) - // - // So z[0] fits into a u64 if 51 + 2*b + lg(249) < 64 - // if b < 2.5. - - // In fact this bound is slightly sloppy, since it treats both - // inputs x and y as being bounded by the same parameter b, - // while they are in fact bounded by b_x and b_y, and we - // already require that b_y < 1.75 in order to fit the - // multiplications by 19 into a u32. The tighter bound on b_y - // means we could get a tighter bound on the outputs, or a - // looser bound on b_x. - FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9]) + #[inline(always)] + fn m(x: u32x8, y: u32x8) -> u64x4 { + x.mul32(y) } - unsafe { - inner(self, rhs) + #[inline(always)] + fn m_lo(x: u32x8, y: u32x8) -> u32x8 { + x.mul32(y).into() } + + let (x0, x1) = unpack_pair(self.0[0]); + let (x2, x3) = unpack_pair(self.0[1]); + let (x4, x5) = unpack_pair(self.0[2]); + let (x6, x7) = unpack_pair(self.0[3]); + let (x8, x9) = unpack_pair(self.0[4]); + + let (y0, y1) = unpack_pair(rhs.0[0]); + let (y2, y3) = unpack_pair(rhs.0[1]); + let (y4, y5) = unpack_pair(rhs.0[2]); + let (y6, y7) = unpack_pair(rhs.0[3]); + let (y8, y9) = unpack_pair(rhs.0[4]); + + let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0); + + let y1_19 = m_lo(v19, y1); // This fits in a u32 + let y2_19 = m_lo(v19, y2); // iff 26 + b + lg(19) < 32 + let y3_19 = m_lo(v19, y3); // if b < 32 - 26 - 4.248 = 1.752 + let y4_19 = m_lo(v19, y4); + let y5_19 = m_lo(v19, y5); + let y6_19 = m_lo(v19, y6); + let y7_19 = m_lo(v19, y7); + let y8_19 = m_lo(v19, y8); + let y9_19 = m_lo(v19, y9); + + let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32 + let x3_2 = x3 + x3; // iff b < 6 + let x5_2 = x5 + x5; + let x7_2 = x7 + x7; + let x9_2 = x9 + x9; + + let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19); + let z1 = m(x0, y1) + m(x1, y0) + m(x2, y9_19) + m(x3, y8_19) + m(x4, y7_19) + m(x5, y6_19) + m(x6, y5_19) + m(x7, y4_19) + m(x8, y3_19) + m(x9, y2_19); + let z2 = m(x0, y2) + m(x1_2, y1) + m(x2, y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19); + let z3 = m(x0, y3) + m(x1, y2) + m(x2, y1) + m(x3, y0) + m(x4, y9_19) + m(x5, y8_19) + m(x6, y7_19) + m(x7, y6_19) + m(x8, y5_19) + m(x9, y4_19); + let z4 = m(x0, y4) + m(x1_2, y3) + m(x2, y2) + m(x3_2, y1) + m(x4, y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19); + let z5 = m(x0, y5) + m(x1, y4) + m(x2, y3) + m(x3, y2) + m(x4, y1) + m(x5, y0) + m(x6, y9_19) + m(x7, y8_19) + m(x8, y7_19) + m(x9, y6_19); + let z6 = m(x0, y6) + m(x1_2, y5) + m(x2, y4) + m(x3_2, y3) + m(x4, y2) + m(x5_2, y1) + m(x6, y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19); + let z7 = m(x0, y7) + m(x1, y6) + m(x2, y5) + m(x3, y4) + m(x4, y3) + m(x5, y2) + m(x6, y1) + m(x7, y0) + m(x8, y9_19) + m(x9, y8_19); + let z8 = m(x0, y8) + m(x1_2, y7) + m(x2, y6) + m(x3_2, y5) + m(x4, y4) + m(x5_2, y3) + m(x6, y2) + m(x7_2, y1) + m(x8, y0) + m(x9_2, y9_19); + let z9 = m(x0, y9) + m(x1, y8) + m(x2, y7) + m(x3, y6) + m(x4, y5) + m(x5, y4) + m(x6, y3) + m(x7, y2) + m(x8, y1) + m(x9, y0); + + // The bounds on z[i] are the same as in the serial 32-bit code + // and the comment below is copied from there: + + // How big is the contribution to z[i+j] from x[i], y[j]? + // + // Using the bounds above, we get: + // + // i even, j even: x[i]*y[j] < 2^(26+b)*2^(26+b) = 2*2^(51+2*b) + // i odd, j even: x[i]*y[j] < 2^(25+b)*2^(26+b) = 1*2^(51+2*b) + // i even, j odd: x[i]*y[j] < 2^(26+b)*2^(25+b) = 1*2^(51+2*b) + // i odd, j odd: 2*x[i]*y[j] < 2*2^(25+b)*2^(25+b) = 1*2^(51+2*b) + // + // We perform inline reduction mod p by replacing 2^255 by 19 + // (since 2^255 - 19 = 0 mod p). This adds a factor of 19, so + // we get the bounds (z0 is the biggest one, but calculated for + // posterity here in case finer estimation is needed later): + // + // z0 < ( 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 249*2^(51 + 2*b) + // z1 < ( 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 154*2^(51 + 2*b) + // z2 < ( 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 195*2^(51 + 2*b) + // z3 < ( 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 118*2^(51 + 2*b) + // z4 < ( 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 141*2^(51 + 2*b) + // z5 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 82*2^(51 + 2*b) + // z6 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 87*2^(51 + 2*b) + // z7 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1*19 + 1*19 )*2^(51 + 2b) = 46*2^(51 + 2*b) + // z8 < ( 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1*19 )*2^(51 + 2b) = 33*2^(51 + 2*b) + // z9 < ( 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 )*2^(51 + 2b) = 10*2^(51 + 2*b) + // + // So z[0] fits into a u64 if 51 + 2*b + lg(249) < 64 + // if b < 2.5. + + // In fact this bound is slightly sloppy, since it treats both + // inputs x and y as being bounded by the same parameter b, + // while they are in fact bounded by b_x and b_y, and we + // already require that b_y < 1.75 in order to fit the + // multiplications by 19 into a u32. The tighter bound on b_y + // means we could get a tighter bound on the outputs, or a + // looser bound on b_x. + FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9]) } } diff --git a/src/backend/vector/packed_simd.rs b/src/backend/vector/packed_simd.rs index 201f1eb0f..371410d6f 100644 --- a/src/backend/vector/packed_simd.rs +++ b/src/backend/vector/packed_simd.rs @@ -252,9 +252,9 @@ impl u64x4 { } /// Constructs a new instance. - #[target_feature(enable = "avx2")] + #[unsafe_target_feature("avx2")] #[inline] - pub unsafe fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { + pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { unsafe { // _mm256_set_epi64 sets the underlying vector in reverse order of the args u64x4(core::arch::x86_64::_mm256_set_epi64x( @@ -264,9 +264,9 @@ impl u64x4 { } /// Constructs a new instance with all of the elements initialized to the given value. - #[target_feature(enable = "avx2")] + #[unsafe_target_feature("avx2")] #[inline] - pub unsafe fn splat(x: u64) -> u64x4 { + pub fn splat(x: u64) -> u64x4 { unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } } } @@ -303,18 +303,9 @@ impl u32x8 { /// Constructs a new instance. #[allow(clippy::too_many_arguments)] - #[target_feature(enable = "avx2")] + #[unsafe_target_feature("avx2")] #[inline] - pub unsafe fn new( - x0: u32, - x1: u32, - x2: u32, - x3: u32, - x4: u32, - x5: u32, - x6: u32, - x7: u32, - ) -> u32x8 { + pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 { unsafe { // _mm256_set_epi32 sets the underlying vector in reverse order of the args u32x8(core::arch::x86_64::_mm256_set_epi32( @@ -325,9 +316,9 @@ impl u32x8 { } /// Constructs a new instance with all of the elements initialized to the given value. - #[target_feature(enable = "avx2")] + #[unsafe_target_feature("avx2")] #[inline] - pub unsafe fn splat(x: u32) -> u32x8 { + pub fn splat(x: u32) -> u32x8 { unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } } } diff --git a/src/backend/vector/scalar_mul/pippenger.rs b/src/backend/vector/scalar_mul/pippenger.rs index 8ac06f8f5..b00cb87c5 100644 --- a/src/backend/vector/scalar_mul/pippenger.rs +++ b/src/backend/vector/scalar_mul/pippenger.rs @@ -9,183 +9,169 @@ #![allow(non_snake_case)] -macro_rules! implement { - ($module:ident, $backend_module:ident, $features:expr) => { - pub mod $module { - use alloc::vec::Vec; - - use core::borrow::Borrow; - use core::cmp::Ordering; - - use crate::backend::vector::$backend_module::{CachedPoint, ExtendedPoint}; +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + + use alloc::vec::Vec; + + use core::borrow::Borrow; + use core::cmp::Ordering; + + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::{Identity, VartimeMultiscalarMul}; + + /// Implements a version of Pippenger's algorithm. + /// + /// See the documentation in the serial `scalar_mul::pippenger` module for details. + pub struct Pippenger; + + impl VartimeMultiscalarMul for Pippenger { + type Point = EdwardsPoint; + + fn optional_multiscalar_mul(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + let mut scalars = scalars.into_iter(); + let size = scalars.by_ref().size_hint().0; + let w = if size < 500 { + 6 + } else if size < 800 { + 7 + } else { + 8 + }; + + let max_digit: usize = 1 << w; + let digits_count: usize = Scalar::to_radix_2w_size_hint(w); + let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket + + // Collect optimized scalars and points in a buffer for repeated access + // (scanning the whole collection per each digit position). + let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); + + let points = points + .into_iter() + .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); + + let scalars_points = scalars + .zip(points) + .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) + .collect::>>()?; + + // Prepare 2^w/2 buckets. + // buckets[i] corresponds to a multiplication factor (i+1). + let mut buckets: Vec = (0..buckets_count) + .map(|_| ExtendedPoint::identity()) + .collect(); + + let mut columns = (0..digits_count).rev().map(|digit_index| { + // Clear the buckets when processing another digit. + for bucket in &mut buckets { + *bucket = ExtendedPoint::identity(); + } - use crate::edwards::EdwardsPoint; - use crate::scalar::Scalar; - use crate::traits::{Identity, VartimeMultiscalarMul}; - - /// Implements a version of Pippenger's algorithm. - /// - /// See the documentation in the serial `scalar_mul::pippenger` module for details. - pub struct Pippenger; - - impl VartimeMultiscalarMul for Pippenger { - type Point = EdwardsPoint; - - #[inline(always)] - fn optional_multiscalar_mul(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - #[target_feature(enable = $features)] - unsafe fn inner(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - let mut scalars = scalars.into_iter(); - let size = scalars.by_ref().size_hint().0; - let w = if size < 500 { - 6 - } else if size < 800 { - 7 - } else { - 8 - }; - - let max_digit: usize = 1 << w; - let digits_count: usize = Scalar::to_radix_2w_size_hint(w); - let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket - - // Collect optimized scalars and points in a buffer for repeated access - // (scanning the whole collection per each digit position). - let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); - - let points = points - .into_iter() - .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); - - let scalars_points = scalars - .zip(points) - .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) - .collect::>>()?; - - // Prepare 2^w/2 buckets. - // buckets[i] corresponds to a multiplication factor (i+1). - let mut buckets: Vec = (0..buckets_count) - .map(|_| ExtendedPoint::identity()) - .collect(); - - let mut columns = (0..digits_count).rev().map(|digit_index| { - // Clear the buckets when processing another digit. - for bucket in &mut buckets { - *bucket = ExtendedPoint::identity(); - } - - // Iterate over pairs of (point, scalar) - // and add/sub the point to the corresponding bucket. - // Note: if we add support for precomputed lookup tables, - // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. - for (digits, pt) in scalars_points.iter() { - // Widen digit so that we don't run into edge cases when w=8. - let digit = digits[digit_index] as i16; - match digit.cmp(&0) { - Ordering::Greater => { - let b = (digit - 1) as usize; - buckets[b] = &buckets[b] + pt; - } - Ordering::Less => { - let b = (-digit - 1) as usize; - buckets[b] = &buckets[b] - pt; - } - Ordering::Equal => {} - } - } - - // Add the buckets applying the multiplication factor to each bucket. - // The most efficient way to do that is to have a single sum with two running sums: - // an intermediate sum from last bucket to the first, and a sum of intermediate sums. - // - // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: - // C - // C B - // C B A Sum = C + (C+B) + (C+B+A) - let mut buckets_intermediate_sum = buckets[buckets_count - 1]; - let mut buckets_sum = buckets[buckets_count - 1]; - for i in (0..(buckets_count - 1)).rev() { - buckets_intermediate_sum = - &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); - buckets_sum = - &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); - } - - buckets_sum - }); - - // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. - // `unwrap()` always succeeds because we know we have more than zero digits. - let hi_column = columns.next().unwrap(); - - Some( - columns - .fold(hi_column, |total, p| { - &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) - }) - .into(), - ) + // Iterate over pairs of (point, scalar) + // and add/sub the point to the corresponding bucket. + // Note: if we add support for precomputed lookup tables, + // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. + for (digits, pt) in scalars_points.iter() { + // Widen digit so that we don't run into edge cases when w=8. + let digit = digits[digit_index] as i16; + match digit.cmp(&0) { + Ordering::Greater => { + let b = (digit - 1) as usize; + buckets[b] = &buckets[b] + pt; + } + Ordering::Less => { + let b = (-digit - 1) as usize; + buckets[b] = &buckets[b] - pt; + } + Ordering::Equal => {} } - unsafe { inner(scalars, points) } } - } - #[cfg(test)] - #[cfg(target_feature = $features)] - mod test { - #[test] - fn test_vartime_pippenger() { - use super::*; - use crate::constants; - use crate::scalar::Scalar; - - // Reuse points across different tests - let mut n = 512; - let x = Scalar::from(2128506u64).invert(); - let y = Scalar::from(4443282u64).invert(); - let points: Vec<_> = (0..n) - .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) - .collect(); - let scalars: Vec<_> = (0..n) - .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars - .collect(); - - let premultiplied: Vec = scalars - .iter() - .zip(points.iter()) - .map(|(sc, pt)| sc * pt) - .collect(); - - while n > 0 { - let scalars = &scalars[0..n].to_vec(); - let points = &points[0..n].to_vec(); - let control: EdwardsPoint = premultiplied[0..n].iter().sum(); - - let subject = - Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); - - assert_eq!(subject.compress(), control.compress()); - - n = n / 2; - } + // Add the buckets applying the multiplication factor to each bucket. + // The most efficient way to do that is to have a single sum with two running sums: + // an intermediate sum from last bucket to the first, and a sum of intermediate sums. + // + // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: + // C + // C B + // C B A Sum = C + (C+B) + (C+B+A) + let mut buckets_intermediate_sum = buckets[buckets_count - 1]; + let mut buckets_sum = buckets[buckets_count - 1]; + for i in (0..(buckets_count - 1)).rev() { + buckets_intermediate_sum = + &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); + buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); } + + buckets_sum + }); + + // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. + // `unwrap()` always succeeds because we know we have more than zero digits. + let hi_column = columns.next().unwrap(); + + Some( + columns + .fold(hi_column, |total, p| { + &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) + }) + .into(), + ) + } + } + + #[cfg(test)] + mod test { + #[test] + fn test_vartime_pippenger() { + use super::*; + use crate::constants; + use crate::scalar::Scalar; + + // Reuse points across different tests + let mut n = 512; + let x = Scalar::from(2128506u64).invert(); + let y = Scalar::from(4443282u64).invert(); + let points: Vec<_> = (0..n) + .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) + .collect(); + let scalars: Vec<_> = (0..n) + .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars + .collect(); + + let premultiplied: Vec = scalars + .iter() + .zip(points.iter()) + .map(|(sc, pt)| sc * pt) + .collect(); + + while n > 0 { + let scalars = &scalars[0..n].to_vec(); + let points = &points[0..n].to_vec(); + let control: EdwardsPoint = premultiplied[0..n].iter().sum(); + + let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); + + assert_eq!(subject.compress(), control.compress()); + + n = n / 2; } } - }; + } } - -#[cfg(feature = "simd_avx2")] -implement!(spec_avx2, avx2, "avx2"); - -#[cfg(all(feature = "simd_avx512", nightly))] -implement!(spec_avx512ifma_avx512vl, ifma, "avx512ifma,avx512vl"); From 502897109c13f6fcf0dbf8e172ff1b43307917ab Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Mon, 5 Jun 2023 07:40:31 +0000 Subject: [PATCH 10/11] Pin the version of `unsafe_target_feature` --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 1a76f8742..33bbb29a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ digest = { version = "0.10", default-features = false, optional = true } subtle = { version = "2.3.0", default-features = false } serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] } zeroize = { version = "1", default-features = false, optional = true } -unsafe_target_feature = { version = "0.1.1", optional = true } +unsafe_target_feature = { version = "= 0.1.1", optional = true } [target.'cfg(target_arch = "x86_64")'.dependencies] cpufeatures = "0.2.6" From 50aa63532b3012ce79b74f3677ee243e97e62b60 Mon Sep 17 00:00:00 2001 From: Jan Bujak Date: Mon, 5 Jun 2023 07:42:36 +0000 Subject: [PATCH 11/11] Fix the doc comment in `packed_simd.rs` --- src/backend/vector/packed_simd.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/backend/vector/packed_simd.rs b/src/backend/vector/packed_simd.rs index 371410d6f..6ab5dcc9c 100644 --- a/src/backend/vector/packed_simd.rs +++ b/src/backend/vector/packed_simd.rs @@ -3,12 +3,12 @@ // This file is part of curve25519-dalek. // See LICENSE for licensing information. -///! This module defines wrappers over platform-specific SIMD types to make them -///! more convenient to use. -///! -///! UNSAFETY: Everything in this module assumes that we're running on hardware -///! which supports at least AVX2. This invariant *must* be enforced -///! by the callers of this code. +//! This module defines wrappers over platform-specific SIMD types to make them +//! more convenient to use. +//! +//! UNSAFETY: Everything in this module assumes that we're running on hardware +//! which supports at least AVX2. This invariant *must* be enforced +//! by the callers of this code. use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub}; use unsafe_target_feature::unsafe_target_feature;