diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3f167e3..37846d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,6 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --features fuzz-tests - name: Run cargo fmt uses: actions-rs/cargo@v1 diff --git a/light-poseidon/Cargo.toml b/light-poseidon/Cargo.toml index b952bb5..deea8d9 100644 --- a/light-poseidon/Cargo.toml +++ b/light-poseidon/Cargo.toml @@ -19,10 +19,6 @@ criterion = "0.5" rand = "0.8" hex = "0.4.3" -[features] -fuzz-tests = [] - - [[bench]] name = "bn254_x5" harness = false diff --git a/light-poseidon/src/lib.rs b/light-poseidon/src/lib.rs index 0b5792d..99c5d5f 100644 --- a/light-poseidon/src/lib.rs +++ b/light-poseidon/src/lib.rs @@ -415,6 +415,11 @@ impl PoseidonHasher for Poseidon { impl PoseidonBytesHasher for Poseidon { fn hash_bytes_le(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> { + let inputs: Result, _> = inputs + .iter() + .map(|input| validate_bytes_length::(input)) + .collect(); + let inputs = inputs?; let inputs: Result, _> = inputs .iter() .map(|input| bytes_to_prime_field_element(input)) @@ -429,6 +434,11 @@ impl PoseidonBytesHasher for Poseidon { } fn hash_bytes_be(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> { + let inputs: Result, _> = inputs + .iter() + .map(|input| validate_bytes_length::(input)) + .collect(); + let inputs = inputs?; let inputs: Result, _> = inputs .iter() .map(|input| { @@ -448,22 +458,17 @@ impl PoseidonBytesHasher for Poseidon { } } -/// Converts a slice of bytes into a prime field element, represented by the -/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait. +/// Checks whether a slice of bytes is not empty or its length does not exceed +/// the modulus size od the prime field. If it does, an error is returned. /// /// # Safety /// -/// Unlike the /// [`PrimeField::from_be_bytes_mod_order`](ark_ff::PrimeField::from_be_bytes_mod_order) -/// and [`Field::from_random_bytes`](ark_ff::Field::from_random_bytes) -/// methods, this function ensures that the input byte slice's length exactly matches -/// the modulus size of the prime field. If the size doesn't match, an error is returned. -/// -/// This strict check is designed to prevent unexpected behaviors and collisions -/// that might occur when using `from_be_bytes_mod_order` or `from_random_bytes`, -/// which simply take a subslice of the input if it's too large, potentially -/// leading to collisions. -fn bytes_to_prime_field_element(input: &[u8]) -> Result +/// just takes a subslice of the input if it's too large, potentially leading +/// to collisions. The purpose of this function is to prevent them by returning +/// and error. It should be always used before converting byte slices to +/// prime field elements. +fn validate_bytes_length(input: &[u8]) -> Result<&[u8], PoseidonError> where F: PrimeField, { @@ -477,6 +482,15 @@ where modulus_bytes_len, }); } + Ok(input) +} + +/// Converts a slice of bytes into a prime field element, represented by the +/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait. +fn bytes_to_prime_field_element(input: &[u8]) -> Result +where + F: PrimeField, +{ F::from_random_bytes(input).ok_or(PoseidonError::InputLargerThanModulus) } diff --git a/light-poseidon/tests/bn254_fq_x5.rs b/light-poseidon/tests/bn254_fq_x5.rs index 532c968..2460ba1 100644 --- a/light-poseidon/tests/bn254_fq_x5.rs +++ b/light-poseidon/tests/bn254_fq_x5.rs @@ -1,7 +1,8 @@ use ark_bn254::Fr; -use ark_ff::{BigInteger, One, PrimeField, Zero}; +use ark_ff::{BigInteger, BigInteger256, One, PrimeField, Zero}; use light_poseidon::{Poseidon, PoseidonError}; use light_poseidon::{PoseidonBytesHasher, PoseidonHasher}; +use rand::Rng; #[test] fn test_poseidon_one() { @@ -148,128 +149,120 @@ fn test_poseidon_bn254_x5_fq_hash_bytes_le() { ); } -#[cfg(feature = "fuzz-tests")] -mod fuzz_tests { - use ark_ff::BigInteger256; - use rand::Rng; - - use super::*; - - macro_rules! test_random_input_same_results { - ($name:ident, $method:ident) => { - #[test] - fn $name() { - let input = [1u8; 32]; - - for nr_inputs in 1..12 { - let mut hasher = Poseidon::::new_circom(nr_inputs).unwrap(); - - let mut inputs = Vec::with_capacity(nr_inputs); - for _ in 0..nr_inputs { - inputs.push(input.as_slice()); - } +macro_rules! test_random_input_same_results { + ($name:ident, $method:ident) => { + #[test] + fn $name() { + let input = [1u8; 32]; - let hash1 = hasher.$method(inputs.as_slice()).unwrap(); - let hash2 = hasher.$method(inputs.as_slice()).unwrap(); + for nr_inputs in 1..12 { + let mut hasher = Poseidon::::new_circom(nr_inputs).unwrap(); - assert_eq!(hash1, hash2); + let mut inputs = Vec::with_capacity(nr_inputs); + for _ in 0..nr_inputs { + inputs.push(input.as_slice()); } - } - }; - } - - test_random_input_same_results!( - test_poseidon_bn254_x5_fq_hash_bytes_be_random_input_same_results, - hash_bytes_be - ); - test_random_input_same_results!( - test_poseidon_bn254_x5_fq_hash_bytes_le_random_input_same_results, - hash_bytes_le - ); + let hash1 = hasher.$method(inputs.as_slice()).unwrap(); + let hash2 = hasher.$method(inputs.as_slice()).unwrap(); - macro_rules! test_invalid_input_length { - ($name:ident, $method:ident) => { - #[test] - fn $name() { - let mut rng = rand::thread_rng(); - - for _ in 0..100 { - let len = rng.gen_range(33..524_288_000); // Maximum 500 MB. - let input = vec![1u8; len]; - - for nr_inputs in 1..12 { - let mut hasher = Poseidon::::new_circom(nr_inputs).unwrap(); - - let mut inputs = Vec::with_capacity(nr_inputs); - for _ in 0..nr_inputs { - inputs.push(input.as_slice()); - } - - let hash = hasher.$method(inputs.as_slice()); - assert_eq!( - hash, - Err(PoseidonError::InvalidInputLength { - len, - modulus_bytes_len: 32, - }) - ); - } - } + assert_eq!(hash1, hash2); } - }; - } + } + }; +} - test_invalid_input_length!( - test_poseidon_bn254_x5_fq_hash_bytes_be_invalid_input_length, - hash_bytes_be - ); +test_random_input_same_results!( + test_poseidon_bn254_x5_fq_hash_bytes_be_random_input_same_results, + hash_bytes_be +); - test_invalid_input_length!( - test_poseidon_bn254_x5_fq_hash_bytes_le_invalid_input_length, - hash_bytes_le - ); +test_random_input_same_results!( + test_poseidon_bn254_x5_fq_hash_bytes_le_random_input_same_results, + hash_bytes_le +); - macro_rules! test_input_gt_field_size { - ($name:ident, $method:ident, $to_bytes_method:ident) => { - #[test] - fn $name() { - let mut greater_than_field_size = Fr::MODULUS; - let mut rng = rand::thread_rng(); - let random_number = rng.gen_range(1u64..1_000_000u64); - greater_than_field_size.add_with_carry(&BigInteger256::from(random_number)); - let greater_than_field_size = greater_than_field_size.$to_bytes_method(); +macro_rules! test_invalid_input_length { + ($name:ident, $method:ident) => { + #[test] + fn $name() { + let mut rng = rand::thread_rng(); - assert_eq!(greater_than_field_size.len(), 32); + for _ in 0..100 { + let len = rng.gen_range(33..524_288_000); // Maximum 500 MB. + let input = vec![1u8; len]; for nr_inputs in 1..12 { let mut hasher = Poseidon::::new_circom(nr_inputs).unwrap(); let mut inputs = Vec::with_capacity(nr_inputs); for _ in 0..nr_inputs { - inputs.push(&greater_than_field_size[..]); + inputs.push(input.as_slice()); } let hash = hasher.$method(inputs.as_slice()); - assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus)); + assert_eq!( + hash, + Err(PoseidonError::InvalidInputLength { + len, + modulus_bytes_len: 32, + }) + ); } } - }; - } + } + }; +} - test_input_gt_field_size!( - test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size, - hash_bytes_be, - to_bytes_be - ); +test_invalid_input_length!( + test_poseidon_bn254_x5_fq_hash_bytes_be_invalid_input_length, + hash_bytes_be +); - test_input_gt_field_size!( - test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size, - hash_bytes_le, - to_bytes_le - ); +test_invalid_input_length!( + test_poseidon_bn254_x5_fq_hash_bytes_le_invalid_input_length, + hash_bytes_le +); + +macro_rules! test_fuzz_input_gt_field_size { + ($name:ident, $method:ident, $to_bytes_method:ident) => { + #[test] + fn $name() { + let mut greater_than_field_size = Fr::MODULUS; + let mut rng = rand::thread_rng(); + let random_number = rng.gen_range(1u64..1_000_000u64); + greater_than_field_size.add_with_carry(&BigInteger256::from(random_number)); + let greater_than_field_size = greater_than_field_size.$to_bytes_method(); + + assert_eq!(greater_than_field_size.len(), 32); + + for nr_inputs in 1..12 { + let mut hasher = Poseidon::::new_circom(nr_inputs).unwrap(); + + let mut inputs = Vec::with_capacity(nr_inputs); + for _ in 0..nr_inputs { + inputs.push(&greater_than_field_size[..]); + } + + let hash = hasher.$method(inputs.as_slice()); + assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus)); + } + } + }; } +test_fuzz_input_gt_field_size!( + test_fuzz_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size, + hash_bytes_be, + to_bytes_be +); + +test_fuzz_input_gt_field_size!( + test_fuzz_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size, + hash_bytes_le, + to_bytes_le +); + macro_rules! test_input_gt_field_size { ($name:ident, $method:ident, $greater_than_field_size:expr) => { #[test]