Skip to content

Commit

Permalink
feat: add start/end interval for bam/sam/cram (#186)
Browse files Browse the repository at this point in the history
* feat: rename interval to pos interval

* feat: add start/end interval for bam/sam/cram
  • Loading branch information
tshauck authored Sep 27, 2023
1 parent 46b7168 commit 49fbb36
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 54 deletions.
8 changes: 8 additions & 0 deletions exon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ mod tests {
BinaryExpr::new(left, Operator::GtEq, right)
}

pub(crate) fn gt(left: Arc<dyn PhysicalExpr>, right: Arc<dyn PhysicalExpr>) -> BinaryExpr {
BinaryExpr::new(left, Operator::Gt, right)
}

pub(crate) fn lt(left: Arc<dyn PhysicalExpr>, right: Arc<dyn PhysicalExpr>) -> BinaryExpr {
BinaryExpr::new(left, Operator::Lt, right)
}

pub fn make_object_store() -> Arc<dyn ObjectStore> {
let local_file_system = LocalFileSystem::new();

Expand Down
10 changes: 6 additions & 4 deletions exon/src/physical_optimizer/interval_optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use datafusion::physical_plan::expressions::BinaryExpr;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::{with_new_children_if_necessary, ExecutionPlan};

use crate::physical_plan::interval_physical_expr::IntervalPhysicalExpr;
use crate::physical_plan::pos_interval_physical_expr::PosIntervalPhysicalExpr;

fn optimize(plan: Arc<dyn ExecutionPlan>) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let plan = if plan.children().is_empty() {
Expand Down Expand Up @@ -53,7 +53,7 @@ fn optimize(plan: Arc<dyn ExecutionPlan>) -> Result<Transformed<Arc<dyn Executio
None => return Ok(Transformed::No(plan)),
};

let interval_expr = match IntervalPhysicalExpr::try_from(pred.clone()) {
let interval_expr = match PosIntervalPhysicalExpr::try_from(pred.clone()) {
Ok(expr) => expr,
Err(_) => return Ok(Transformed::No(plan)),
};
Expand Down Expand Up @@ -94,7 +94,9 @@ mod tests {
use datafusion::{physical_plan::filter::FilterExec, prelude::SessionContext};
use noodles::core::region::Interval;

use crate::{physical_plan::interval_physical_expr::IntervalPhysicalExpr, ExonSessionExt};
use crate::{
physical_plan::pos_interval_physical_expr::PosIntervalPhysicalExpr, ExonSessionExt,
};

#[tokio::test]
async fn test_interval_rule_eq() {
Expand Down Expand Up @@ -123,7 +125,7 @@ mod tests {
let pred = filter_exec
.predicate()
.as_any()
.downcast_ref::<IntervalPhysicalExpr>()
.downcast_ref::<PosIntervalPhysicalExpr>()
.unwrap();

let expected_interval = Interval::from_str("1-1").unwrap();
Expand Down
40 changes: 20 additions & 20 deletions exon/src/physical_optimizer/merging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use datafusion::error::{DataFusionError, Result};
use noodles::core::Region;

use crate::physical_plan::{
interval_physical_expr::{pos_schema, IntervalPhysicalExpr},
pos_interval_physical_expr::{pos_schema, PosIntervalPhysicalExpr},
region_name_physical_expr::RegionNamePhysicalExpr,
region_physical_expr::RegionPhysicalExpr,
};
Expand Down Expand Up @@ -68,20 +68,20 @@ pub(crate) fn try_merge_chrom_exprs(

/// Merge two `IntervalPhysicalExpr`s.
pub fn try_merge_interval_exprs(
left: &IntervalPhysicalExpr,
right: &IntervalPhysicalExpr,
) -> Result<Option<IntervalPhysicalExpr>> {
left: &PosIntervalPhysicalExpr,
right: &PosIntervalPhysicalExpr,
) -> Result<Option<PosIntervalPhysicalExpr>> {
match intersect_ranges(left.interval_tuple(), right.interval_tuple()) {
Some((start, Some(end))) => {
return Ok(Some(IntervalPhysicalExpr::new(
return Ok(Some(PosIntervalPhysicalExpr::new(
start,
Some(end),
left.inner().clone(),
)))
}
Some((start, None)) => {
let schema = pos_schema();
let interval_expr = IntervalPhysicalExpr::from_interval(start, None, &schema)?;
let interval_expr = PosIntervalPhysicalExpr::from_interval(start, None, &schema)?;

Ok(Some(interval_expr))
}
Expand All @@ -91,13 +91,13 @@ pub fn try_merge_interval_exprs(

pub fn try_merge_region_with_interval(
left: &RegionPhysicalExpr,
right: &IntervalPhysicalExpr,
right: &PosIntervalPhysicalExpr,
) -> Result<Option<RegionPhysicalExpr>> {
let interval = match left.interval_expr() {
Some(interval_expr) => interval_expr,
None => {
let new_interval = right;
let new_interval = IntervalPhysicalExpr::new(
let new_interval = PosIntervalPhysicalExpr::new(
new_interval.start(),
new_interval.end(),
new_interval.inner().clone(),
Expand Down Expand Up @@ -212,7 +212,7 @@ mod tests {
use crate::{
physical_optimizer::merging::{try_merge_chrom_exprs, try_merge_interval_exprs},
physical_plan::{
interval_physical_expr::{pos_schema, IntervalPhysicalExpr},
pos_interval_physical_expr::{pos_schema, PosIntervalPhysicalExpr},
region_name_physical_expr::RegionNamePhysicalExpr,
},
tests::{and, gteq, lteq},
Expand Down Expand Up @@ -251,17 +251,17 @@ mod tests {
]));

let left_expr = gteq(col("pos", &schema).unwrap(), lit(3));
let left_interval = IntervalPhysicalExpr::try_from(left_expr).unwrap();
let left_interval = PosIntervalPhysicalExpr::try_from(left_expr).unwrap();

let right_expr = lteq(col("pos", &schema).unwrap(), lit(4));
let right_interval = IntervalPhysicalExpr::try_from(right_expr).unwrap();
let right_interval = PosIntervalPhysicalExpr::try_from(right_expr).unwrap();

try_merge_interval_exprs(&left_interval, &right_interval)
.unwrap()
.unwrap();

let right_expr = lteq(col("pos", &schema).unwrap(), lit(2));
let right_interval = IntervalPhysicalExpr::try_from(right_expr).unwrap();
let right_interval = PosIntervalPhysicalExpr::try_from(right_expr).unwrap();

let merged = try_merge_interval_exprs(&left_interval, &right_interval).unwrap();
assert!(merged.is_none());
Expand All @@ -280,8 +280,8 @@ mod tests {

let inner = Arc::new(inner_expression);

let right_interval = IntervalPhysicalExpr::new(1, Some(10), inner.clone());
let left_interval = IntervalPhysicalExpr::new(1, Some(10), inner.clone());
let right_interval = PosIntervalPhysicalExpr::new(1, Some(10), inner.clone());
let left_interval = PosIntervalPhysicalExpr::new(1, Some(10), inner.clone());

let merged = try_merge_interval_exprs(&left_interval, &right_interval)
.unwrap()
Expand All @@ -298,10 +298,10 @@ mod tests {
]));

let left_inner = gteq(col("pos", &schema).unwrap(), lit(ScalarValue::from(10)));
let left = IntervalPhysicalExpr::try_from(left_inner).unwrap();
let left = PosIntervalPhysicalExpr::try_from(left_inner).unwrap();

let right_inner = lteq(col("pos", &schema).unwrap(), lit(ScalarValue::from(20)));
let right = IntervalPhysicalExpr::try_from(right_inner).unwrap();
let right = PosIntervalPhysicalExpr::try_from(right_inner).unwrap();

let merged = try_merge_interval_exprs(&left, &right)
.unwrap()
Expand All @@ -324,10 +324,10 @@ mod tests {
]));

let gteq_expr = gteq(col("pos", &schema).unwrap(), lit(4));
let gt_interval = super::IntervalPhysicalExpr::try_from(gteq_expr).unwrap();
let gt_interval = super::PosIntervalPhysicalExpr::try_from(gteq_expr).unwrap();

let lteq_expr = lteq(col("pos", &schema).unwrap(), lit(10));
let lt_interval = super::IntervalPhysicalExpr::try_from(lteq_expr).unwrap();
let lt_interval = super::PosIntervalPhysicalExpr::try_from(lteq_expr).unwrap();

let interval = try_merge_interval_exprs(&gt_interval, &lt_interval)
.unwrap()
Expand All @@ -346,10 +346,10 @@ mod tests {
let schema = pos_schema();

let left_expr = gteq(col("pos", &schema).unwrap(), lit(3));
let left_interval = IntervalPhysicalExpr::try_from(left_expr).unwrap();
let left_interval = PosIntervalPhysicalExpr::try_from(left_expr).unwrap();

let right_expr = gteq(col("pos", &schema).unwrap(), lit(4));
let right_interval = IntervalPhysicalExpr::try_from(right_expr).unwrap();
let right_interval = PosIntervalPhysicalExpr::try_from(right_expr).unwrap();

let merged = try_merge_interval_exprs(&left_interval, &right_interval)
.unwrap()
Expand Down
16 changes: 10 additions & 6 deletions exon/src/physical_optimizer/region_filter_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;
use datafusion::error::Result;
use datafusion::{common::tree_node::Transformed, physical_plan::PhysicalExpr};

use crate::physical_plan::interval_physical_expr::IntervalPhysicalExpr;
use crate::physical_plan::pos_interval_physical_expr::PosIntervalPhysicalExpr;
use crate::physical_plan::region_name_physical_expr::RegionNamePhysicalExpr;
use crate::physical_plan::region_physical_expr::RegionPhysicalExpr;

Expand All @@ -37,7 +37,7 @@ pub fn transform_region_expressions(
return Ok(Transformed::Yes(Arc::new(region_expr)));
}

if let Ok(interval_expr) = IntervalPhysicalExpr::try_from(be.clone()) {
if let Ok(interval_expr) = PosIntervalPhysicalExpr::try_from(be.clone()) {
return Ok(Transformed::Yes(Arc::new(interval_expr)));
}

Expand All @@ -58,8 +58,10 @@ pub fn transform_region_expressions(

// Case 2: left is a chrom expression and right is an interval expression
if let Some(_left_chrom) = be.left().as_any().downcast_ref::<RegionNamePhysicalExpr>() {
if let Some(_right_interval) =
be.right().as_any().downcast_ref::<IntervalPhysicalExpr>()
if let Some(_right_interval) = be
.right()
.as_any()
.downcast_ref::<PosIntervalPhysicalExpr>()
{
let new_expr =
RegionPhysicalExpr::new(be.left().clone(), Some(be.right().clone()));
Expand All @@ -70,8 +72,10 @@ pub fn transform_region_expressions(

// Case 3: left is a region expression and the right is an interval expression
if let Some(left_region) = be.left().as_any().downcast_ref::<RegionPhysicalExpr>() {
if let Some(right_interval) =
be.right().as_any().downcast_ref::<IntervalPhysicalExpr>()
if let Some(right_interval) = be
.right()
.as_any()
.downcast_ref::<PosIntervalPhysicalExpr>()
{
let new_region = try_merge_region_with_interval(left_region, right_interval)?;

Expand Down
5 changes: 4 additions & 1 deletion exon/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
pub mod region_name_physical_expr;

/// A physical expression that represents a genomic interval.
pub mod interval_physical_expr;
pub mod pos_interval_physical_expr;

/// A physical expression that represents a region, e.g. chr1:100-200.
pub mod region_physical_expr;
Expand All @@ -26,3 +26,6 @@ pub mod object_store;

/// Builder for a file scan configuration.
pub mod file_scan_config_builder;

/// A physical expression that represents start/end interval
pub mod start_end_interval_physical_expr;
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ pub(crate) fn pos_schema() -> SchemaRef {

/// A physical expression that represents a genomic interval.
#[derive(Debug)]
pub struct IntervalPhysicalExpr {
pub struct PosIntervalPhysicalExpr {
start: usize,
end: Option<usize>,
inner: Arc<dyn PhysicalExpr>,
}

impl IntervalPhysicalExpr {
impl PosIntervalPhysicalExpr {
/// Create a new interval physical expression from an interval and an inner expression.
pub fn new(start: usize, end: Option<usize>, inner: Arc<dyn PhysicalExpr>) -> Self {
Self { start, end, inner }
Expand Down Expand Up @@ -104,7 +104,7 @@ impl IntervalPhysicalExpr {
}
}

impl Display for IntervalPhysicalExpr {
impl Display for PosIntervalPhysicalExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
Expand All @@ -114,7 +114,7 @@ impl Display for IntervalPhysicalExpr {
}
}

impl TryFrom<BinaryExpr> for IntervalPhysicalExpr {
impl TryFrom<BinaryExpr> for PosIntervalPhysicalExpr {
type Error = DataFusionError;

fn try_from(expr: BinaryExpr) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -155,7 +155,7 @@ impl TryFrom<BinaryExpr> for IntervalPhysicalExpr {
}
}

impl TryFrom<Arc<dyn PhysicalExpr>> for IntervalPhysicalExpr {
impl TryFrom<Arc<dyn PhysicalExpr>> for PosIntervalPhysicalExpr {
type Error = DataFusionError;

fn try_from(expr: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
Expand All @@ -167,23 +167,23 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for IntervalPhysicalExpr {
}
}

impl PartialEq<dyn Any> for IntervalPhysicalExpr {
impl PartialEq<dyn Any> for PosIntervalPhysicalExpr {
fn eq(&self, other: &dyn Any) -> bool {
if let Some(other) = other.downcast_ref::<IntervalPhysicalExpr>() {
if let Some(other) = other.downcast_ref::<Self>() {
self.start == other.start && self.end == other.end
} else {
false
}
}
}

impl PartialEq for IntervalPhysicalExpr {
impl PartialEq for PosIntervalPhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.start == other.start && self.end == other.end
}
}

impl PhysicalExpr for IntervalPhysicalExpr {
impl PhysicalExpr for PosIntervalPhysicalExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand Down Expand Up @@ -217,7 +217,7 @@ impl PhysicalExpr for IntervalPhysicalExpr {
self: std::sync::Arc<Self>,
_children: Vec<std::sync::Arc<dyn PhysicalExpr>>,
) -> datafusion::error::Result<std::sync::Arc<dyn PhysicalExpr>> {
Ok(Arc::new(IntervalPhysicalExpr::new(
Ok(Arc::new(PosIntervalPhysicalExpr::new(
self.start,
self.end,
self.inner.clone(),
Expand Down Expand Up @@ -246,11 +246,11 @@ mod tests {
use noodles::core::Position;

use crate::{
physical_plan::interval_physical_expr,
physical_plan::pos_interval_physical_expr,
tests::{eq, gteq},
};

use super::IntervalPhysicalExpr;
use super::PosIntervalPhysicalExpr;

#[test]
fn test_call_interval_with_no_upper_bound() {
Expand All @@ -259,7 +259,8 @@ mod tests {
]));

let expr = gteq(col("pos", &schema).unwrap(), lit(4));
let interval_expr = interval_physical_expr::IntervalPhysicalExpr::try_from(expr).unwrap();
let interval_expr =
pos_interval_physical_expr::PosIntervalPhysicalExpr::try_from(expr).unwrap();

assert_eq!(interval_expr.start, 4);
assert_eq!(interval_expr.end, None);
Expand All @@ -275,7 +276,7 @@ mod tests {

let pos_expr = eq(col("pos", &schema).unwrap(), lit(4));

let interval = super::IntervalPhysicalExpr::try_from(pos_expr).unwrap();
let interval = super::PosIntervalPhysicalExpr::try_from(pos_expr).unwrap();

assert_eq!(
interval.interval().unwrap(),
Expand All @@ -297,8 +298,11 @@ mod tests {

let binary_expr = eq(col("pos", &batch.schema()).unwrap(), lit(1i64));

let expr =
interval_physical_expr::IntervalPhysicalExpr::new(1, Some(1), Arc::new(binary_expr));
let expr = pos_interval_physical_expr::PosIntervalPhysicalExpr::new(
1,
Some(1),
Arc::new(binary_expr),
);

let result = match expr.evaluate(&batch)? {
datafusion::physical_plan::ColumnarValue::Array(array) => array,
Expand Down Expand Up @@ -329,7 +333,7 @@ mod tests {
arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
]));

let interval_expr = IntervalPhysicalExpr::from_interval(1, Some(10), &schema).unwrap();
let interval_expr = PosIntervalPhysicalExpr::from_interval(1, Some(10), &schema).unwrap();

// The interval_expr should be a BinaryExpr with an AND operator
let inner_expr = interval_expr
Expand All @@ -344,7 +348,7 @@ mod tests {
}

// Now test that without an end, we get a GtEq
let interval_expr = IntervalPhysicalExpr::from_interval(1, None, &schema).unwrap();
let interval_expr = PosIntervalPhysicalExpr::from_interval(1, None, &schema).unwrap();

let inner_expr = interval_expr
.inner
Expand Down
Loading

0 comments on commit 49fbb36

Please sign in to comment.