Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils/mmr: safety & performance rework #274

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 83 additions & 66 deletions crates/util/mmr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,95 @@
use error::MerkleError;
use hasher::{Hash, MerkleHasher};

fn zero() -> Hash {
[0; 32]
}

fn is_zero(h: Hash) -> bool {
h.iter().all(|b| *b == 0)
}

/// Compact representation of the MMR that should be borsh serializable easily.
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, Arbitrary)]
pub struct CompactMmr {
Comment on lines 12 to 14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be really cool is to write manual borsh serialization so that we don't have to store the length of the peaks list and can infer it from the element count (since it's redundant, it's just popcnt).

entries: u64,
cap_log2: u8,
roots: Vec<Hash>,
element_count: u64,
peaks: Vec<Hash>,
}

const ZERO: [u8; 32] = [0; 32];

#[derive(Clone)]
pub struct MerkleMr<H: MerkleHasher + Clone> {
// number of elements inserted into mmr
pub num: u64,
// Buffer of all possible peaks in mmr. only some of them will be valid at a time
/// number of elements inserted into mmr
pub element_count: u64,
/// Buffer of all possible peaks in a MMR.
/// Only some of them will be valid at a time.
pub peaks: Box<[Hash]>,
// phantom data for hasher
/// [`PhantomData`] for the hasher.
pub hasher: PhantomData<H>,
}

impl<H: MerkleHasher + Clone> MerkleMr<H> {
pub fn new(cap_log2: usize) -> Self {
pub fn new(peak_count: usize) -> Self {
Self {
num: 0,
peaks: vec![[0; 32]; cap_log2].into_boxed_slice(),
element_count: 0,
peaks: vec![ZERO; peak_count].into_boxed_slice(),
hasher: PhantomData,
}
}

pub fn from_compact(compact: &CompactMmr) -> Self {
// FIXME this is somewhat inefficient, we could consume the vec and just
// slice out its elements, but this is fine for now
let mut roots = vec![zero(); compact.cap_log2 as usize];
let mut at = 0;
for i in 0..compact.cap_log2 {
if compact.entries >> i & 1 != 0 {
roots[i as usize] = compact.roots[at as usize];
at += 1;
}
/// Returns the minimum peaks needed to store an MMR of `element_count` elements
#[inline]
fn min_peaks(element_count: u64) -> usize {
match element_count {
0 => 0,

Check warning on line 45 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L45

Added line #L45 was not covered by tests
c => c.ilog2() as usize + 1,
Comment on lines +41 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should split this out into a toplevel function so we don't have to call it with Self::.

}
}

/// Restores from a [`CompactMmr`]
pub fn from_compact(compact: CompactMmr) -> Self {
let required_peaks = Self::min_peaks(compact.element_count);
let peaks = match required_peaks {
0 => vec![ZERO],

Check warning on line 54 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L54

Added line #L54 was not covered by tests
required => {
let mut peaks = compact.peaks;
// this shouldn't ever need to run with the below Self::to_compact
// as that will truncate it automatically to the correct length.
// this is mostly for safety.
if peaks.len() < required {
let num_to_add = required - peaks.len();
peaks.reserve_exact(num_to_add);
// NOTE: we add in a loop so we don't need to make 2 allocs
(0..num_to_add).for_each(|_| peaks.push(ZERO))

Check warning on line 64 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L61-L64

Added lines #L61 - L64 were not covered by tests
}
peaks
}
};

Self {
num: compact.entries,
peaks: roots.into(),
element_count: compact.element_count,
peaks: peaks.into(),
hasher: PhantomData,
}
}

/// Exports a "compact" version of the MMR that can be easily serialized
pub fn to_compact(&self) -> CompactMmr {
let min_peaks = Self::min_peaks(self.element_count);
// self.peaks should always have enough peaks to hold its elements
assert!(self.peaks.len() >= min_peaks);
CompactMmr {
entries: self.num,
cap_log2: self.peaks.len() as u8,
roots: self
.peaks
.iter()
.filter(|h| !is_zero(**h))
.copied()
.collect(),
element_count: self.element_count,
peaks: match min_peaks {
0 => vec![ZERO],

Check warning on line 85 in crates/util/mmr/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/util/mmr/src/lib.rs#L85

Added line #L85 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to tweak the representation on the decode side so an empty compact MMR has no peaks at all and this can just be an empty buffer.

required => self.peaks.iter().take(required).copied().collect(),
},
}
}

/// Adds a leaf's hash to the MMR
pub fn add_leaf(&mut self, hash_arr: Hash) {
if self.num == 0 {
if self.element_count == 0 {
self.peaks[0] = hash_arr;
self.num += 1;
self.element_count += 1;
return;
}

// the number of elements in MMR is also the mask of peaks
let peak_mask = self.num;
let peak_mask = self.element_count;

let mut current_node = hash_arr;
// we iterate through the height
Expand All @@ -101,26 +114,26 @@
}

self.peaks[current_height] = current_node;
self.num += 1;
self.element_count += 1;
}

pub fn get_single_root(&self) -> Result<Hash, MerkleError> {
if self.num == 0 {
if self.element_count == 0 {
return Err(MerkleError::NoElements);
}
if !self.num.is_power_of_two() && self.num != 1 {
if !self.element_count.is_power_of_two() && self.element_count != 1 {
return Err(MerkleError::NotPowerOfTwo);
}

Ok(self.peaks[(self.num.ilog2()) as usize])
Ok(self.peaks[(self.element_count.ilog2()) as usize])
}

pub fn add_leaf_updating_proof(
&mut self,
next: Hash,
proof: &MerkleProof<H>,
) -> MerkleProof<H> {
if self.num == 0 {
if self.element_count == 0 {
self.add_leaf(next);
return MerkleProof {
cohashes: vec![],
Expand All @@ -130,8 +143,8 @@
}
let mut updated_proof = proof.clone();

let new_leaf_index = self.num;
let peak_mask = self.num;
let new_leaf_index = self.element_count;
let peak_mask = self.element_count;
let mut current_node = next;
let mut current_height = 0;
while (peak_mask >> current_height) & 1 == 1 {
Expand All @@ -152,7 +165,7 @@
}

self.peaks[current_height] = current_node;
self.num += 1;
self.element_count += 1;

updated_proof
}
Expand Down Expand Up @@ -184,7 +197,7 @@
next: Hash,
proof_list: &mut [MerkleProof<H>],
) -> MerkleProof<H> {
if self.num == 0 {
if self.element_count == 0 {
self.add_leaf(next);
return MerkleProof {
cohashes: vec![],
Expand All @@ -194,12 +207,12 @@
}
let mut new_proof = MerkleProof {
cohashes: vec![],
index: self.num,
index: self.element_count,
_pd: PhantomData,
};

let new_leaf_index = self.num;
let peak_mask = self.num;
let new_leaf_index = self.element_count;
let peak_mask = self.element_count;
let mut current_node = next;
let mut current_height = 0;
while (peak_mask >> current_height) & 1 == 1 {
Expand Down Expand Up @@ -230,7 +243,7 @@
}

self.peaks[current_height] = current_node;
self.num += 1;
self.element_count += 1;

new_proof
}
Expand Down Expand Up @@ -267,7 +280,7 @@
proof_list: &[MerkleProof<H>],
index: u64,
) -> Result<Option<MerkleProof<H>>, MerkleError> {
if index > self.num {
if index > self.element_count {
return Err(MerkleError::IndexOutOfBounds);
}

Expand Down Expand Up @@ -315,7 +328,9 @@
use super::{hasher::Hash, MerkleMr, MerkleProof};
use crate::error::MerkleError;

fn generate_for_n_integers(n: usize) -> (MerkleMr<Sha256>, Vec<MerkleProof<Sha256>>) {
fn generate_for_n_integers(
n: usize,
) -> (MerkleMr<Sha256>, Vec<MerkleProof<Sha256>>, Vec<[u8; 32]>) {
let mut mmr: MerkleMr<Sha256> = MerkleMr::new(14);

let mut proof = Vec::new();
Expand All @@ -325,7 +340,7 @@
let new_proof = mmr.add_leaf_updating_proof_list(list_of_hashes[i], &mut proof);
proof.push(new_proof);
});
(mmr, proof)
(mmr, proof, list_of_hashes)
}

fn generate_hashes_for_n_integers(n: usize) -> Vec<Hash> {
Expand All @@ -335,7 +350,7 @@
}

fn mmr_proof_for_specific_nodes(n: usize, specific_nodes: Vec<u64>) {
let (mmr, proof_list) = generate_for_n_integers(n);
let (mmr, proof_list, _) = generate_for_n_integers(n);
let proof: Vec<MerkleProof<Sha256>> = specific_nodes
.iter()
.map(|i| {
Expand Down Expand Up @@ -368,7 +383,7 @@

#[test]
fn check_single_element() {
let (mmr, proof_list) = generate_for_n_integers(1);
let (mmr, proof_list, _) = generate_for_n_integers(1);

let proof = mmr
.gen_proof(&proof_list, 0)
Expand Down Expand Up @@ -438,7 +453,7 @@

#[test]
fn check_invalid_proof() {
let (mmr, _) = generate_for_n_integers(5);
let (mmr, ..) = generate_for_n_integers(5);
let invalid_proof = MerkleProof::<Sha256> {
cohashes: vec![],
index: 6,
Expand Down Expand Up @@ -485,18 +500,20 @@

#[test]
fn check_compact_and_non_compact() {
let (mmr, _) = generate_for_n_integers(5);
let (mmr, proofs, hashes) = generate_for_n_integers(5);

let compact_mmr = mmr.to_compact();
let deserialized_mmr = MerkleMr::<Sha256>::from_compact(&compact_mmr);
let deserialized_mmr = MerkleMr::<Sha256>::from_compact(compact_mmr);

assert_eq!(mmr.num, deserialized_mmr.num);
assert_eq!(mmr.peaks, deserialized_mmr.peaks);
assert_eq!(mmr.element_count, deserialized_mmr.element_count);
for (i, proof) in proofs.into_iter().enumerate() {
assert!(deserialized_mmr.verify(&proof, &hashes[i]))
}
}

#[test]
fn arbitrary_index_proof() {
let (mut mmr, _) = generate_for_n_integers(20);
let (mut mmr, ..) = generate_for_n_integers(20);
// update proof for 21st element
let mut proof: MerkleProof<Sha256> = MerkleProof {
cohashes: vec![],
Expand All @@ -518,7 +535,7 @@

#[test]
fn update_proof_list_from_arbitrary_index() {
let (mut mmr, _) = generate_for_n_integers(20);
let (mut mmr, ..) = generate_for_n_integers(20);
// update proof for 21st element
let mut proof_list = Vec::new();

Expand Down
Loading