Skip to content

Commit

Permalink
feat: add regex positional match (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck authored Mar 27, 2024
1 parent 6978b8a commit 2f526d6
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 2 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions exon/exon-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ url = { workspace = true }
fxhash = "0.2.1"
lazy_static = "1.4.0"
serde = { version = "1.0.197", features = ["derive"] }
regex = "1.10.4"

[dev-dependencies]
exon-test = { path = "../exon-test" }
Expand Down
216 changes: 216 additions & 0 deletions exon/exon-core/src/udfs/sequence/locate_regex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright 2023 WHERE TRUE Technologies.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use core::str;
use std::sync::Arc;

use arrow::{
array::{Array, GenericStringBuilder, Int32Builder, ListBuilder, StructBuilder},
datatypes::{DataType, Field, Fields},
};
use datafusion::{
common::cast::as_string_array,
error::Result,
logical_expr::{ColumnarValue, ScalarUDFImpl, Volatility},
scalar::ScalarValue,
};

#[derive(Debug)]
pub(crate) struct LocateRegex {
signature: datafusion::logical_expr::Signature,
}

impl Default for LocateRegex {
fn default() -> Self {
let sequence = DataType::Utf8;
let pattern = DataType::Utf8;

let signature = datafusion::logical_expr::Signature::exact(
vec![sequence, pattern],
Volatility::Immutable,
);

Self { signature }
}
}

// a macro that is ok_or a DataFusionError
macro_rules! ok_or {
($e:expr, $msg:expr) => {
$e.ok_or(datafusion::error::DataFusionError::Execution(
$msg.to_string(),
))
};
}

impl ScalarUDFImpl for LocateRegex {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"locate_regex"
}

fn signature(&self) -> &datafusion::logical_expr::Signature {
&self.signature
}

fn invoke(
&self,
args: &[datafusion::logical_expr::ColumnarValue],
) -> Result<datafusion::logical_expr::ColumnarValue> {
if args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(format!(
"{} takes two arguments",
self.name()
)));
}

let sequence = &args[0];
let pattern = &args[1];

match (sequence, pattern) {
(
ColumnarValue::Array(sequence_arr),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(pattern_scalar))),
) => {
let regex_pattern = regex::Regex::new(pattern_scalar).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"Error creating regex pattern: {}",
e
))
})?;

let sequence_arr = as_string_array(sequence_arr)?;

let struct_builder = StructBuilder::from_fields(
vec![
Field::new("start", DataType::Int32, true),
Field::new("end", DataType::Int32, true),
Field::new("match", DataType::Utf8, true),
],
1,
);

let mut list_builder =
ListBuilder::with_capacity(struct_builder, sequence_arr.len());

for sequence in sequence_arr.iter() {
if let Some(sequence) = sequence {
let struct_builder = list_builder.values();

for m in regex_pattern.find_iter(sequence) {
let start_builder = ok_or!(
struct_builder.field_builder::<Int32Builder>(0),
"Error creating start builder"
)?;
start_builder.append_value((m.start() as i32) + 1);

let end_builder = ok_or!(
struct_builder.field_builder::<Int32Builder>(1),
"Error creating end builder"
)?;
end_builder.append_value((m.end() as i32) + 1);

let match_builder = ok_or!(
struct_builder.field_builder::<GenericStringBuilder<i32>>(2),
"Error creating match builder"
)?;
match_builder.append_value(&sequence[m.start()..m.end()]);

struct_builder.append(true);
}

list_builder.append(true);
} else {
list_builder.append_null();
continue;
}
}

let list_array = list_builder.finish();

Ok(ColumnarValue::Array(Arc::new(list_array)))
}
(
ColumnarValue::Scalar(ScalarValue::Utf8(Some(scalar_arr))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(pattern_scalar))),
) => {
let regex_pattern = regex::Regex::new(pattern_scalar).unwrap();

let struct_builder = StructBuilder::from_fields(
vec![
Field::new("start", DataType::Int32, true),
Field::new("end", DataType::Int32, true),
Field::new("match", DataType::Utf8, true),
],
1,
);
let mut list_builder = ListBuilder::new(struct_builder);

let struct_builder = list_builder.values();

regex_pattern.find_iter(scalar_arr).for_each(|m| {
let start_builder = struct_builder.field_builder::<Int32Builder>(0).unwrap();
start_builder.append_value((m.start() as i32) + 1);

let end_builder = struct_builder.field_builder::<Int32Builder>(1).unwrap();
end_builder.append_value((m.end() as i32) + 1);

let match_builder = struct_builder
.field_builder::<GenericStringBuilder<i32>>(2)
.unwrap();
match_builder.append_value(&scalar_arr[m.start()..m.end()]);

struct_builder.append(true);
});

list_builder.append(true);

let list_array = list_builder.finish();

Ok(ColumnarValue::Array(Arc::new(list_array)))
}
_ => Err(datafusion::error::DataFusionError::Execution(format!(
"{} takes different types of arguments, the a sequence and a pattern",
self.name()
))),
}
}

fn return_type(
&self,
arg_types: &[arrow::datatypes::DataType],
) -> Result<arrow::datatypes::DataType> {
if arg_types.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(format!(
"{} takes two arguments",
self.name()
)));
}

// List of structs with start and end fields
let fields = Fields::from(vec![
Field::new("start", DataType::Int32, true),
Field::new("end", DataType::Int32, true),
Field::new("match", DataType::Utf8, true),
]);
let struct_type = DataType::Struct(fields);

let dt = DataType::List(Arc::new(Field::new("item", struct_type, true)));

Ok(dt)
}
}
5 changes: 5 additions & 0 deletions exon/exon-core/src/udfs/sequence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod alignment_score;
mod gc_content;
mod locate_regex;
mod quality_score_list_to_string;
mod quality_score_string_to_list;
mod reverse_complement;
Expand Down Expand Up @@ -54,4 +55,8 @@ pub fn register_udfs(ctx: &SessionContext) {
let reverse_complement = ReverseComplement::default();
let reverse_complement_scalar = ScalarUDF::from(reverse_complement);
ctx.register_udf(reverse_complement_scalar);

let locate_regex = locate_regex::LocateRegex::default();
let locate_regex_udf = ScalarUDF::from(locate_regex);
ctx.register_udf(locate_regex_udf);
}
19 changes: 19 additions & 0 deletions exon/exon-core/tests/sqllogictests/slt/sequence-udfs.slt
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,24 @@ SELECT alignment_score(s1, s2) score FROM dna_sequences
statement error
SELECT alignment_score('A', 'T', 'hi')

query I
SELECT locate_regex('agctggagctacc', 'agc')
----
[{start: 1, end: 4, match: agc}, {start: 7, end: 10, match: agc}]

query I
SELECT locate_regex('agctggagctacc', 'a[atcg]c')
----
[{start: 1, end: 4, match: agc}, {start: 7, end: 10, match: agc}, {start: 11, end: 14, match: acc}]

query I
SELECT locate_regex(s1, 'A[TCG]C') FROM dna_sequences
----
[{start: 1, end: 4, match: ATC}]
[]

statement error Execution error: locate_regex takes different types of arguments, the a sequence and a pattern
SELECT locate_regex('A[TCG]C', s1) FROM dna_sequences

statement ok
DROP TABLE dna_sequences;

0 comments on commit 2f526d6

Please sign in to comment.