Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEON backend for aarch64 #457

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion curve25519-dalek/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ legacy_compatibility = []
group = ["dep:group", "rand_core"]
group-bits = ["group", "ff/bits"]

[target.'cfg(all(not(curve25519_dalek_backend = "fiat"), not(curve25519_dalek_backend = "serial"), target_arch = "x86_64"))'.dependencies]
[target.'cfg(all(not(curve25519_dalek_backend = "fiat"), not(curve25519_dalek_backend = "serial"), any(target_arch = "x86_64", target_arch = "aarch64")))'.dependencies]
curve25519-dalek-derive = { version = "0.1", path = "../curve25519-dalek-derive" }

[lints.rust.unexpected_cfgs]
Expand Down
2 changes: 1 addition & 1 deletion curve25519-dalek/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ fn main() {

// Is the target arch & curve25519_dalek_bits potentially simd capable ?
fn is_capable_simd(arch: &str, bits: DalekBits) -> bool {
arch == "x86_64" && bits == DalekBits::Dalek64
(arch == "x86_64" || arch == "aarch64") && bits == DalekBits::Dalek64
}

// Deterministic cfg(curve25519_dalek_bits) when this is not explicitly set.
Expand Down
74 changes: 54 additions & 20 deletions curve25519-dalek/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ pub mod vector;

#[derive(Copy, Clone)]
enum BackendKind {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
Avx2,
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
Avx512,
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
Neon,
Serial,
}

#[inline]
fn get_selected_backend() -> BackendKind {
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
{
return BackendKind::Neon;
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
{
cpufeatures::new!(cpuid_avx512, "avx512ifma", "avx512vl");
let token_avx512: cpuid_avx512::InitToken = cpuid_avx512::init();
Expand All @@ -62,7 +68,7 @@ fn get_selected_backend() -> BackendKind {
}
}

#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
{
cpufeatures::new!(cpuid_avx2, "avx2");
let token_avx2: cpuid_avx2::InitToken = cpuid_avx2::init();
Expand All @@ -85,25 +91,30 @@ where
use crate::traits::VartimeMultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
BackendKind::Avx2 =>
vector::scalar_mul::pippenger::spec_avx2::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
BackendKind::Avx512 =>
vector::scalar_mul::pippenger::spec_avx512ifma_avx512vl::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
BackendKind::Neon =>
vector::scalar_mul::pippenger::spec_neon::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
BackendKind::Serial =>
serial::scalar_mul::pippenger::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
}
}

#[cfg(feature = "alloc")]
pub(crate) enum VartimePrecomputedStraus {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
Avx2(vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
Avx512ifma(
vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
Neon(vector::scalar_mul::precomputed_straus::spec_neon::VartimePrecomputedStraus),
Scalar(serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus),
}

Expand All @@ -117,12 +128,15 @@ impl VartimePrecomputedStraus {
use crate::traits::VartimePrecomputedMultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
BackendKind::Avx2 =>
VartimePrecomputedStraus::Avx2(vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus::new(static_points)),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
BackendKind::Avx512 =>
VartimePrecomputedStraus::Avx512ifma(vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus::new(static_points)),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
BackendKind::Neon =>
VartimePrecomputedStraus::Neon(vector::scalar_mul::precomputed_straus::spec_neon::VartimePrecomputedStraus::new(static_points)),
BackendKind::Serial =>
VartimePrecomputedStraus::Scalar(serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points))
}
Expand All @@ -144,18 +158,24 @@ impl VartimePrecomputedStraus {
use crate::traits::VartimePrecomputedMultiscalarMul;

match self {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
VartimePrecomputedStraus::Avx2(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
dynamic_points,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
VartimePrecomputedStraus::Avx512ifma(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
dynamic_points,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
VartimePrecomputedStraus::Neon(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
dynamic_points,
),
VartimePrecomputedStraus::Scalar(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
Expand All @@ -177,16 +197,20 @@ where
use crate::traits::MultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::<I, J>(scalars, points)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::<I, J>(
scalars, points,
)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => {
vector::scalar_mul::straus::spec_neon::Straus::multiscalar_mul::<I, J>(scalars, points)
}
BackendKind::Serial => {
serial::scalar_mul::straus::Straus::multiscalar_mul::<I, J>(scalars, points)
}
Expand All @@ -204,19 +228,25 @@ where
use crate::traits::VartimeMultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::optional_multiscalar_mul::<I, J>(
scalars, points,
)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::optional_multiscalar_mul::<
I,
J,
>(scalars, points)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => {
vector::scalar_mul::straus::spec_neon::Straus::optional_multiscalar_mul::<I, J>(
scalars, points,
)
}
BackendKind::Serial => {
serial::scalar_mul::straus::Straus::optional_multiscalar_mul::<I, J>(scalars, points)
}
Expand All @@ -226,12 +256,14 @@ where
/// Perform constant-time, variable-base scalar multiplication.
pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => vector::scalar_mul::variable_base::spec_avx2::mul(point, scalar),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => vector::scalar_mul::variable_base::spec_neon::mul(point, scalar),
BackendKind::Serial => serial::scalar_mul::variable_base::mul(point, scalar),
}
}
Expand All @@ -240,12 +272,14 @@ pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint
#[allow(non_snake_case)]
pub fn vartime_double_base_mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint {
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
#[cfg(all(curve25519_dalek_backend = "simd", target_arch = "x86_64"))]
BackendKind::Avx2 => vector::scalar_mul::vartime_double_base::spec_avx2::mul(a, A, b),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "x86_64"))]
BackendKind::Avx512 => {
vector::scalar_mul::vartime_double_base::spec_avx512ifma_avx512vl::mul(a, A, b)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch = "aarch64"))]
BackendKind::Neon => vector::scalar_mul::vartime_double_base::spec_neon::mul(a, A, b),
BackendKind::Serial => serial::scalar_mul::vartime_double_base::mul(a, A, b),
}
}
7 changes: 6 additions & 1 deletion curve25519-dalek/src/backend/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
#![doc = include_str!("../../../docs/parallel-formulas.md")]

#[allow(missing_docs)]
#[cfg(target_arch = "x86_64")]
pub mod packed_simd;

#[cfg(target_arch = "x86_64")]
pub mod avx2;

#[cfg(nightly)]
#[cfg(all(nightly, target_arch = "x86_64"))]
pub mod ifma;

#[cfg(all(nightly, target_arch = "aarch64"))]
pub mod neon;

pub mod scalar_mul;
Loading
Loading