Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --features fuzz-tests

- name: Run cargo fmt
uses: actions-rs/cargo@v1
Expand Down
4 changes: 0 additions & 4 deletions light-poseidon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ criterion = "0.5"
rand = "0.8"
hex = "0.4.3"

[features]
fuzz-tests = []


[[bench]]
name = "bn254_x5"
harness = false
38 changes: 26 additions & 12 deletions light-poseidon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ impl<F: PrimeField> PoseidonHasher<F> for Poseidon<F> {

impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
fn hash_bytes_le(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| validate_bytes_length::<F>(input))
.collect();
let inputs = inputs?;
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| bytes_to_prime_field_element(input))
Expand All @@ -429,6 +434,11 @@ impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
}

fn hash_bytes_be(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| validate_bytes_length::<F>(input))
.collect();
let inputs = inputs?;
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| {
Expand All @@ -448,22 +458,17 @@ impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
}
}

/// Converts a slice of bytes into a prime field element, represented by the
/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait.
/// Checks whether a slice of bytes is not empty or its length does not exceed
/// the modulus size od the prime field. If it does, an error is returned.
///
/// # Safety
///
/// Unlike the
/// [`PrimeField::from_be_bytes_mod_order`](ark_ff::PrimeField::from_be_bytes_mod_order)
/// and [`Field::from_random_bytes`](ark_ff::Field::from_random_bytes)
/// methods, this function ensures that the input byte slice's length exactly matches
/// the modulus size of the prime field. If the size doesn't match, an error is returned.
///
/// This strict check is designed to prevent unexpected behaviors and collisions
/// that might occur when using `from_be_bytes_mod_order` or `from_random_bytes`,
/// which simply take a subslice of the input if it's too large, potentially
/// leading to collisions.
fn bytes_to_prime_field_element<F>(input: &[u8]) -> Result<F, PoseidonError>
/// just takes a subslice of the input if it's too large, potentially leading
/// to collisions. The purpose of this function is to prevent them by returning
/// and error. It should be always used before converting byte slices to
/// prime field elements.
fn validate_bytes_length<F>(input: &[u8]) -> Result<&[u8], PoseidonError>
where
F: PrimeField,
{
Expand All @@ -477,6 +482,15 @@ where
modulus_bytes_len,
});
}
Ok(input)
}

/// Converts a slice of bytes into a prime field element, represented by the
/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait.
fn bytes_to_prime_field_element<F>(input: &[u8]) -> Result<F, PoseidonError>
where
F: PrimeField,
{
F::from_random_bytes(input).ok_or(PoseidonError::InputLargerThanModulus)
}

Expand Down
189 changes: 91 additions & 98 deletions light-poseidon/tests/bn254_fq_x5.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ark_bn254::Fr;
use ark_ff::{BigInteger, One, PrimeField, Zero};
use ark_ff::{BigInteger, BigInteger256, One, PrimeField, Zero};
use light_poseidon::{Poseidon, PoseidonError};
use light_poseidon::{PoseidonBytesHasher, PoseidonHasher};
use rand::Rng;

#[test]
fn test_poseidon_one() {
Expand Down Expand Up @@ -148,128 +149,120 @@ fn test_poseidon_bn254_x5_fq_hash_bytes_le() {
);
}

#[cfg(feature = "fuzz-tests")]
mod fuzz_tests {
use ark_ff::BigInteger256;
use rand::Rng;

use super::*;

macro_rules! test_random_input_same_results {
($name:ident, $method:ident) => {
#[test]
fn $name() {
let input = [1u8; 32];

for nr_inputs in 1..12 {
let mut hasher = Poseidon::<Fr>::new_circom(nr_inputs).unwrap();

let mut inputs = Vec::with_capacity(nr_inputs);
for _ in 0..nr_inputs {
inputs.push(input.as_slice());
}
macro_rules! test_random_input_same_results {
($name:ident, $method:ident) => {
#[test]
fn $name() {
let input = [1u8; 32];

let hash1 = hasher.$method(inputs.as_slice()).unwrap();
let hash2 = hasher.$method(inputs.as_slice()).unwrap();
for nr_inputs in 1..12 {
let mut hasher = Poseidon::<Fr>::new_circom(nr_inputs).unwrap();

assert_eq!(hash1, hash2);
let mut inputs = Vec::with_capacity(nr_inputs);
for _ in 0..nr_inputs {
inputs.push(input.as_slice());
}
}
};
}

test_random_input_same_results!(
test_poseidon_bn254_x5_fq_hash_bytes_be_random_input_same_results,
hash_bytes_be
);

test_random_input_same_results!(
test_poseidon_bn254_x5_fq_hash_bytes_le_random_input_same_results,
hash_bytes_le
);
let hash1 = hasher.$method(inputs.as_slice()).unwrap();
let hash2 = hasher.$method(inputs.as_slice()).unwrap();

macro_rules! test_invalid_input_length {
($name:ident, $method:ident) => {
#[test]
fn $name() {
let mut rng = rand::thread_rng();

for _ in 0..100 {
let len = rng.gen_range(33..524_288_000); // Maximum 500 MB.
let input = vec![1u8; len];

for nr_inputs in 1..12 {
let mut hasher = Poseidon::<Fr>::new_circom(nr_inputs).unwrap();

let mut inputs = Vec::with_capacity(nr_inputs);
for _ in 0..nr_inputs {
inputs.push(input.as_slice());
}

let hash = hasher.$method(inputs.as_slice());
assert_eq!(
hash,
Err(PoseidonError::InvalidInputLength {
len,
modulus_bytes_len: 32,
})
);
}
}
assert_eq!(hash1, hash2);
}
};
}
}
};
}

test_invalid_input_length!(
test_poseidon_bn254_x5_fq_hash_bytes_be_invalid_input_length,
hash_bytes_be
);
test_random_input_same_results!(
test_poseidon_bn254_x5_fq_hash_bytes_be_random_input_same_results,
hash_bytes_be
);

test_invalid_input_length!(
test_poseidon_bn254_x5_fq_hash_bytes_le_invalid_input_length,
hash_bytes_le
);
test_random_input_same_results!(
test_poseidon_bn254_x5_fq_hash_bytes_le_random_input_same_results,
hash_bytes_le
);

macro_rules! test_input_gt_field_size {
($name:ident, $method:ident, $to_bytes_method:ident) => {
#[test]
fn $name() {
let mut greater_than_field_size = Fr::MODULUS;
let mut rng = rand::thread_rng();
let random_number = rng.gen_range(1u64..1_000_000u64);
greater_than_field_size.add_with_carry(&BigInteger256::from(random_number));
let greater_than_field_size = greater_than_field_size.$to_bytes_method();
macro_rules! test_invalid_input_length {
($name:ident, $method:ident) => {
#[test]
fn $name() {
let mut rng = rand::thread_rng();

assert_eq!(greater_than_field_size.len(), 32);
for _ in 0..100 {
let len = rng.gen_range(33..524_288_000); // Maximum 500 MB.
let input = vec![1u8; len];

for nr_inputs in 1..12 {
let mut hasher = Poseidon::<Fr>::new_circom(nr_inputs).unwrap();

let mut inputs = Vec::with_capacity(nr_inputs);
for _ in 0..nr_inputs {
inputs.push(&greater_than_field_size[..]);
inputs.push(input.as_slice());
}

let hash = hasher.$method(inputs.as_slice());
assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus));
assert_eq!(
hash,
Err(PoseidonError::InvalidInputLength {
len,
modulus_bytes_len: 32,
})
);
}
}
};
}
}
};
}

test_input_gt_field_size!(
test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size,
hash_bytes_be,
to_bytes_be
);
test_invalid_input_length!(
test_poseidon_bn254_x5_fq_hash_bytes_be_invalid_input_length,
hash_bytes_be
);

test_input_gt_field_size!(
test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size,
hash_bytes_le,
to_bytes_le
);
test_invalid_input_length!(
test_poseidon_bn254_x5_fq_hash_bytes_le_invalid_input_length,
hash_bytes_le
);

macro_rules! test_fuzz_input_gt_field_size {
($name:ident, $method:ident, $to_bytes_method:ident) => {
#[test]
fn $name() {
let mut greater_than_field_size = Fr::MODULUS;
let mut rng = rand::thread_rng();
let random_number = rng.gen_range(1u64..1_000_000u64);
greater_than_field_size.add_with_carry(&BigInteger256::from(random_number));
let greater_than_field_size = greater_than_field_size.$to_bytes_method();

assert_eq!(greater_than_field_size.len(), 32);

for nr_inputs in 1..12 {
let mut hasher = Poseidon::<Fr>::new_circom(nr_inputs).unwrap();

let mut inputs = Vec::with_capacity(nr_inputs);
for _ in 0..nr_inputs {
inputs.push(&greater_than_field_size[..]);
}

let hash = hasher.$method(inputs.as_slice());
assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus));
}
}
};
}

test_fuzz_input_gt_field_size!(
test_fuzz_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size,
hash_bytes_be,
to_bytes_be
);

test_fuzz_input_gt_field_size!(
test_fuzz_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size,
hash_bytes_le,
to_bytes_le
);

macro_rules! test_input_gt_field_size {
($name:ident, $method:ident, $greater_than_field_size:expr) => {
#[test]
Expand Down