Skip to content

Commit aed6067

Browse files
committed
optimize Poseidon for larger number of rounds
1 parent 6696dae commit aed6067

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ serde_derive = "1.0.80"
2525
tiny-keccak = "1.4.2"
2626
rust-crypto = "0.2"
2727

28-
#bellman_ce = { path = "../bellman"}
29-
bellman_ce = { version = "0.3.0", default-features = false}
28+
bellman_ce = { path = "../bellman"}
29+
#bellman_ce = { version = "0.3.0", default-features = false}
3030
blake2-rfc_bellman_edition = "0.0.1"
3131

3232
[dev-dependencies]

src/circuit/num.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,9 +778,30 @@ impl<E: Engine> Num<E> {
778778

779779
self.value = newval;
780780
let mut lc = LinearCombination::zero();
781-
std::mem::swap(&mut self.lc, &mut lc);
781+
// std::mem::swap(&mut self.lc, &mut lc);
782+
use std::collections::HashMap;
783+
let mut final_coeffs: HashMap<bellman::Variable, E::Fr> = HashMap::new();
784+
for (var, coeff) in self.lc.as_ref() {
785+
if final_coeffs.get(var).is_some() {
786+
if let Some(existing_coeff) = final_coeffs.get_mut(var) {
787+
existing_coeff.add_assign(&coeff);
788+
}
789+
} else {
790+
final_coeffs.insert(*var, *coeff);
791+
}
792+
}
793+
782794
for (var, coeff) in other.lc.as_ref() {
783-
lc = lc + (*coeff, var.clone());
795+
if final_coeffs.get(var).is_some() {
796+
if let Some(existing_coeff) = final_coeffs.get_mut(var) {
797+
existing_coeff.add_assign(&coeff);
798+
}
799+
} else {
800+
final_coeffs.insert(*var, *coeff);
801+
}
802+
}
803+
for (var, coeff) in final_coeffs.into_iter() {
804+
lc = lc + (coeff, var);
784805
}
785806
self.lc = lc;
786807
}

src/circuit/poseidon_hash.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,6 @@ impl<E: PoseidonEngine> QuinticSBox<E> {
100100
}
101101
}
102102

103-
// pub fn poseidon_tree_hash<E: PoseidonEngine<SBox = QuinticSBox<E> >, CS>(
104-
// mut cs: CS,
105-
// elements: &[AllocatedNum<E>],
106-
// params: &E::Params
107-
108103
pub fn poseidon_hash<E: PoseidonEngine<SBox = QuinticSBox<E> >, CS>(
109104
mut cs: CS,
110105
input: &[AllocatedNum<E>],
@@ -311,6 +306,12 @@ fn poseidon_mimc_round<E: PoseidonEngine<SBox = QuinticSBox<E> >, CS>(
311306
round += 1;
312307
}
313308

309+
// up to this point we always made a well-formed LC that later was collapsed into
310+
// a signel variable after non-linearity application
311+
// now we need to make linear combinations of linear combinations, so basically make
312+
// filtering and joining. It's actually possible to just separate MSD matrix into
313+
// three in later optimizations
314+
314315
// now we need to apply full SBox of the last full round, then do linear
315316
// transformation and add first round constants before going through partial rounds
316317
{
@@ -329,6 +330,9 @@ fn poseidon_mimc_round<E: PoseidonEngine<SBox = QuinticSBox<E> >, CS>(
329330
add_round_constants::<E, CS>(params, &mut linear_transformation_results[..], 0, false);
330331
state = linear_transformation_results;
331332

333+
// up to this point linear combinations are well-formed and have number
334+
// of terms equal to the number of variables in the state
335+
332336
round += 1;
333337
}
334338

@@ -462,7 +466,7 @@ mod test {
462466
let mut rng = XorShiftRng::from_seed([0x3dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
463467
let params = Bn256PoseidonParams::new::<BlakeHasher>();
464468
let input: Vec<Fr> = (0..params.t()).map(|_| rng.gen()).collect();
465-
let expected = poseidon::poseidon_hash::<Bn256>(&params, &input[..]);
469+
let expected = poseidon::poseidon_mimc::<Bn256>(&params, &input[..]);
466470

467471
{
468472
let mut cs = TestConstraintSystem::<Bn256>::new();
@@ -482,7 +486,7 @@ mod test {
482486
).unwrap();
483487

484488
assert!(cs.is_satisfied());
485-
assert!(res.len() == 1);
489+
assert!(res.len() == (params.t() as usize));
486490

487491
assert_eq!(res[0].get_value().unwrap(), expected[0]);
488492
}

0 commit comments

Comments
 (0)