Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
support only u64 sized primes for smallfp
  • Loading branch information
z-tech committed Mar 30, 2026
commit 869d61ea995d0c587c2a7c48aa81be4fb366a85c
2 changes: 1 addition & 1 deletion ff-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn define_field(input: TokenStream) -> TokenStream {

let name = args.name;
let config_name = format_ident!("{}Config", name);
let is_small_modulus = modulus_big < (BigUint::from(1u128) << 127);
let is_small_modulus = modulus_big < (BigUint::from(1u128) << 64);

if is_small_modulus {
let modulus_u128: u128 = args
Expand Down
8 changes: 4 additions & 4 deletions ff-macros/src/small_fp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ pub(crate) fn small_fp_config_helper(
m if m < 1u128 << 8 => quote! { u8 },
m if m < 1u128 << 16 => quote! { u16 },
m if m < 1u128 << 32 => quote! { u32 },
m if m < 1u128 << 64 => quote! { u64 },
_ => quote! { u128 },
_ => quote! { u64 },
};

assert!(modulus < 1u128 << 127,
"SmallFpConfig montgomery backend supports only moduli < 2^127. Use MontConfig with BigInt instead of SmallFp."
assert!(
modulus < 1u128 << 64,
"SmallFpConfig supports only moduli < 2^64. Use MontConfig with BigInt instead of SmallFp."
);

let (backend_impl, r_mod_p) = montgomery_backend::backend_impl(&ty, modulus, generator);
Expand Down
135 changes: 13 additions & 122 deletions ff-macros/src/small_fp/montgomery_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,18 @@ pub(crate) fn backend_impl(
"modulus must be odd for Montgomery multiplication"
);
assert!(
modulus < (1u128 << 127),
"modulus must be < 2^127 for u128-backed SmallFp"
modulus < (1u128 << 64),
"modulus must be < 2^64 for SmallFp"
);

let ty_str = ty.to_string();
let is_u128 = ty_str == "u128";

// For u128, we use R = 2^128 for smaller types, R = 2^k_bits
let k_bits = if is_u128 {
128u32
} else {
128 - modulus.leading_zeros()
};
let r: u128 = if k_bits == 128 {
0u128
} else {
1u128 << k_bits
};
// When R = 2^128 this doesn't fit in u128 but:
// (2^128 - n) mod n = 2^128 mod n
// and in u128 wrapping arithmetic:
// 0 - n wraps to 2^128 - n
// so:
// 2^128 mod n = (0 - n) mod n
let r_mod_n = if k_bits == 128 {
0u128.wrapping_sub(modulus) % modulus
} else {
r % modulus
};
let r_mask = if k_bits == 128 { u128::MAX } else { r - 1 };
// R = 2^k_bits where k_bits = ceil(log2(modulus))
// Since modulus < 2^64, k_bits <= 64 and R always fits in u128
let k_bits = 128 - modulus.leading_zeros();
let r: u128 = 1u128 << k_bits;
let r_mod_n = r % modulus;
let r_mask = r - 1;

let n_prime = mod_inverse_pow2(modulus, k_bits);
let one_mont = r_mod_n;
Expand All @@ -68,7 +50,6 @@ pub(crate) fn backend_impl(
"u16" => 16u32,
"u32" => 32u32,
"u64" => 64u32,
"u128" => 128u32,
_ => panic!("unsupported type"),
};

Expand Down Expand Up @@ -214,9 +195,8 @@ pub(crate) fn backend_impl(
(ts, r_mod_n)
}

// Selects the appropriate multiplication algorithm at compile time:
// if modulus <= u64, multiply by casting to the next largest primitive
// otherwise, multiply in parts to form a 256-bit product
// Selects the appropriate multiplication algorithm at compile time
// by widening to the next-largest primitive type for the product
fn generate_mul_impl(
ty: &proc_macro2::TokenStream,
modulus: u128,
Expand All @@ -227,99 +207,13 @@ fn generate_mul_impl(
let ty_str = ty.to_string();

match ty_str.as_str() {
"u128" => generate_u128_mul(modulus, n_prime),
"u64" => generate_u64_mul(modulus, k_bits, r_mask, n_prime),
"u32" => generate_u32_mul(modulus, k_bits, r_mask, n_prime),
"u8" | "u16" => generate_small_mul(ty, ty_str.as_str(), modulus, k_bits, r_mask, n_prime),
_ => panic!("Unsupported type: {}", ty_str),
}
}

// Montgomery multiplication for 2 limbs (similar to ff-asm/src/lib.rs)
fn generate_u128_mul(modulus: u128, n_prime: u128) -> proc_macro2::TokenStream {
let modulus_lo = (modulus & 0xFFFFFFFFFFFFFFFF) as u64;
let modulus_hi = (modulus >> 64) as u64;

quote! {
#[inline(always)]
fn mul_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
const MODULUS: [u64; 2] = [#modulus_lo, #modulus_hi];
const INV: u64 = #n_prime as u64;

let a_limbs = [a.value as u64, (a.value >> 64) as u64];
let b_limbs = [b.value as u64, (b.value >> 64) as u64];

#[inline(always)]
fn umul(a: u64, b: u64) -> (u64, u64) {
let full = (a as u128) * (b as u128);
(full as u64, (full >> 64) as u64)
}

// r accumulator: 3 words (r[0], r[1], r[2]) to hold intermediate
let mut r0: u64 = 0;
let mut r1: u64 = 0;
let mut r2: u64 = 0;


let (lo, hi) = umul(a_limbs[0], b_limbs[0]);
let (r0_new, c) = r0.overflowing_add(lo);
r0 = r0_new;
let carry1 = c as u64;

let (lo, hi2) = umul(a_limbs[1], b_limbs[0]);
let (r1_new, c1) = r1.overflowing_add(lo);
let (r1_new, c2) = r1_new.overflowing_add(hi + carry1);
r1 = r1_new;
r2 = r2.wrapping_add(hi2).wrapping_add(c1 as u64 + c2 as u64);

let m = r0.wrapping_mul(INV);

let (lo, hi) = umul(m, MODULUS[0]);
let (_, c) = r0.overflowing_add(lo); // r0 + lo should be 0 mod 2^64
let carry = hi.wrapping_add(c as u64);

let (lo, hi) = umul(m, MODULUS[1]);
let (new_r0, c1) = r1.overflowing_add(lo);
let (new_r0, c2) = new_r0.overflowing_add(carry);
r0 = new_r0;
r1 = r2.wrapping_add(hi).wrapping_add(c1 as u64 + c2 as u64);
r2 = 0;


let (lo, hi) = umul(a_limbs[0], b_limbs[1]);
let (r0_new, c) = r0.overflowing_add(lo);
r0 = r0_new;
let carry1 = c as u64;

let (lo, hi2) = umul(a_limbs[1], b_limbs[1]);
let (r1_new, c1) = r1.overflowing_add(lo);
let (r1_new, c2) = r1_new.overflowing_add(hi + carry1);
r1 = r1_new;
r2 = r2.wrapping_add(hi2).wrapping_add(c1 as u64 + c2 as u64);

let m = r0.wrapping_mul(INV);

let (lo, hi) = umul(m, MODULUS[0]);
let (_, c) = r0.overflowing_add(lo);
let carry = hi.wrapping_add(c as u64);

let (lo, hi) = umul(m, MODULUS[1]);
let (new_r0, c1) = r1.overflowing_add(lo);
let (new_r0, c2) = new_r0.overflowing_add(carry);
r0 = new_r0;
r1 = r2.wrapping_add(hi).wrapping_add(c1 as u64 + c2 as u64);


let mut result = (r0 as u128) | ((r1 as u128) << 64);
let modulus_val = (#modulus_lo as u128) | ((#modulus_hi as u128) << 64);
if result >= modulus_val {
result -= modulus_val;
}
a.value = result;
}
}
}

fn generate_u64_mul(
modulus: u128,
k_bits: u32,
Expand Down Expand Up @@ -483,11 +377,8 @@ fn mod_inverse_pow2(n: u128, k_bits: u32) -> u128 {
for _ in 0..ITER {
inv = inv.wrapping_mul(2u128.wrapping_sub(n.wrapping_mul(inv)));
}
let mask = if k_bits == 128 {
u128::MAX
} else {
(1u128 << k_bits) - 1
};
// k_bits <= 64 since modulus < 2^64
let mask = (1u128 << k_bits) - 1;
inv.wrapping_neg() & mask
}

Expand All @@ -504,7 +395,7 @@ pub(crate) fn exit_impl(modulus: u128, r_mod_p: u128) -> proc_macro2::TokenStrea
const R_MOD_P: u128 = #r_mod_p;

// const-compatible modular multiplication via double-and-add
// Safe from overflow: modulus < 2^127 so a,result < 2^127 and all additions fit u128
// Safe from overflow: modulus < 2^64 so a,result < 2^64 and all additions fit u128
const fn mod_mul(mut a: u128, mut b: u128, m: u128) -> u128 {
a %= m;
let mut result = 0u128;
Expand Down
19 changes: 7 additions & 12 deletions ff-macros/src/small_fp/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,23 @@ pub(crate) fn generate_montgomery_bigint_casts(
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
(
quote! {
fn from_bigint(a: ark_ff::BigInt<2>) -> Option<SmallFp<Self>> {
let val = (a.0[0] as u128) + ((a.0[1] as u128) << 64);
fn from_bigint(a: ark_ff::BigInt<1>) -> Option<SmallFp<Self>> {
let val = a.0[0] as u128;
if val >= Self::MODULUS_U128 {
None
} else {
let reduced_val = val % Self::MODULUS_U128;
let val_t = Self::T::try_from(reduced_val).ok().unwrap();
let val_t = Self::T::try_from(val).ok().unwrap();
Some(Self::new(val_t))
}
}
},
quote! {
fn into_bigint(a: SmallFp<Self>) -> ark_ff::BigInt<2> {
fn into_bigint(a: SmallFp<Self>) -> ark_ff::BigInt<1> {
let mut tmp = a;
let one = SmallFp::from_raw(1 as Self::T);
Self::mul_assign(&mut tmp, &one);
let val = tmp.value as u128;
let lo = val as u64;
let hi = (val >> 64) as u64;
ark_ff::BigInt([lo, hi])
ark_ff::BigInt([val as u64])
}
},
)
Expand All @@ -116,12 +113,11 @@ pub(crate) fn generate_sqrt_precomputation(
if modulus % 4 == 3 {
let modulus_plus_one_div_four = (modulus + 1) / 4;
let lo = modulus_plus_one_div_four as u64;
let hi = (modulus_plus_one_div_four >> 64) as u64;

quote! {
// Case3Mod4 square root precomputation
const SQRT_PRECOMP: Option<ark_ff::SqrtPrecomputation<SmallFp<Self>>> = {
const MODULUS_PLUS_ONE_DIV_FOUR: [u64; 2] = [#lo, #hi];
const MODULUS_PLUS_ONE_DIV_FOUR: [u64; 1] = [#lo];
Some(ark_ff::SqrtPrecomputation::Case3Mod4 {
modulus_plus_one_div_four: &MODULUS_PLUS_ONE_DIV_FOUR,
})
Expand All @@ -131,7 +127,6 @@ pub(crate) fn generate_sqrt_precomputation(
let trace = (modulus - 1) >> two_adicity;
let trace_minus_one_div_two = trace / 2;
let lo = trace_minus_one_div_two as u64;
let hi = (trace_minus_one_div_two >> 64) as u64;
let qnr = find_quadratic_non_residue(modulus);
let qnr_to_trace = match r_mod_n {
None => pow_mod_const(qnr, trace, modulus),
Expand All @@ -141,7 +136,7 @@ pub(crate) fn generate_sqrt_precomputation(
quote! {
// TonelliShanks square root precomputation
const SQRT_PRECOMP: Option<ark_ff::SqrtPrecomputation<SmallFp<Self>>> = {
const TRACE_MINUS_ONE_DIV_TWO: [u64; 2] = [#lo, #hi];
const TRACE_MINUS_ONE_DIV_TWO: [u64; 1] = [#lo];
Some(ark_ff::SqrtPrecomputation::TonelliShanks {
two_adicity: #two_adicity,
quadratic_nonresidue_to_trace: SmallFp::from_raw(#qnr_to_trace as Self::T),
Expand Down
2 changes: 1 addition & 1 deletion ff/src/fields/models/small_fp/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<P: SmallFpConfig> Field for SmallFp<P> {
None
} else {
let shave_bits = Self::num_bits_to_shave();
let mut result_bytes: crate::const_helpers::SerBuffer<2> =
let mut result_bytes: crate::const_helpers::SerBuffer<1> =
crate::const_helpers::SerBuffer::zeroed();
// Copy the input into a temporary buffer.
result_bytes.copy_from_u8_slice(bytes);
Expand Down
6 changes: 3 additions & 3 deletions ff/src/fields/models/small_fp/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ impl<P: SmallFpConfig> From<SmallFp<P>> for num_bigint::BigUint {
}
}

impl<P: SmallFpConfig> From<SmallFp<P>> for BigInt<2> {
impl<P: SmallFpConfig> From<SmallFp<P>> for BigInt<1> {
fn from(fp: SmallFp<P>) -> Self {
fp.into_bigint()
}
}

impl<P: SmallFpConfig> From<BigInt<2>> for SmallFp<P> {
fn from(int: BigInt<2>) -> Self {
impl<P: SmallFpConfig> From<BigInt<1>> for SmallFp<P> {
fn from(int: BigInt<1>) -> Self {
Self::from_bigint(int).unwrap()
}
}
24 changes: 10 additions & 14 deletions ff/src/fields/models/small_fp/small_fp_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub trait SmallFpConfig: Send + Sync + 'static + Sized {

// TODO: the value can be 1 or 2, it would be nice to have it generic.
/// Number of bigint limbs used to represent the field elements.
const NUM_BIG_INT_LIMBS: usize = 2;
const NUM_BIG_INT_LIMBS: usize = 1;

/// A multiplicative generator of the field.
/// `Self::GENERATOR` is an element having multiplicative order
Expand Down Expand Up @@ -120,11 +120,11 @@ pub trait SmallFpConfig: Send + Sync + 'static + Sized {
/// Construct a field element from an integer in the range
/// `0..(Self::MODULUS - 1)`. Returns `None` if the integer is outside
/// this range.
fn from_bigint(other: BigInt<2>) -> Option<SmallFp<Self>>;
fn from_bigint(other: BigInt<1>) -> Option<SmallFp<Self>>;

/// Convert a field element to an integer in the range `0..(Self::MODULUS -
/// 1)`.
fn into_bigint(other: SmallFp<Self>) -> BigInt<2>;
fn into_bigint(other: SmallFp<Self>) -> BigInt<1>;
}

/// Represents an element of the prime field F_p, where `p == P::MODULUS`.
Expand Down Expand Up @@ -219,10 +219,8 @@ impl<P: SmallFpConfig> AdditiveGroup for SmallFp<P> {
}
}

const fn const_to_bigint(value: u128) -> BigInt<2> {
let low = (value & 0xFFFFFFFFFFFFFFFF) as u64;
let high = (value >> 64) as u64;
BigInt::<2>::new([low, high])
const fn const_to_bigint(value: u128) -> BigInt<1> {
BigInt::<1>::new([value as u64])
}

const fn const_num_bits_u128(value: u128) -> u32 {
Expand All @@ -238,13 +236,12 @@ const fn primitive_type_bit_size(modulus_u128: u128) -> usize {
x if x <= u8::MAX as u128 => 8,
x if x <= u16::MAX as u128 => 16,
x if x <= u32::MAX as u128 => 32,
x if x <= u64::MAX as u128 => 64,
_ => 128,
_ => 64,
}
}

impl<P: SmallFpConfig> PrimeField for SmallFp<P> {
type BigInt = BigInt<2>;
type BigInt = BigInt<1>;

const MODULUS: Self::BigInt = const_to_bigint(P::MODULUS_U128);
const MODULUS_MINUS_ONE_DIV_TWO: Self::BigInt = Self::MODULUS.divide_by_2_round_down();
Expand All @@ -253,11 +250,11 @@ impl<P: SmallFpConfig> PrimeField for SmallFp<P> {
const TRACE_MINUS_ONE_DIV_TWO: Self::BigInt = Self::TRACE.divide_by_2_round_down();

#[inline]
fn from_bigint(r: BigInt<2>) -> Option<Self> {
fn from_bigint(r: BigInt<1>) -> Option<Self> {
P::from_bigint(r)
}

fn into_bigint(self) -> BigInt<2> {
fn into_bigint(self) -> BigInt<1> {
P::into_bigint(self)
}
}
Expand Down Expand Up @@ -320,8 +317,7 @@ impl<P: SmallFpConfig> ark_std::rand::distributions::Distribution<SmallFp<P>>
modulus if modulus <= u8::MAX as u128 => sample_loop!(u8),
modulus if modulus <= u16::MAX as u128 => sample_loop!(u16),
modulus if modulus <= u32::MAX as u128 => sample_loop!(u32),
modulus if modulus <= u64::MAX as u128 => sample_loop!(u64),
_ => sample_loop!(u128),
_ => sample_loop!(u64),
}
}
}
Expand Down
Loading
Loading