Skip to content

Commit

Permalink
Fused computing lookback goodness and argmax (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
Skielex authored Dec 18, 2024
1 parent 6abb0f8 commit 59de262
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 25 deletions.
40 changes: 16 additions & 24 deletions pco/src/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,37 +125,33 @@ fn lookback_hash_lookup(
}
}

fn lookback_compute_goodness<L: Latent>(
#[inline(never)]
fn find_best_lookback<L: Latent>(
l: L,
i: usize,
latents: &[L],
proposed_lookbacks: &[usize; PROPOSED_LOOKBACKS],
lookback_counts: &mut [u32],
goodnesses: &mut [Bitlen; PROPOSED_LOOKBACKS],
) {
for lookback_idx in 0..PROPOSED_LOOKBACKS {
let lookback = proposed_lookbacks[lookback_idx];
let lookback_count = lookback_counts[lookback - 1];
let other = unsafe { *latents.get_unchecked(i - lookback) };
) -> usize {
let mut best_goodness = 0;
let mut best_lookback: usize = 0;
for &lookback in proposed_lookbacks {
let (lookback_count, other) = unsafe {
(
*lookback_counts.get_unchecked(lookback - 1),
*latents.get_unchecked(i - lookback),
)
};
let lookback_goodness = Bitlen::BITS - lookback_count.leading_zeros();
let delta = L::min(l.wrapping_sub(other), other.wrapping_sub(l));
let delta_goodness = delta.leading_zeros();
goodnesses[lookback_idx] = lookback_goodness + delta_goodness;
}
}

fn lookback_goodness_argmax(goodnesses: &[Bitlen; PROPOSED_LOOKBACKS]) -> usize {
let mut best_goodness = goodnesses[0];
let mut best_idx = 0;

for (i, &goodness) in goodnesses.iter().enumerate().skip(1) {
let goodness = lookback_goodness + delta_goodness;
if goodness > best_goodness {
best_goodness = goodness;
best_idx = i;
best_lookback = lookback;
}
}

best_idx
best_lookback
}

#[inline(never)]
Expand All @@ -178,7 +174,6 @@ fn choose_lookbacks<L: Latent>(config: DeltaLookbackConfig, latents: &[L]) -> Ve
let mut lookbacks = vec![MaybeUninit::uninit(); latents.len() - state_n];
let mut idx_hash_table = vec![0_usize; COARSENESSES.len() * hash_table_n];
let mut proposed_lookbacks = array::from_fn::<_, PROPOSED_LOOKBACKS, _>(|i| (i + 1).min(state_n));
let mut goodnesses = [0; PROPOSED_LOOKBACKS];
let mut best_lookback = 1;
let mut repeating_lookback_idx: usize = 0;
for i in state_n..latents.len() {
Expand All @@ -195,16 +190,13 @@ fn choose_lookbacks<L: Latent>(config: DeltaLookbackConfig, latents: &[L]) -> Ve
&mut idx_hash_table,
&mut proposed_lookbacks,
);
lookback_compute_goodness(
let new_best_lookback = find_best_lookback(
l,
i,
latents,
&proposed_lookbacks,
&mut lookback_counts,
&mut goodnesses,
);
let best_lookback_idx = lookback_goodness_argmax(&goodnesses);
let new_best_lookback = proposed_lookbacks[best_lookback_idx];
if new_best_lookback != best_lookback {
repeating_lookback_idx += 1;
}
Expand Down
2 changes: 1 addition & 1 deletion pco_cli/src/dtypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ impl Parquetable for f16 {
nums.iter().map(|x| x.to_f32()).collect()
}
fn parquet_to_nums(vec: Vec<f32>) -> Vec<Self> {
vec.into_iter().map(|x| f16::from_f32(x)).collect()
vec.into_iter().map(f16::from_f32).collect()
}
}

Expand Down

0 comments on commit 59de262

Please sign in to comment.