Skip to content

Commit

Permalink
Adding resultmode control and custom value normalizer. Bumped version.
Browse files Browse the repository at this point in the history
Signed-off-by: Bruce Ritchie <[email protected]>
  • Loading branch information
Omega359 committed Dec 12, 2024
1 parent c35b5bc commit 236a913
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 34 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]

[workspace.package]
version = "0.23.0"
version = "0.24.0"
edition = "2021"
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
keywords = ["sql", "database", "parser", "cli"]
Expand Down
4 changes: 2 additions & 2 deletions sqllogictest-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ glob = "0.3"
itertools = "0.13"
quick-junit = { version = "0.5" }
rand = "0.8"
sqllogictest = { path = "../sqllogictest", version = "0.23" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.23" }
sqllogictest = { path = "../sqllogictest", version = "0.24" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.24" }
tokio = { version = "1", features = [
"rt",
"rt-multi-thread",
Expand Down
6 changes: 2 additions & 4 deletions sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ use itertools::Itertools;
use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
use rand::distributions::DistString;
use rand::seq::SliceRandom;
use sqllogictest::{
default_column_validator, default_validator, update_record_with_output, AsyncDB, Injected,
MakeConnection, Record, Runner,
};
use sqllogictest::{default_column_validator, default_normalizer, default_validator, update_record_with_output, AsyncDB, Injected, MakeConnection, Record, Runner};

#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
#[must_use]
Expand Down Expand Up @@ -750,6 +747,7 @@ async fn update_record<M: MakeConnection>(
&record_output,
"\t",
default_validator,
default_normalizer,
default_column_validator,
) {
Some(new_record) => {
Expand Down
2 changes: 1 addition & 1 deletion sqllogictest-engines/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sqllogictest = { path = "../sqllogictest", version = "0.23" }
sqllogictest = { path = "../sqllogictest", version = "0.24" }
thiserror = "2"
tokio = { version = "1", features = [
"rt",
Expand Down
44 changes: 43 additions & 1 deletion sqllogictest/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub enum QueryExpect<T: ColumnType> {
Results {
types: Vec<T>,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
label: Option<String>,
results: Vec<String>,
},
Expand All @@ -98,6 +99,7 @@ impl<T: ColumnType> QueryExpect<T> {
Self::Results {
types: Vec::new(),
sort_mode: None,
result_mode: None,
label: None,
results: Vec::new(),
}
Expand Down Expand Up @@ -287,6 +289,7 @@ impl<T: ColumnType> std::fmt::Display for Record<T> {
}
Record::Control(c) => match c {
Control::SortMode(m) => write!(f, "control sortmode {}", m.as_str()),
Control::ResultMode(m) => write!(f, "control resultmode {}", m.as_str()),
Control::Substitution(s) => write!(f, "control substitution {}", s.as_str()),
},
Record::Condition(cond) => match cond {
Expand Down Expand Up @@ -435,6 +438,8 @@ impl PartialEq for ExpectedError {
pub enum Control {
/// Control sort mode.
SortMode(SortMode),
/// control result mode.
ResultMode(ResultMode),
/// Control whether or not to substitute variables in the SQL.
Substitution(bool),
}
Expand Down Expand Up @@ -545,6 +550,38 @@ impl ControlItem for SortMode {
}
}

/// Whether the results should be parsed as value-wise or row-wise
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ResultMode {
/// Results are in a single column
ValueWise,
/// The default option where results are in columns separated by spaces
RowWise,
}

impl ControlItem for ResultMode {
fn try_from_str(s: &str) -> Result<Self, ParseErrorKind> {
match s {
"rowwise" => Ok(Self::RowWise),
"valuewise" => Ok(Self::ValueWise),
_ => Err(ParseErrorKind::InvalidSortMode(s.to_string())),
}
}

fn as_str(&self) -> &'static str {
match self {
Self::RowWise => "rowwise",
Self::ValueWise => "valuewise",
}
}
}

impl fmt::Display for ResultMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}

/// The error type for parsing sqllogictest.
#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
#[error("parse error at {loc}: {kind}")]
Expand Down Expand Up @@ -754,6 +791,7 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
QueryExpect::Results {
types,
sort_mode,
result_mode: None,
label,
results: Vec::new(),
}
Expand Down Expand Up @@ -812,6 +850,10 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
});
}
["control", res @ ..] => match res {
["resultmode", result_mode] => match ResultMode::try_from_str(result_mode) {
Ok(result_mode) => records.push(Record::Control(Control::ResultMode(result_mode))),
Err(k) => return Err(k.at(loc)),
},
["sortmode", sort_mode] => match SortMode::try_from_str(sort_mode) {
Ok(sort_mode) => records.push(Record::Control(Control::SortMode(sort_mode))),
Err(k) => return Err(k.at(loc)),
Expand All @@ -829,7 +871,7 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
ParseErrorKind::InvalidNumber((*threshold).into()).at(loc.clone())
})?,
});
}
},
_ => return Err(ParseErrorKind::InvalidLine(line.into()).at(loc)),
}
}
Expand Down
64 changes: 44 additions & 20 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,26 +449,34 @@ fn format_column_diff(expected: &str, actual: &str, colorize: bool) -> String {
format!("[Expected] {expected}\n[Actual ] {actual}")
}

/// Normalizer will be used by [`Runner`] to normalize the result values
///
/// # Default
///
/// By default, the ([`default_normalizer`]) will be used to normalize values.
pub type Normalizer = fn(s: &String) -> String;

/// Trim and replace multiple whitespaces with one.
#[allow(clippy::ptr_arg)]
fn normalize_string(s: &String) -> String {
pub fn default_normalizer(s: &String) -> String {
s.trim().split_ascii_whitespace().join(" ")
}

/// Validator will be used by [`Runner`] to validate the output.
///
/// # Default
///
/// By default ([`default_validator`]), we will use compare normalized results.
pub type Validator = fn(actual: &[Vec<String>], expected: &[String]) -> bool;
/// By default, the ([`default_validator`]) will be used compare normalized results.
pub type Validator = fn(normalizer: Normalizer, actual: &[Vec<String>], expected: &[String]) -> bool;

pub fn default_validator(actual: &[Vec<String>], expected: &[String]) -> bool {
let expected_results = expected.iter().map(normalize_string).collect_vec();
pub fn default_validator(normalizer: Normalizer, actual: &[Vec<String>], expected: &[String]) -> bool {
let expected_results = expected.iter().map(normalizer).collect_vec();
// Default, we compare normalized results. Whitespace characters are ignored.
let normalized_rows = actual
.iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.map(|strs| strs.iter().map(normalizer).join(" "))
.collect_vec();

normalized_rows == expected_results
}

Expand Down Expand Up @@ -502,9 +510,12 @@ pub struct Runner<D: AsyncDB, M: MakeConnection> {
conn: Connections<D, M>,
// validator is used for validate if the result of query equals to expected.
validator: Validator,
// normalizer is used to normalize the result text
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
substitution: Option<Substitution>,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
/// 0 means never hashing
hash_threshold: usize,
/// Labels for condition `skipif` and `onlyif`.
Expand All @@ -518,9 +529,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
pub fn new(make_conn: M) -> Self {
Runner {
validator: default_validator,
normalizer: default_normalizer,
column_type_validator: default_column_validator,
substitution: None,
sort_mode: None,
result_mode: None,
hash_threshold: 0,
labels: HashSet::new(),
conn: Connections::new(make_conn),
Expand All @@ -532,6 +545,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
self.labels.insert(label.to_string());
}

pub fn with_normalizer(&mut self, normalizer: Normalizer) {
self.normalizer = normalizer;
}
pub fn with_validator(&mut self, validator: Validator) {
self.validator = validator;
}
Expand Down Expand Up @@ -824,6 +840,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
Control::SortMode(sort_mode) => {
self.sort_mode = Some(sort_mode);
}
Control::ResultMode(result_mode) => {
self.result_mode = Some(result_mode);
}
Control::Substitution(on_off) => match (&mut self.substitution, on_off) {
(s @ None, true) => *s = Some(Substitution::default()),
(s @ Some(_), false) => *s = None,
Expand Down Expand Up @@ -1012,20 +1031,17 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.at(loc));
}

let actual_results =
if types.len() > 1 && rows.len() * types.len() == expected_results.len() {
// value-wise mode
let actual_results = match self.result_mode {
Some(ResultMode::ValueWise) =>
rows.into_iter()
.flat_map(|strs| strs.iter().map(normalize_string).collect_vec())
.collect_vec()
} else {
// row-wise mode
rows.into_iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.collect_vec()
};
.flat_map(|strs| strs.into_iter())
.map(|str| vec![str.to_string()])
.collect_vec(),
// default to rowwise
_ => rows.clone(),
};

if !(self.validator)(&[actual_results], &expected_results) {
if !(self.validator)(self.normalizer, &actual_results, &expected_results) {
let output_rows =
rows.iter().map(|strs| strs.iter().join(" ")).collect_vec();
return Err(TestErrorKind::QueryResultMismatch {
Expand Down Expand Up @@ -1196,9 +1212,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
conn_builder(target.clone(), db_name.clone()).map(Ok)
}),
validator: self.validator,
normalizer: self.normalizer,
column_type_validator: self.column_type_validator,
substitution: self.substitution.clone(),
sort_mode: self.sort_mode,
result_mode: self.result_mode,
hash_threshold: self.hash_threshold,
labels: self.labels.clone(),
};
Expand Down Expand Up @@ -1269,6 +1287,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
filename: impl AsRef<Path>,
col_separator: &str,
validator: Validator,
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
) -> Result<(), Box<dyn std::error::Error>> {
use std::io::{Read, Seek, SeekFrom, Write};
Expand Down Expand Up @@ -1384,6 +1403,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
&record_output,
col_separator,
validator,
normalizer,
column_type_validator,
)
.unwrap_or(record);
Expand Down Expand Up @@ -1413,6 +1433,7 @@ pub fn update_record_with_output<T: ColumnType>(
record_output: &RecordOutput<T>,
col_separator: &str,
validator: Validator,
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<T>,
) -> Option<Record<T>> {
match (record.clone(), record_output) {
Expand Down Expand Up @@ -1552,7 +1573,7 @@ pub fn update_record_with_output<T: ColumnType>(
QueryExpect::Results {
results: expected_results,
..
} if validator(rows, expected_results) => expected_results.clone(),
} if validator(normalizer, rows, expected_results) => expected_results.clone(),
_ => rows.iter().map(|cols| cols.join(col_separator)).collect(),
};
let types = match &expected {
Expand All @@ -1570,17 +1591,19 @@ pub fn update_record_with_output<T: ColumnType>(
connection,
expected: match expected {
QueryExpect::Results {
sort_mode, label, ..
sort_mode, label, result_mode, ..
} => QueryExpect::Results {
results,
types,
sort_mode,
result_mode,
label,
},
QueryExpect::Error(_) => QueryExpect::Results {
results,
types,
sort_mode: None,
result_mode: None,
label: None,
},
},
Expand Down Expand Up @@ -2038,6 +2061,7 @@ Caused by:
&record_output,
" ",
default_validator,
default_normalizer,
strict_column_validator,
);

Expand Down
5 changes: 4 additions & 1 deletion tests/custom_type/custom_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,8 @@ fn test() {
let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) });
tester.with_column_validator(strict_column_validator);

tester.run_file("./custom_type/custom_type.slt").unwrap();
let r = tester.run_file("./custom_type/custom_type.slt");
if let Err(err) = r {
eprintln!("{:?}", err);
}
}
15 changes: 15 additions & 0 deletions tests/slt/rowsort.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,18 @@ select * from example_sort
1 10 2333
10 100 2333
2 20 2333

control resultmode valuewise

query III rowsort
select * from example_sort
----
1
10
2333
10
100
2333
2
20
2333
Loading

0 comments on commit 236a913

Please sign in to comment.