Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support keyword completion #855

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ impl Database {
.add_row_count(table_id, count);
Ok(true)
}

/// Return all available pragma options.
fn pragma_options() -> &'static [&'static str] {
&["enable_optimizer", "disable_optimizer"]
}
}

/// The error type of database operations.
Expand Down Expand Up @@ -230,3 +235,104 @@ pub enum Error {
#[error("Internal error: {0}")]
Internal(String),
}

impl rustyline::Helper for &Database {}
impl rustyline::validate::Validator for &Database {}
impl rustyline::highlight::Highlighter for &Database {}
impl rustyline::hint::Hinter for &Database {
type Hint = String;
}

/// Implement SQL completion.
impl rustyline::completion::Completer for &Database {
type Candidate = rustyline::completion::Pair;
fn complete(
&self,
line: &str,
pos: usize,
_ctx: &rustyline::Context<'_>,
) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
// find the word before cursor
let (prefix, last_word) = line[..pos].rsplit_once(' ').unwrap_or(("", &line[..pos]));

// completion for pragma options
if prefix.trim().eq_ignore_ascii_case("pragma") {
let candidates = Database::pragma_options()
.iter()
.filter(|option| option.starts_with(last_word))
.map(|option| rustyline::completion::Pair {
display: option.to_string(),
replacement: option.to_string(),
})
.collect();
return Ok((pos - last_word.len(), candidates));
}

// TODO: complete table and column names

// completion for keywords

// for a given prefix, all keywords starting with the prefix are returned as candidates
// they should be ordered in principle that frequently used ones come first
const KEYWORDS: &[&str] = &[
"AS", "ALL", "ANALYZE", "CREATE", "COPY", "DELETE", "DROP", "EXPLAIN", "FROM",
"FUNCTION", "INSERT", "JOIN", "ON", "PRAGMA", "SET", "SELECT", "TABLE", "UNION",
"VIEW", "WHERE", "WITH",
];
let last_word_upper = last_word.to_uppercase();
let candidates = KEYWORDS
.iter()
.filter(|command| command.starts_with(&last_word_upper))
.map(|command| rustyline::completion::Pair {
display: command.to_string(),
replacement: format!("{command} "),
})
.collect();
Ok((pos - last_word.len(), candidates))
}
}

#[cfg(test)]
mod tests {
use rustyline::history::DefaultHistory;

use super::*;

#[test]
fn test_completion() {
let db = Database::new_in_memory();
assert_complete(&db, "sel", "SELECT ");
assert_complete(&db, "sel|ect", "SELECT |ect");
assert_complete(&db, "select a f", "select a FROM ");
assert_complete(&db, "pragma en", "pragma enable_optimizer");
}

/// Assert that if complete (e.g. press tab) the given `line`, the result will be
/// `completed_line`.
///
/// Both `line` and `completed_line` can optionally contain a `|` which indicates the cursor
/// position. If not provided, the cursor is assumed to be at the end of the line.
#[track_caller]
fn assert_complete(db: &Database, line: &str, completed_line: &str) {
/// Find cursor position and remove it from the line.
fn get_line_and_cursor(line: &str) -> (String, usize) {
let (before_cursor, after_cursor) = line.split_once('|').unwrap_or((line, ""));
let pos = before_cursor.len();
(format!("{before_cursor}{after_cursor}"), pos)
}
let (mut line, pos) = get_line_and_cursor(line);

// complete
use rustyline::completion::Completer;
let (start_pos, candidates) = db
.complete(&line, pos, &rustyline::Context::new(&DefaultHistory::new()))
.unwrap();
let replacement = &candidates[0].replacement;
line.replace_range(start_pos..pos, replacement);

// assert
let (completed_line, completed_cursor_pos) = get_line_and_cursor(completed_line);
assert_eq!(line, completed_line);
assert_eq!(start_pos + replacement.len(), completed_cursor_pos);
}
}
8 changes: 5 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use risinglight::storage::SecondaryStorageOptions;
use risinglight::utils::time::RoundingDuration;
use risinglight::Database;
use rustyline::error::ReadlineError;
use rustyline::DefaultEditor;
use rustyline::history::DefaultHistory;
use rustyline::Editor;
use sqllogictest::DefaultColumnType;
use tokio::{select, signal};
use tracing::{info, warn, Level};
Expand Down Expand Up @@ -149,7 +150,7 @@ async fn run_query_in_background(db: Arc<Database>, sql: String, output_format:
///
/// Note that `;` in string literals will also be treated as a terminator
/// as long as it is at the end of a line.
fn read_sql(rl: &mut DefaultEditor) -> Result<String, ReadlineError> {
fn read_sql(rl: &mut Editor<&Database, DefaultHistory>) -> Result<String, ReadlineError> {
let mut sql = String::new();
loop {
let prompt = if sql.is_empty() { "> " } else { "? " };
Expand All @@ -174,7 +175,7 @@ fn read_sql(rl: &mut DefaultEditor) -> Result<String, ReadlineError> {

/// Run RisingLight interactive mode
async fn interactive(db: Database, output_format: Option<String>) -> Result<()> {
let mut rl = DefaultEditor::new()?;
let mut rl = Editor::<&Database, DefaultHistory>::new()?;
let history_path = dirs::cache_dir().map(|p| {
let cache_dir = p.join("risinglight");
std::fs::create_dir_all(cache_dir.as_path()).ok();
Expand All @@ -192,6 +193,7 @@ async fn interactive(db: Database, output_format: Option<String>) -> Result<()>
}

let db = Arc::new(db);
rl.set_helper(Some(&db));

loop {
let read_sql = read_sql(&mut rl);
Expand Down
Loading