-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
utils/mmr: safety & performance rework #274
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,82 +9,95 @@ | |
use error::MerkleError; | ||
use hasher::{Hash, MerkleHasher}; | ||
|
||
fn zero() -> Hash { | ||
[0; 32] | ||
} | ||
|
||
fn is_zero(h: Hash) -> bool { | ||
h.iter().all(|b| *b == 0) | ||
} | ||
|
||
/// Compact representation of the MMR that should be borsh serializable easily. | ||
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, Arbitrary)] | ||
pub struct CompactMmr { | ||
entries: u64, | ||
cap_log2: u8, | ||
roots: Vec<Hash>, | ||
element_count: u64, | ||
peaks: Vec<Hash>, | ||
} | ||
|
||
const ZERO: [u8; 32] = [0; 32]; | ||
|
||
#[derive(Clone)] | ||
pub struct MerkleMr<H: MerkleHasher + Clone> { | ||
// number of elements inserted into mmr | ||
pub num: u64, | ||
// Buffer of all possible peaks in mmr. only some of them will be valid at a time | ||
/// number of elements inserted into mmr | ||
pub element_count: u64, | ||
/// Buffer of all possible peaks in a MMR. | ||
/// Only some of them will be valid at a time. | ||
pub peaks: Box<[Hash]>, | ||
// phantom data for hasher | ||
/// [`PhantomData`] for the hasher. | ||
pub hasher: PhantomData<H>, | ||
} | ||
|
||
impl<H: MerkleHasher + Clone> MerkleMr<H> { | ||
pub fn new(cap_log2: usize) -> Self { | ||
pub fn new(peak_count: usize) -> Self { | ||
Self { | ||
num: 0, | ||
peaks: vec![[0; 32]; cap_log2].into_boxed_slice(), | ||
element_count: 0, | ||
peaks: vec![ZERO; peak_count].into_boxed_slice(), | ||
hasher: PhantomData, | ||
} | ||
} | ||
|
||
pub fn from_compact(compact: &CompactMmr) -> Self { | ||
// FIXME this is somewhat inefficient, we could consume the vec and just | ||
// slice out its elements, but this is fine for now | ||
let mut roots = vec![zero(); compact.cap_log2 as usize]; | ||
let mut at = 0; | ||
for i in 0..compact.cap_log2 { | ||
if compact.entries >> i & 1 != 0 { | ||
roots[i as usize] = compact.roots[at as usize]; | ||
at += 1; | ||
} | ||
/// Returns the minimum peaks needed to store an MMR of `element_count` elements | ||
#[inline] | ||
fn min_peaks(element_count: u64) -> usize { | ||
match element_count { | ||
0 => 0, | ||
c => c.ilog2() as usize + 1, | ||
Comment on lines
+41
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should split this out into a toplevel function so we don't have to call it with |
||
} | ||
} | ||
|
||
/// Restores from a [`CompactMmr`] | ||
pub fn from_compact(compact: CompactMmr) -> Self { | ||
let required_peaks = Self::min_peaks(compact.element_count); | ||
let peaks = match required_peaks { | ||
0 => vec![ZERO], | ||
required => { | ||
let mut peaks = compact.peaks; | ||
// this shouldn't ever need to run with the below Self::to_compact | ||
// as that will truncate it automatically to the correct length. | ||
// this is mostly for safety. | ||
if peaks.len() < required { | ||
let num_to_add = required - peaks.len(); | ||
peaks.reserve_exact(num_to_add); | ||
// NOTE: we add in a loop so we don't need to make 2 allocs | ||
(0..num_to_add).for_each(|_| peaks.push(ZERO)) | ||
} | ||
peaks | ||
} | ||
}; | ||
|
||
Self { | ||
num: compact.entries, | ||
peaks: roots.into(), | ||
element_count: compact.element_count, | ||
peaks: peaks.into(), | ||
hasher: PhantomData, | ||
} | ||
} | ||
|
||
/// Exports a "compact" version of the MMR that can be easily serialized | ||
pub fn to_compact(&self) -> CompactMmr { | ||
let min_peaks = Self::min_peaks(self.element_count); | ||
// self.peaks should always have enough peaks to hold its elements | ||
assert!(self.peaks.len() >= min_peaks); | ||
CompactMmr { | ||
entries: self.num, | ||
cap_log2: self.peaks.len() as u8, | ||
roots: self | ||
.peaks | ||
.iter() | ||
.filter(|h| !is_zero(**h)) | ||
.copied() | ||
.collect(), | ||
element_count: self.element_count, | ||
peaks: match min_peaks { | ||
0 => vec![ZERO], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want to tweak the representation on the decode side so an empty compact MMR has no peaks at all and this can just be an empty buffer. |
||
required => self.peaks.iter().take(required).copied().collect(), | ||
}, | ||
} | ||
} | ||
|
||
/// Adds a leaf's hash to the MMR | ||
pub fn add_leaf(&mut self, hash_arr: Hash) { | ||
if self.num == 0 { | ||
if self.element_count == 0 { | ||
self.peaks[0] = hash_arr; | ||
self.num += 1; | ||
self.element_count += 1; | ||
return; | ||
} | ||
|
||
// the number of elements in MMR is also the mask of peaks | ||
let peak_mask = self.num; | ||
let peak_mask = self.element_count; | ||
|
||
let mut current_node = hash_arr; | ||
// we iterate through the height | ||
|
@@ -101,26 +114,26 @@ | |
} | ||
|
||
self.peaks[current_height] = current_node; | ||
self.num += 1; | ||
self.element_count += 1; | ||
} | ||
|
||
pub fn get_single_root(&self) -> Result<Hash, MerkleError> { | ||
if self.num == 0 { | ||
if self.element_count == 0 { | ||
return Err(MerkleError::NoElements); | ||
} | ||
if !self.num.is_power_of_two() && self.num != 1 { | ||
if !self.element_count.is_power_of_two() && self.element_count != 1 { | ||
return Err(MerkleError::NotPowerOfTwo); | ||
} | ||
|
||
Ok(self.peaks[(self.num.ilog2()) as usize]) | ||
Ok(self.peaks[(self.element_count.ilog2()) as usize]) | ||
} | ||
|
||
pub fn add_leaf_updating_proof( | ||
&mut self, | ||
next: Hash, | ||
proof: &MerkleProof<H>, | ||
) -> MerkleProof<H> { | ||
if self.num == 0 { | ||
if self.element_count == 0 { | ||
self.add_leaf(next); | ||
return MerkleProof { | ||
cohashes: vec![], | ||
|
@@ -130,8 +143,8 @@ | |
} | ||
let mut updated_proof = proof.clone(); | ||
|
||
let new_leaf_index = self.num; | ||
let peak_mask = self.num; | ||
let new_leaf_index = self.element_count; | ||
let peak_mask = self.element_count; | ||
let mut current_node = next; | ||
let mut current_height = 0; | ||
while (peak_mask >> current_height) & 1 == 1 { | ||
|
@@ -152,7 +165,7 @@ | |
} | ||
|
||
self.peaks[current_height] = current_node; | ||
self.num += 1; | ||
self.element_count += 1; | ||
|
||
updated_proof | ||
} | ||
|
@@ -184,7 +197,7 @@ | |
next: Hash, | ||
proof_list: &mut [MerkleProof<H>], | ||
) -> MerkleProof<H> { | ||
if self.num == 0 { | ||
if self.element_count == 0 { | ||
self.add_leaf(next); | ||
return MerkleProof { | ||
cohashes: vec![], | ||
|
@@ -194,12 +207,12 @@ | |
} | ||
let mut new_proof = MerkleProof { | ||
cohashes: vec![], | ||
index: self.num, | ||
index: self.element_count, | ||
_pd: PhantomData, | ||
}; | ||
|
||
let new_leaf_index = self.num; | ||
let peak_mask = self.num; | ||
let new_leaf_index = self.element_count; | ||
let peak_mask = self.element_count; | ||
let mut current_node = next; | ||
let mut current_height = 0; | ||
while (peak_mask >> current_height) & 1 == 1 { | ||
|
@@ -230,7 +243,7 @@ | |
} | ||
|
||
self.peaks[current_height] = current_node; | ||
self.num += 1; | ||
self.element_count += 1; | ||
|
||
new_proof | ||
} | ||
|
@@ -267,7 +280,7 @@ | |
proof_list: &[MerkleProof<H>], | ||
index: u64, | ||
) -> Result<Option<MerkleProof<H>>, MerkleError> { | ||
if index > self.num { | ||
if index > self.element_count { | ||
return Err(MerkleError::IndexOutOfBounds); | ||
} | ||
|
||
|
@@ -315,7 +328,9 @@ | |
use super::{hasher::Hash, MerkleMr, MerkleProof}; | ||
use crate::error::MerkleError; | ||
|
||
fn generate_for_n_integers(n: usize) -> (MerkleMr<Sha256>, Vec<MerkleProof<Sha256>>) { | ||
fn generate_for_n_integers( | ||
n: usize, | ||
) -> (MerkleMr<Sha256>, Vec<MerkleProof<Sha256>>, Vec<[u8; 32]>) { | ||
let mut mmr: MerkleMr<Sha256> = MerkleMr::new(14); | ||
|
||
let mut proof = Vec::new(); | ||
|
@@ -325,7 +340,7 @@ | |
let new_proof = mmr.add_leaf_updating_proof_list(list_of_hashes[i], &mut proof); | ||
proof.push(new_proof); | ||
}); | ||
(mmr, proof) | ||
(mmr, proof, list_of_hashes) | ||
} | ||
|
||
fn generate_hashes_for_n_integers(n: usize) -> Vec<Hash> { | ||
|
@@ -335,7 +350,7 @@ | |
} | ||
|
||
fn mmr_proof_for_specific_nodes(n: usize, specific_nodes: Vec<u64>) { | ||
let (mmr, proof_list) = generate_for_n_integers(n); | ||
let (mmr, proof_list, _) = generate_for_n_integers(n); | ||
let proof: Vec<MerkleProof<Sha256>> = specific_nodes | ||
.iter() | ||
.map(|i| { | ||
|
@@ -368,7 +383,7 @@ | |
|
||
#[test] | ||
fn check_single_element() { | ||
let (mmr, proof_list) = generate_for_n_integers(1); | ||
let (mmr, proof_list, _) = generate_for_n_integers(1); | ||
|
||
let proof = mmr | ||
.gen_proof(&proof_list, 0) | ||
|
@@ -438,7 +453,7 @@ | |
|
||
#[test] | ||
fn check_invalid_proof() { | ||
let (mmr, _) = generate_for_n_integers(5); | ||
let (mmr, ..) = generate_for_n_integers(5); | ||
let invalid_proof = MerkleProof::<Sha256> { | ||
cohashes: vec![], | ||
index: 6, | ||
|
@@ -485,18 +500,20 @@ | |
|
||
#[test] | ||
fn check_compact_and_non_compact() { | ||
let (mmr, _) = generate_for_n_integers(5); | ||
let (mmr, proofs, hashes) = generate_for_n_integers(5); | ||
|
||
let compact_mmr = mmr.to_compact(); | ||
let deserialized_mmr = MerkleMr::<Sha256>::from_compact(&compact_mmr); | ||
let deserialized_mmr = MerkleMr::<Sha256>::from_compact(compact_mmr); | ||
|
||
assert_eq!(mmr.num, deserialized_mmr.num); | ||
assert_eq!(mmr.peaks, deserialized_mmr.peaks); | ||
assert_eq!(mmr.element_count, deserialized_mmr.element_count); | ||
for (i, proof) in proofs.into_iter().enumerate() { | ||
assert!(deserialized_mmr.verify(&proof, &hashes[i])) | ||
} | ||
} | ||
|
||
#[test] | ||
fn arbitrary_index_proof() { | ||
let (mut mmr, _) = generate_for_n_integers(20); | ||
let (mut mmr, ..) = generate_for_n_integers(20); | ||
// update proof for 21st element | ||
let mut proof: MerkleProof<Sha256> = MerkleProof { | ||
cohashes: vec![], | ||
|
@@ -518,7 +535,7 @@ | |
|
||
#[test] | ||
fn update_proof_list_from_arbitrary_index() { | ||
let (mut mmr, _) = generate_for_n_integers(20); | ||
let (mut mmr, ..) = generate_for_n_integers(20); | ||
// update proof for 21st element | ||
let mut proof_list = Vec::new(); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be really cool is to write manual borsh serialization so that we don't have to store the length of the peaks list and can infer it from the element count (since it's redundant, it's just popcnt).