Skip to content

Commit

Permalink
refactor: return record output in run
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Jun 24, 2024
1 parent 17d81db commit a71c006
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

* runner: `RecordOutput` is now returned by `Runner::run` (or `Runner::run_async`). This allows users to access the output of each record, or check whether the record is skipped.

## [0.20.6] - 2024-06-21

* runner: add logs for `system` command (with target `sqllogictest::system_command`) for ease of debugging.
Expand Down
49 changes: 34 additions & 15 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ use crate::{ColumnType, Connections, MakeConnection};
/// Type-erased error type.
type AnyError = Arc<dyn std::error::Error + Send + Sync>;

/// Output of a record.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RecordOutput<T: ColumnType> {
/// No output. Occurs when the record is skipped or not a `query`, `statement`, or `system`
/// command.
Nothing,
/// The output of a `query`.
Query {
types: Vec<T>,
rows: Vec<Vec<String>>,
error: Option<AnyError>,
},
Statement {
count: u64,
error: Option<AnyError>,
},
/// The output of a `statement`.
Statement { count: u64, error: Option<AnyError> },
/// The output of a `system` command.
#[non_exhaustive]
System {
stdout: Option<String>,
Expand Down Expand Up @@ -833,10 +836,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
}

/// Run a single record.
pub async fn run_async(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
pub async fn run_async(
&mut self,
record: Record<D::ColumnType>,
) -> Result<RecordOutput<D::ColumnType>, TestError> {
let result = self.apply_record(record.clone()).await;

match (record, result) {
match (record, &result) {
(_, RecordOutput::Nothing) => {}
// Tolerate the mismatched return type...
(
Expand Down Expand Up @@ -894,7 +900,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.at(loc))
}
(None, StatementExpect::Count(expected_count)) => {
if expected_count != count {
if expected_count != *count {
return Err(TestErrorKind::StatementResultMismatch {
sql,
expected: expected_count,
Expand All @@ -908,7 +914,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
if !expected_error.is_match(&e.to_string()) {
return Err(TestErrorKind::ErrorMismatch {
sql,
err: Arc::new(e),
err: Arc::clone(e),
expected_err: expected_error.to_string(),
kind: RecordKind::Statement,
}
Expand All @@ -918,7 +924,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
(Some(e), StatementExpect::Count(_) | StatementExpect::Ok) => {
return Err(TestErrorKind::Fail {
sql,
err: Arc::new(e),
err: Arc::clone(e),
kind: RecordKind::Statement,
}
.at(loc));
Expand Down Expand Up @@ -946,7 +952,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
if !expected_error.is_match(&e.to_string()) {
return Err(TestErrorKind::ErrorMismatch {
sql,
err: Arc::new(e),
err: Arc::clone(e),
expected_err: expected_error.to_string(),
kind: RecordKind::Query,
}
Expand All @@ -956,7 +962,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
(Some(e), QueryExpect::Results { .. }) => {
return Err(TestErrorKind::Fail {
sql,
err: Arc::new(e),
err: Arc::clone(e),
kind: RecordKind::Query,
}
.at(loc));
Expand Down Expand Up @@ -1006,12 +1012,16 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
},
) => {
if let Some(err) = error {
return Err(TestErrorKind::SystemFail { command, err }.at(loc));
return Err(TestErrorKind::SystemFail {
command,
err: Arc::clone(err),
}
.at(loc));
}
match (expected_stdout, actual_stdout) {
(None, _) => {}
(Some(expected_stdout), actual_stdout) => {
let actual_stdout = actual_stdout.unwrap_or_default();
let actual_stdout = actual_stdout.clone().unwrap_or_default();
// TODO: support newlines contained in expected_stdout
if expected_stdout != actual_stdout.trim() {
return Err(TestErrorKind::SystemStdoutMismatch {
Expand All @@ -1027,17 +1037,24 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
_ => unreachable!(),
}

Ok(())
Ok(result)
}

/// Run a single record.
pub fn run(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
///
/// Returns the output of the record if successful.
pub fn run(
&mut self,
record: Record<D::ColumnType>,
) -> Result<RecordOutput<D::ColumnType>, TestError> {
futures::executor::block_on(self.run_async(record))
}

/// Run multiple records.
///
/// The runner will stop early once a halt record is seen.
///
/// To acquire the result of each record, manually call `run_async` for each record instead.
pub async fn run_multi_async(
&mut self,
records: impl IntoIterator<Item = Record<D::ColumnType>>,
Expand All @@ -1054,6 +1071,8 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
/// Run multiple records.
///
/// The runner will stop early once a halt record is seen.
///
/// To acquire the result of each record, manually call `run` for each record instead.
pub fn run_multi(
&mut self,
records: impl IntoIterator<Item = Record<D::ColumnType>>,
Expand Down

0 comments on commit a71c006

Please sign in to comment.