Skip to content

Commit

Permalink
Refactor table heap iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw committed Feb 24, 2024
1 parent 5f31fcd commit c563dae
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 19 deletions.
15 changes: 10 additions & 5 deletions bustubx/src/execution/physical_plan/seq_scan.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use log::debug;

Check warning on line 1 in bustubx/src/execution/physical_plan/seq_scan.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `log::debug`
use std::ops::{Bound, RangeBounds, RangeFull};

Check warning on line 2 in bustubx/src/execution/physical_plan/seq_scan.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `RangeBounds`
use std::sync::Mutex;

use crate::catalog::SchemaRef;
use crate::common::rid::Rid;

Check warning on line 6 in bustubx/src/execution/physical_plan/seq_scan.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `crate::common::rid::Rid`
use crate::common::TableReference;
use crate::{
execution::{ExecutionContext, VolcanoExecutor},
Expand All @@ -22,18 +24,21 @@ impl PhysicalSeqScan {
PhysicalSeqScan {
table,
table_schema,
iterator: Mutex::new(TableIterator::new(None, None)),
iterator: Mutex::new(TableIterator::new(
Bound::Unbounded,
Bound::Unbounded,
None,
false,
false,
)),
}
}
}

impl VolcanoExecutor for PhysicalSeqScan {
fn init(&self, context: &mut ExecutionContext) -> BustubxResult<()> {
debug!("init table scan executor");
let table_heap = context.catalog.table_heap(&self.table)?;
let inited_iterator = table_heap.iter(None, None);
let mut iterator = self.iterator.lock().unwrap();
*iterator = inited_iterator;
*self.iterator.lock().unwrap() = table_heap.scan(RangeFull);
Ok(())
}

Expand Down
34 changes: 32 additions & 2 deletions bustubx/src/storage/index.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::VecDeque;
use std::ops::{Bound, Range, RangeBounds};

Check warning on line 2 in bustubx/src/storage/index.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused imports: `Bound`, `Range`
use std::sync::{Arc, RwLock};

use crate::buffer::{Page, PageId, INVALID_PAGE_ID};
use crate::catalog::SchemaRef;
use crate::catalog::{Schema, SchemaRef};

Check warning on line 6 in bustubx/src/storage/index.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `Schema`
use crate::common::util::page_bytes_to_array;
use crate::storage::codec::{
BPlusTreeInternalPageCodec, BPlusTreeLeafPageCodec, BPlusTreePageCodec,
Expand Down Expand Up @@ -41,6 +42,8 @@ pub struct BPlusTreeIndex {
pub root_page_id: PageId,
}

pub struct TreeIndexIterator {}

impl BPlusTreeIndex {
pub fn new(
key_schema: SchemaRef,
Expand Down Expand Up @@ -205,7 +208,11 @@ impl BPlusTreeIndex {
Ok(())
}

pub fn scan(&self, _key: &Tuple) -> Vec<Rid> {
pub fn scan<R>(&self, range: R) -> Vec<Rid>
where
R: RangeBounds<Tuple>,
{
range.start_bound();
unimplemented!()
}

Expand Down Expand Up @@ -749,4 +756,27 @@ B+ Tree Level No.2:
+--------------------------------------+--------------------------------------+--------------------------------------+
");
}

#[test]
pub fn test_index_get() {
let (mut index, key_schema) = build_index();
assert_eq!(
index
.get(&Tuple::new(
key_schema.clone(),
vec![3i8.into(), 3i16.into()],
))
.unwrap(),
Some(Rid::new(3, 3))
);
assert_eq!(
index
.get(&Tuple::new(
key_schema.clone(),
vec![10i8.into(), 10i16.into()],
))
.unwrap(),
Some(Rid::new(10, 10))
);
}
}
76 changes: 64 additions & 12 deletions bustubx/src/storage/table_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use crate::common::util::page_bytes_to_array;
use crate::storage::codec::TablePageCodec;
use crate::storage::{TablePage, TupleMeta};
use crate::{buffer::BufferPoolManager, common::rid::Rid, BustubxResult};
use std::collections::Bound;
use std::ops::RangeBounds;
use std::sync::atomic::Ordering;
use std::sync::Arc;

Expand Down Expand Up @@ -182,35 +184,85 @@ impl TableHeap {
}
}

pub fn iter(&self, start_at: Option<Rid>, stop_at: Option<Rid>) -> TableIterator {
pub fn scan<R: RangeBounds<Rid>>(&self, rang: R) -> TableIterator {
TableIterator {
rid: start_at.or(self.get_first_rid()),
stop_at,
start_bound: rang.start_bound().cloned(),
end_bound: rang.end_bound().cloned(),
cursor: None,
started: false,
ended: false,
}
}
}

#[derive(derive_new::new, Debug)]
pub struct TableIterator {
pub rid: Option<Rid>,
pub stop_at: Option<Rid>,
start_bound: Bound<Rid>,
end_bound: Bound<Rid>,
cursor: Option<Rid>,
started: bool,
ended: bool,
}

impl TableIterator {
pub fn next(&mut self, table_heap: &TableHeap) -> Option<(TupleMeta, Tuple)> {
self.rid?;
let rid = self.rid.unwrap();
if self.stop_at.is_some() && rid == self.stop_at.unwrap() {
if self.ended {
return None;
}
let result = table_heap.tuple(rid).unwrap();
self.rid = table_heap.get_next_rid(rid);
Some(result)

if self.started {
match self.end_bound {
Bound::Included(rid) => {
if let Some(next_rid) = table_heap.get_next_rid(self.cursor.unwrap()) {
if next_rid == rid {
self.ended = true;
}
self.cursor = Some(next_rid);
self.cursor.map(|rid| table_heap.tuple(rid).unwrap())
} else {
None
}
}
Bound::Excluded(rid) => {
if let Some(next_rid) = table_heap.get_next_rid(self.cursor.unwrap()) {
if next_rid == rid {
None
} else {
self.cursor = Some(next_rid);
self.cursor.map(|rid| table_heap.tuple(rid).unwrap())
}
} else {
None
}
}
Bound::Unbounded => {
let next_rid = table_heap.get_next_rid(self.cursor.unwrap());
self.cursor = next_rid;
self.cursor.map(|rid| table_heap.tuple(rid).unwrap())
}
}
} else {
match self.start_bound {
Bound::Included(rid) => {
self.cursor = Some(rid.clone());
Some(table_heap.tuple(rid).unwrap())
}
Bound::Excluded(rid) => {
self.cursor = table_heap.get_next_rid(rid);
self.cursor.map(|rid| table_heap.tuple(rid).unwrap())
}
Bound::Unbounded => {
self.cursor = table_heap.get_first_rid();
self.cursor.map(|rid| table_heap.tuple(rid).unwrap())
}
}
}
}
}

#[cfg(test)]
mod tests {
use std::ops::RangeFull;
use std::sync::Arc;
use tempfile::TempDir;

Expand Down Expand Up @@ -377,7 +429,7 @@ mod tests {
)
.unwrap();

let mut iterator = table_heap.iter(None, None);
let mut iterator = table_heap.scan(RangeFull);

let (meta, tuple) = iterator.next(&table_heap).unwrap();
assert_eq!(meta, meta1);
Expand Down

0 comments on commit c563dae

Please sign in to comment.