Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add valuesort, new control result mode and value normalizer #237

Merged
merged 9 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 103 additions & 52 deletions CHANGELOG.md

Large diffs are not rendered by default.

305 changes: 149 additions & 156 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]

[workspace.package]
version = "0.23.1"
version = "0.24.0"
edition = "2021"
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
keywords = ["sql", "database", "parser", "cli"]
Expand Down
4 changes: 2 additions & 2 deletions sqllogictest-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ glob = "0.3"
itertools = "0.13"
quick-junit = { version = "0.5" }
rand = "0.8"
sqllogictest = { path = "../sqllogictest", version = "0.23" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.23" }
sqllogictest = { path = "../sqllogictest", version = "0.24" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.24" }
tokio = { version = "1", features = [
"rt",
"rt-multi-thread",
Expand Down
5 changes: 3 additions & 2 deletions sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
use rand::distributions::DistString;
use rand::seq::SliceRandom;
use sqllogictest::{
default_column_validator, default_validator, update_record_with_output, AsyncDB, Injected,
MakeConnection, Record, Runner,
default_column_validator, default_normalizer, default_validator, update_record_with_output,
AsyncDB, Injected, MakeConnection, Record, Runner,
};

#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
Expand Down Expand Up @@ -770,6 +770,7 @@ async fn update_record<M: MakeConnection>(
&record_output,
"\t",
default_validator,
default_normalizer,
default_column_validator,
) {
Some(new_record) => {
Expand Down
2 changes: 1 addition & 1 deletion sqllogictest-engines/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sqllogictest = { path = "../sqllogictest", version = "0.23" }
sqllogictest = { path = "../sqllogictest", version = "0.24" }
thiserror = "2"
tokio = { version = "1", features = [
"rt",
Expand Down
49 changes: 49 additions & 0 deletions sqllogictest/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub enum QueryExpect<T: ColumnType> {
Results {
types: Vec<T>,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
label: Option<String>,
results: Vec<String>,
},
Expand All @@ -98,6 +99,7 @@ impl<T: ColumnType> QueryExpect<T> {
Self::Results {
types: Vec::new(),
sort_mode: None,
result_mode: None,
label: None,
results: Vec::new(),
}
Expand Down Expand Up @@ -287,6 +289,7 @@ impl<T: ColumnType> std::fmt::Display for Record<T> {
}
Record::Control(c) => match c {
Control::SortMode(m) => write!(f, "control sortmode {}", m.as_str()),
Control::ResultMode(m) => write!(f, "control resultmode {}", m.as_str()),
Control::Substitution(s) => write!(f, "control substitution {}", s.as_str()),
},
Record::Condition(cond) => match cond {
Expand Down Expand Up @@ -435,6 +438,8 @@ impl PartialEq for ExpectedError {
pub enum Control {
/// Control sort mode.
SortMode(SortMode),
/// control result mode.
ResultMode(ResultMode),
/// Control whether or not to substitute variables in the SQL.
Substitution(bool),
}
Expand Down Expand Up @@ -545,6 +550,38 @@ impl ControlItem for SortMode {
}
}

/// Whether the results should be parsed as value-wise or row-wise
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ResultMode {
/// Results are in a single column
ValueWise,
/// The default option where results are in columns separated by spaces
RowWise,
}

impl ControlItem for ResultMode {
fn try_from_str(s: &str) -> Result<Self, ParseErrorKind> {
match s {
"rowwise" => Ok(Self::RowWise),
"valuewise" => Ok(Self::ValueWise),
_ => Err(ParseErrorKind::InvalidSortMode(s.to_string())),
}
}

fn as_str(&self) -> &'static str {
match self {
Self::RowWise => "rowwise",
Self::ValueWise => "valuewise",
}
}
}

impl fmt::Display for ResultMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}

/// The error type for parsing sqllogictest.
#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
#[error("parse error at {loc}: {kind}")]
Expand Down Expand Up @@ -754,6 +791,7 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
QueryExpect::Results {
types,
sort_mode,
result_mode: None,
label,
results: Vec::new(),
}
Expand Down Expand Up @@ -812,6 +850,12 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
});
}
["control", res @ ..] => match res {
["resultmode", result_mode] => match ResultMode::try_from_str(result_mode) {
Ok(result_mode) => {
records.push(Record::Control(Control::ResultMode(result_mode)))
}
Err(k) => return Err(k.at(loc)),
},
["sortmode", sort_mode] => match SortMode::try_from_str(sort_mode) {
Ok(sort_mode) => records.push(Record::Control(Control::SortMode(sort_mode))),
Err(k) => return Err(k.at(loc)),
Expand Down Expand Up @@ -988,6 +1032,11 @@ mod tests {
parse_roundtrip::<DefaultColumnType>("../tests/slt/rowsort.slt")
}

#[test]
fn test_valuesort() {
parse_roundtrip::<DefaultColumnType>("../tests/slt/valuesort.slt")
}

#[test]
fn test_substitution() {
parse_roundtrip::<DefaultColumnType>("../tests/substitution/basic.slt")
Expand Down
85 changes: 73 additions & 12 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,26 +449,39 @@ fn format_column_diff(expected: &str, actual: &str, colorize: bool) -> String {
format!("[Expected] {expected}\n[Actual ] {actual}")
}

/// Normalizer will be used by [`Runner`] to normalize the result values
///
/// # Default
///
/// By default, the ([`default_normalizer`]) will be used to normalize values.
pub type Normalizer = fn(s: &String) -> String;

/// Trim and replace multiple whitespaces with one.
#[allow(clippy::ptr_arg)]
fn normalize_string(s: &String) -> String {
pub fn default_normalizer(s: &String) -> String {
s.trim().split_ascii_whitespace().join(" ")
}

/// Validator will be used by [`Runner`] to validate the output.
///
/// # Default
///
/// By default ([`default_validator`]), we will use compare normalized results.
pub type Validator = fn(actual: &[Vec<String>], expected: &[String]) -> bool;

pub fn default_validator(actual: &[Vec<String>], expected: &[String]) -> bool {
let expected_results = expected.iter().map(normalize_string).collect_vec();
/// By default, the ([`default_validator`]) will be used compare normalized results.
pub type Validator =
fn(normalizer: Normalizer, actual: &[Vec<String>], expected: &[String]) -> bool;

pub fn default_validator(
normalizer: Normalizer,
actual: &[Vec<String>],
expected: &[String],
) -> bool {
let expected_results = expected.iter().map(normalizer).collect_vec();
// Default, we compare normalized results. Whitespace characters are ignored.
let normalized_rows = actual
.iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.map(|strs| strs.iter().map(normalizer).join(" "))
.collect_vec();

normalized_rows == expected_results
}

Expand Down Expand Up @@ -502,9 +515,12 @@ pub struct Runner<D: AsyncDB, M: MakeConnection> {
conn: Connections<D, M>,
// validator is used for validate if the result of query equals to expected.
validator: Validator,
// normalizer is used to normalize the result text
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
substitution: Option<Substitution>,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
/// 0 means never hashing
hash_threshold: usize,
/// Labels for condition `skipif` and `onlyif`.
Expand All @@ -518,9 +534,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
pub fn new(make_conn: M) -> Self {
Runner {
validator: default_validator,
normalizer: default_normalizer,
column_type_validator: default_column_validator,
substitution: None,
sort_mode: None,
result_mode: None,
hash_threshold: 0,
labels: HashSet::new(),
conn: Connections::new(make_conn),
Expand All @@ -532,6 +550,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
self.labels.insert(label.to_string());
}

pub fn with_normalizer(&mut self, normalizer: Normalizer) {
self.normalizer = normalizer;
}
pub fn with_validator(&mut self, validator: Validator) {
self.validator = validator;
}
Expand Down Expand Up @@ -769,15 +790,31 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
QueryExpect::Error(_) => None,
}
.or(self.sort_mode);

let mut value_sort = false;
match sort_mode {
None | Some(SortMode::NoSort) => {}
Some(SortMode::RowSort) => {
rows.sort_unstable();
}
Some(SortMode::ValueSort) => todo!("value sort"),
Some(SortMode::ValueSort) => {
rows = rows
.iter()
.flat_map(|row| row.iter())
.map(|s| vec![s.to_owned()])
.collect();
rows.sort_unstable();
value_sort = true;
}
};

if self.hash_threshold > 0 && rows.len() * types.len() > self.hash_threshold {
let num_values = if value_sort {
rows.len()
} else {
rows.len() * types.len()
};

if self.hash_threshold > 0 && num_values > self.hash_threshold {
let mut md5 = md5::Md5::new();
for line in &rows {
for value in line {
Expand Down Expand Up @@ -808,6 +845,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
Control::SortMode(sort_mode) => {
self.sort_mode = Some(sort_mode);
}
Control::ResultMode(result_mode) => {
self.result_mode = Some(result_mode);
}
Control::Substitution(on_off) => match (&mut self.substitution, on_off) {
(s @ None, true) => *s = Some(Substitution::default()),
(s @ Some(_), false) => *s = None,
Expand Down Expand Up @@ -996,7 +1036,17 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.at(loc));
}

if !(self.validator)(rows, &expected_results) {
let actual_results = match self.result_mode {
Some(ResultMode::ValueWise) => rows
.iter()
.flat_map(|strs| strs.iter())
.map(|str| vec![str.to_string()])
.collect_vec(),
// default to rowwise
_ => rows.clone(),
};

if !(self.validator)(self.normalizer, &actual_results, &expected_results) {
let output_rows =
rows.iter().map(|strs| strs.iter().join(" ")).collect_vec();
return Err(TestErrorKind::QueryResultMismatch {
Expand Down Expand Up @@ -1167,9 +1217,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
conn_builder(target.clone(), db_name.clone()).map(Ok)
}),
validator: self.validator,
normalizer: self.normalizer,
column_type_validator: self.column_type_validator,
substitution: self.substitution.clone(),
sort_mode: self.sort_mode,
result_mode: self.result_mode,
hash_threshold: self.hash_threshold,
labels: self.labels.clone(),
};
Expand Down Expand Up @@ -1240,6 +1292,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
filename: impl AsRef<Path>,
col_separator: &str,
validator: Validator,
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
) -> Result<(), Box<dyn std::error::Error>> {
use std::io::{Read, Seek, SeekFrom, Write};
Expand Down Expand Up @@ -1355,6 +1408,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
&record_output,
col_separator,
validator,
normalizer,
column_type_validator,
)
.unwrap_or(record);
Expand Down Expand Up @@ -1384,6 +1438,7 @@ pub fn update_record_with_output<T: ColumnType>(
record_output: &RecordOutput<T>,
col_separator: &str,
validator: Validator,
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<T>,
) -> Option<Record<T>> {
match (record.clone(), record_output) {
Expand Down Expand Up @@ -1523,7 +1578,7 @@ pub fn update_record_with_output<T: ColumnType>(
QueryExpect::Results {
results: expected_results,
..
} if validator(rows, expected_results) => expected_results.clone(),
} if validator(normalizer, rows, expected_results) => expected_results.clone(),
_ => rows.iter().map(|cols| cols.join(col_separator)).collect(),
};
let types = match &expected {
Expand All @@ -1541,17 +1596,22 @@ pub fn update_record_with_output<T: ColumnType>(
connection,
expected: match expected {
QueryExpect::Results {
sort_mode, label, ..
sort_mode,
label,
result_mode,
..
} => QueryExpect::Results {
results,
types,
sort_mode,
result_mode,
label,
},
QueryExpect::Error(_) => QueryExpect::Results {
results,
types,
sort_mode: None,
result_mode: None,
label: None,
},
},
Expand Down Expand Up @@ -2009,6 +2069,7 @@ Caused by:
&record_output,
" ",
default_validator,
default_normalizer,
strict_column_validator,
);

Expand Down
5 changes: 4 additions & 1 deletion tests/custom_type/custom_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,8 @@ fn test() {
let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) });
tester.with_column_validator(strict_column_validator);

tester.run_file("./custom_type/custom_type.slt").unwrap();
let r = tester.run_file("./custom_type/custom_type.slt");
if let Err(err) = r {
eprintln!("{:?}", err);
}
}
Loading
Loading