Skip to content

Commit

Permalink
Fixing airgroupvalues to be compatible with new airvalues handling
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerTaule committed Dec 10, 2024
1 parent c0625c6 commit dd6e4c5
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 94 deletions.
102 changes: 22 additions & 80 deletions common/src/air_instance.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::sync::Arc;
use std::ptr;
use std::path::PathBuf;
use p3_field::Field;
use proofman_util::create_buffer_fast;

use crate::{trace::Trace, trace::Values, SetupCtx, Setup, StarkInfo};
use crate::{trace::Trace, trace::Values, SetupCtx, Setup, StarkInfo, PolMap};

#[repr(C)]
pub struct StepsParams {
Expand Down Expand Up @@ -41,11 +42,12 @@ impl Default for StepsParams {
p_const_pols: ptr::null_mut(),
p_const_tree: ptr::null_mut(),
custom_commits: [ptr::null_mut(); 10],
custom_commits_extended: [ptr::null_mut(); 10],
}
}
}

#[derive(Default)]
#[derive(Default, Clone)]
pub struct CustomCommitsInfo<F> {
pub buffer: Vec<F>,
pub cached_file: PathBuf,
Expand Down Expand Up @@ -255,24 +257,23 @@ impl<F: Field> AirInstance<F> {
self.custom_commits[commit_id as usize].buffer = buffer;
}

pub fn set_airvalue(&mut self, name: &str, lengths: Option<Vec<u64>>, value: F) {
let airvalues_map = self.stark_info.airvalues_map.as_ref().unwrap();
fn find_value_map(values_map: &Vec<PolMap>, name: &str, lengths: Option<Vec<u64>>) -> usize {
let mut id = 0;
let mut found = false;
for air_value in airvalues_map {
for value in values_map {
// Check if name matches
let name_matches = air_value.name == name;
let name_matches = value.name == name;


// If lengths is provided, check that it matches airvalue.lengths
let lengths_match = if let Some(ref provided_lengths) = lengths {
Some(&air_value.lengths) == Some(provided_lengths)
Some(&value.lengths) == Some(provided_lengths)
} else {
true // If lengths is None, skip the lengths check
};

if !name_matches || !lengths_match {
if air_value.stage == 1 {
if value.stage == 1 {
id += 1;
} else {
id += 3;
Expand All @@ -287,40 +288,18 @@ impl<F: Field> AirInstance<F> {
panic!("Name {} with specified lengths {:?} not found in airvalues", name, lengths);
}

id
}

pub fn set_airvalue(&mut self, name: &str, lengths: Option<Vec<u64>>, value: F) {
let airvalues_map = self.stark_info.airvalues_map.as_ref().unwrap();
let id = Self::find_value_map(airvalues_map, name, lengths);
self.airvalues[id] = value;
}

pub fn set_airvalue_ext(&mut self, name: &str, lengths: Option<Vec<u64>>, value: Vec<F>) {
let airvalues_map = self.stark_info.airvalues_map.as_ref().unwrap();
let mut id = 0;
let mut found = false;
for air_value in airvalues_map {
// Check if name matches
let name_matches = air_value.name == name;


// If lengths is provided, check that it matches airvalue.lengths
let lengths_match = if let Some(ref provided_lengths) = lengths {
Some(&air_value.lengths) == Some(provided_lengths)
} else {
true // If lengths is None, skip the lengths check
};

if !name_matches || !lengths_match {
if air_value.stage == 1 {
id += 1;
} else {
id += 3;
}
} else {
found = true;
break;
}
}

if !found {
panic!("Name {} with specified lengths {:?} not found in airvalues", name, lengths);
}
let id = Self::find_value_map(airvalues_map, name, lengths);

assert!(value.len() == 3, "Value vector must have exactly 3 elements");

Expand All @@ -333,58 +312,21 @@ impl<F: Field> AirInstance<F> {

pub fn set_airgroupvalue(&mut self, name: &str, lengths: Option<Vec<u64>>, value: F) {
let airgroupvalues_map = self.stark_info.airgroupvalues_map.as_ref().unwrap();
let airgroupvalue_id = (0..airgroupvalues_map.len())
.find(|&i| {
let airgroupvalue = airgroupvalues_map.get(i).unwrap();

// Check if name matches
let name_matches = airgroupvalues_map[i].name == name;

// If lengths is provided, check that it matches airgroupvalues.lengths
let lengths_match = if let Some(ref provided_lengths) = lengths {
Some(&airgroupvalue.lengths) == Some(provided_lengths)
} else {
true // If lengths is None, skip the lengths check
};

name_matches && lengths_match
})
.unwrap_or_else(|| {
panic!("Name {} with specified lengths {:?} not found in airgroupvalues", name, lengths)
});

self.airgroup_values[airgroupvalue_id * 3] = value;
let id = Self::find_value_map(airgroupvalues_map, name, lengths);
self.airgroup_values[id] = value;
}

pub fn set_airgroupvalue_ext(&mut self, name: &str, lengths: Option<Vec<u64>>, value: Vec<F>) {
let airgroupvalues_map = self.stark_info.airgroupvalues_map.as_ref().unwrap();
let airgroupvalue_id = (0..airgroupvalues_map.len())
.find(|&i| {
let airgroupvalue = airgroupvalues_map.get(i).unwrap();

// Check if name matches
let name_matches = airgroupvalues_map[i].name == name;

// If lengths is provided, check that it matches airgroupvalues.lengths
let lengths_match = if let Some(ref provided_lengths) = lengths {
Some(&airgroupvalue.lengths) == Some(provided_lengths)
} else {
true // If lengths is None, skip the lengths check
};

name_matches && lengths_match
})
.unwrap_or_else(|| {
panic!("Name {} with specified lengths {:?} not found in airgroupvalues", name, lengths)
});
let id = Self::find_value_map(airgroupvalues_map, name, lengths);

assert!(value.len() == 3, "Value vector must have exactly 3 elements");

let mut value_iter = value.into_iter();

self.airgroup_values[airgroupvalue_id * 3] = value_iter.next().unwrap();
self.airgroup_values[airgroupvalue_id * 3 + 1] = value_iter.next().unwrap();
self.airgroup_values[airgroupvalue_id * 3 + 2] = value_iter.next().unwrap();
self.airgroup_values[id] = value_iter.next().unwrap();
self.airgroup_values[id + 1] = value_iter.next().unwrap();
self.airgroup_values[id + 2] = value_iter.next().unwrap();
}

pub fn set_air_instance_id(&mut self, air_instance_id: usize, idx: usize) {
Expand Down
17 changes: 13 additions & 4 deletions pil2-stark/src/starkpil/expressions_avx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,23 @@ class ExpressionsAvx : public ExpressionsCtx {
}

Goldilocks3::Element_avx airgroupValues[setupCtx.starkInfo.airgroupValuesMap.size()];
uint64_t p = 0;
for(uint64_t i = 0; i < setupCtx.starkInfo.airgroupValuesMap.size(); ++i) {
airgroupValues[i][0] = _mm256_set1_epi64x(params.airgroupValues[i * FIELD_EXTENSION].fe);
airgroupValues[i][1] = _mm256_set1_epi64x(params.airgroupValues[i * FIELD_EXTENSION + 1].fe);
airgroupValues[i][2] = _mm256_set1_epi64x(params.airgroupValues[i * FIELD_EXTENSION + 2].fe);
if(setupCtx.starkInfo.airgroupValuesMap[i].stage == 1) {
airgroupValues[i][0] = _mm256_set1_epi64x(params.airgroupValues[p].fe);
airgroupValues[i][1] = _mm256_set1_epi64x(0);
airgroupValues[i][2] = _mm256_set1_epi64x(0);
p += 1;
} else {
airgroupValues[i][0] = _mm256_set1_epi64x(params.airgroupValues[p].fe);
airgroupValues[i][1] = _mm256_set1_epi64x(params.airgroupValues[p + 1].fe);
airgroupValues[i][2] = _mm256_set1_epi64x(params.airgroupValues[p + 2].fe);
p += 3;
}
}

Goldilocks3::Element_avx airValues[setupCtx.starkInfo.airValuesMap.size()];
uint64_t p = 0;
p = 0;
for(uint64_t i = 0; i < setupCtx.starkInfo.airValuesMap.size(); ++i) {
if(setupCtx.starkInfo.airValuesMap[i].stage == 1) {
airValues[i][0] = _mm256_set1_epi64x(params.airValues[p].fe);
Expand Down
18 changes: 14 additions & 4 deletions pil2-stark/src/starkpil/expressions_avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,23 @@ class ExpressionsAvx512 : public ExpressionsCtx {
}

Goldilocks3::Element_avx512 airgroupValues[setupCtx.starkInfo.airgroupValuesMap.size()];
uint64_t p = 0;
for(uint64_t i = 0; i < setupCtx.starkInfo.airgroupValuesMap.size(); ++i) {
airgroupValues[i][0] = _mm512_set1_epi64(params.airgroupValues[i * FIELD_EXTENSION].fe);
airgroupValues[i][1] = _mm512_set1_epi64(params.airgroupValues[i * FIELD_EXTENSION + 1].fe);
airgroupValues[i][2] = _mm512_set1_epi64(params.airgroupValues[i * FIELD_EXTENSION + 2].fe);
if(setupCtx.starkInfo.airgroupValuesMap[i].stage == 1) {
airgroupValues[i][0] = _mm512_set1_epi64x(params.airgroupValues[p].fe);
airgroupValues[i][1] = _mm512_set1_epi64x(0);
airgroupValues[i][2] = _mm512_set1_epi64x(0);
p += 1;
} else {
airgroupValues[i][0] = _mm512_set1_epi64x(params.airgroupValues[p].fe);
airgroupValues[i][1] = _mm512_set1_epi64x(params.airgroupValues[p + 1].fe);
airgroupValues[i][2] = _mm512_set1_epi64x(params.airgroupValues[p + 2].fe);
p += 3;
}
}

uint64_t p = 0;
Goldilocks3::Element_avx512 airValues[setupCtx.starkInfo.airValuesMap.size()];
p = 0;
for(uint64_t i = 0; i < setupCtx.starkInfo.airValuesMap.size(); ++i) {
if(setupCtx.starkInfo.airValuesMap[i].stage == 1) {
airValues[i][0] = _mm512_set1_epi64x(params.airValues[p].fe);
Expand Down
17 changes: 13 additions & 4 deletions pil2-stark/src/starkpil/expressions_pack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,18 +319,27 @@ class ExpressionsPack : public ExpressionsCtx {
}

Goldilocks::Element airgroupValues[setupCtx.starkInfo.airgroupValuesMap.size()*FIELD_EXTENSION*nrowsPack];
uint64_t p = 0;
if(!compilation_time) {
for(uint64_t i = 0; i < setupCtx.starkInfo.airgroupValuesMap.size(); ++i) {
for(uint64_t j = 0; j < nrowsPack; ++j) {
airgroupValues[(i*FIELD_EXTENSION)*nrowsPack + j] = params.airgroupValues[i * FIELD_EXTENSION];
airgroupValues[(i*FIELD_EXTENSION + 1)*nrowsPack + j] = params.airgroupValues[i * FIELD_EXTENSION + 1];
airgroupValues[(i*FIELD_EXTENSION + 2)*nrowsPack + j] = params.airgroupValues[i * FIELD_EXTENSION + 2];
if(setupCtx.starkInfo.airgroupValuesMap[i].stage == 1) {
airgroupValues[(i*FIELD_EXTENSION)*nrowsPack + j] = params.airgroupValues[p];
airgroupValues[(i*FIELD_EXTENSION + 1)*nrowsPack + j] = Goldilocks::zero();
airgroupValues[(i*FIELD_EXTENSION + 2)*nrowsPack + j] = Goldilocks::zero();
p += 1;
} else {
airgroupValues[(i*FIELD_EXTENSION)*nrowsPack + j] = params.airgroupValues[p];
airgroupValues[(i*FIELD_EXTENSION + 1)*nrowsPack + j] = params.airgroupValues[p + 1];
airgroupValues[(i*FIELD_EXTENSION + 2)*nrowsPack + j] = params.airgroupValues[p + 2];
p += 3;
}
}
}
}

Goldilocks::Element airValues[setupCtx.starkInfo.airValuesMap.size()*FIELD_EXTENSION*nrowsPack];
uint64_t p = 0;
p = 0;
if(!compilation_time) {
for(uint64_t i = 0; i < setupCtx.starkInfo.airValuesMap.size(); ++i) {
for(uint64_t j = 0; j < nrowsPack; ++j) {
Expand Down
12 changes: 10 additions & 2 deletions pil2-stark/src/starkpil/proof_stark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,17 @@ class Proofs
}

void setAirgroupValues(Goldilocks::Element *_airgroupValues) {
for (uint64_t i = 0; i < airgroupValues.size(); i++)
uint64_t p = 0;
for (uint64_t i = 0; i < starkInfo.airgroupValuesMap.size(); i++)
{
std::memcpy(&airgroupValues[i][0], &_airgroupValues[i * FIELD_EXTENSION], FIELD_EXTENSION * sizeof(Goldilocks::Element));
if(starkInfo.airgroupValuesMap[i].stage == 1) {
airgroupValues[i][0] = _airgroupValues[p++];
airgroupValues[i][1] = Goldilocks::zero();
airgroupValues[i][2] = Goldilocks::zero();
} else {
std::memcpy(&airgroupValues[i][0], &_airgroupValues[p], FIELD_EXTENSION * sizeof(Goldilocks::Element));
p += 3;
}
}
}

Expand Down

0 comments on commit dd6e4c5

Please sign in to comment.