Skip to content

Commit

Permalink
Add ST_MakeBox2D, ST_Expand, fix RectArray round trip (#946)
Browse files Browse the repository at this point in the history
### Change list

- Add `ST_MakeBox2D`, `ST_Expand`.
- Add test for each.
- Fix round-tripping `RectArray` to an `ArrayRef`
- Add test of round-tripping `RectArray` to an `ArrayRef`
  • Loading branch information
kylebarron authored Dec 13, 2024
1 parent b999078 commit f51443a
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 31 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 49 additions & 24 deletions rust/geoarrow/src/array/rect/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::sync::Arc;

use arrow::array::AsArray;
use arrow::datatypes::Float64Type;
use arrow_array::{Array, ArrayRef, Float64Array, StructArray};
use arrow_buffer::NullBuffer;
use arrow_array::{Array, ArrayRef, StructArray};
use arrow_buffer::{NullBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field};

use crate::array::metadata::ArrayMetadata;
Expand Down Expand Up @@ -182,14 +182,12 @@ impl IntoArrow for RectArray {
fn into_arrow(self) -> Self::ArrowArray {
let fields = rect_fields(self.data_type.dimension().unwrap());
let mut arrays: Vec<ArrayRef> = vec![];
for buf in self.lower.buffers {
arrays.push(Arc::new(Float64Array::new(buf, None)));
}
for buf in self.upper.buffers {
arrays.push(Arc::new(Float64Array::new(buf, None)));
}
let validity = self.validity;

// values_array takes care of the correct number of dimensions
arrays.extend_from_slice(self.lower.values_array().as_slice());
arrays.extend_from_slice(self.upper.values_array().as_slice());

let validity = self.validity;
StructArray::new(fields, arrays, validity)
}
}
Expand All @@ -202,23 +200,24 @@ impl TryFrom<(&StructArray, Dimension)> for RectArray {
let columns = value.columns();
assert_eq!(columns.len(), dim.size() * 2);

let lower = match dim {
Dimension::XY => {
core::array::from_fn(|i| columns[i].as_primitive::<Float64Type>().values().clone())
}
Dimension::XYZ => {
core::array::from_fn(|i| columns[i].as_primitive::<Float64Type>().values().clone())
let dim_size = dim.size();
let lower = core::array::from_fn(|i| {
if i < dim_size {
columns[i].as_primitive::<Float64Type>().values().clone()
} else {
ScalarBuffer::from(vec![])
}
};
let upper = match dim {
Dimension::XY => {
core::array::from_fn(|i| columns[i].as_primitive::<Float64Type>().values().clone())
});
let upper = core::array::from_fn(|i| {
if i < dim_size {
columns[dim_size + i]
.as_primitive::<Float64Type>()
.values()
.clone()
} else {
ScalarBuffer::from(vec![])
}
Dimension::XYZ => {
core::array::from_fn(|i| columns[i].as_primitive::<Float64Type>().values().clone())
}
};

});
Ok(Self::new(
SeparatedCoordBuffer::new(lower, dim),
SeparatedCoordBuffer::new(upper, dim),
Expand Down Expand Up @@ -271,3 +270,29 @@ impl<G: RectTrait<T = f64>> From<(Vec<Option<G>>, Dimension)> for RectArray {
mut_arr.into()
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::algorithm::native::eq::rect_eq;
use crate::array::RectBuilder;
use crate::datatypes::Dimension;

#[test]
fn rect_array_round_trip() {
let rect = geo::Rect::new(
geo::coord! { x: 0.0, y: 5.0 },
geo::coord! { x: 10.0, y: 15.0 },
);
let mut builder =
RectBuilder::with_capacity_and_options(Dimension::XY, 1, Default::default());
builder.push_rect(Some(&rect));
builder.push_min_max(&rect.min(), &rect.max());
let rect_arr = builder.finish();

let arrow_arr = rect_arr.into_array_ref();
let rect_arr_again = RectArray::try_from((arrow_arr.as_ref(), Dimension::XY)).unwrap();
let rect_again = rect_arr_again.value(0);
assert!(rect_eq(&rect, &rect_again));
}
}
10 changes: 9 additions & 1 deletion rust/geoarrow/src/array/rect/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::scalar::Rect;
use crate::trait_::IntoArrow;
use arrow_array::{Array, StructArray};
use arrow_buffer::NullBufferBuilder;
use geo_traits::RectTrait;
use geo_traits::{CoordTrait, RectTrait};
use std::sync::Arc;

/// The GeoArrow equivalent to `Vec<Option<Rect>>`: a mutable collection of Rects.
Expand Down Expand Up @@ -168,6 +168,14 @@ impl RectBuilder {
}
}

/// Push min and max coordinates of a rect to the builder.
#[inline]
pub fn push_min_max(&mut self, min: &impl CoordTrait<T = f64>, max: &impl CoordTrait<T = f64>) {
self.lower.push_coord(min);
self.upper.push_coord(max);
self.validity.append_non_null()
}

/// Create this builder from a iterator of Rects.
pub fn from_rects<'a>(
geoms: impl ExactSizeIterator<Item = &'a (impl RectTrait<T = f64> + 'a)>,
Expand Down
1 change: 1 addition & 0 deletions rust/geodatafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ geoarrow = { path = "../geoarrow", features = ["flatgeobuf"] }
thiserror = "1"

[dev-dependencies]
approx = "0.5.1"
tokio = { version = "1.9", features = ["macros", "fs", "rt-multi-thread"] }
4 changes: 2 additions & 2 deletions rust/geodatafusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,10 @@ Spatial extensions for [Apache DataFusion](https://datafusion.apache.org/), an e
| Box2D || Returns a BOX2D representing the 2D extent of a geometry. |
| Box3D | | Returns a BOX3D representing the 3D extent of a geometry. |
| ST_EstimatedExtent | | Returns the estimated extent of a spatial table. |
| ST_Expand | | Returns a bounding box expanded from another bounding box or a geometry. |
| ST_Expand | | Returns a bounding box expanded from another bounding box or a geometry. |
| ST_Extent | | Aggregate function that returns the bounding box of geometries. |
| ST_3DExtent | | Aggregate function that returns the 3D bounding box of geometries. |
| ST_MakeBox2D | | Creates a BOX2D defined by two 2D point geometries. |
| ST_MakeBox2D | | Creates a BOX2D defined by two 2D point geometries. |
| ST_3DMakeBox | | Creates a BOX3D defined by two 3D point geometries. |
| ST_XMax || Returns the X maxima of a 2D or 3D bounding box or a geometry. |
| ST_XMin || Returns the X minima of a 2D or 3D bounding box or a geometry. |
Expand Down
179 changes: 179 additions & 0 deletions rust/geodatafusion/src/udf/native/bounding_box/expand.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
use std::any::Any;
use std::sync::OnceLock;

use arrow::array::AsArray;
use arrow::datatypes::Float64Type;
use arrow_schema::DataType;
use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_OTHER;
use datafusion::logical_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use geo_traits::{CoordTrait, RectTrait};
use geoarrow::array::{RectArray, RectBuilder};
use geoarrow::datatypes::Dimension;
use geoarrow::error::GeoArrowError;
use geoarrow::trait_::ArrayAccessor;
use geoarrow::ArrayBase;

use crate::data_types::BOX2D_TYPE;
use crate::error::GeoDataFusionResult;

#[derive(Debug)]
pub(super) struct Expand {
signature: Signature,
}

impl Expand {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![BOX2D_TYPE.into(), DataType::Float64]),
TypeSignature::Exact(vec![
BOX2D_TYPE.into(),
DataType::Float64,
DataType::Float64,
]),
],
Volatility::Immutable,
),
}
}
}

static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();

impl ScalarUDFImpl for Expand {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"st_expand"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
Ok(arg_types.first().unwrap().clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion::error::Result<ColumnarValue> {
Ok(expand_impl(args)?)
}

fn documentation(&self) -> Option<&Documentation> {
Some(DOCUMENTATION.get_or_init(|| {
Documentation::builder(
DOC_SECTION_OTHER,
"Returns a bounding box expanded from the bounding box of the input, either by specifying a single distance with which the box should be expanded on both axes, or by specifying an expansion distance for each axis. Uses double-precision. Can be used for distance queries, or to add a bounding box filter to a query to take advantage of a spatial index.",
"ST_Expand(box)",
)
.with_argument("box", "box2d")
.build()
}))
}
}

fn expand_impl(args: &[ColumnarValue]) -> GeoDataFusionResult<ColumnarValue> {
let mut args = ColumnarValue::values_to_arrays(args)?.into_iter();
let rect_array = args.next().unwrap();
let factor1 = args.next().unwrap();
let factor2 = args.next();

let dx = factor1.as_primitive::<Float64Type>();

if BOX2D_TYPE
.to_data_type()
.equals_datatype(rect_array.data_type())
{
let rect_array = RectArray::try_from((rect_array.as_ref(), Dimension::XY))?;
let mut builder = RectBuilder::with_capacity_and_options(
Dimension::XY,
rect_array.len(),
rect_array.metadata().clone(),
);

if let Some(dy) = factor2 {
let dy = dy.as_primitive::<Float64Type>();

for val in rect_array.iter().zip(dx.iter()).zip(dy.iter()) {
if let ((Some(rect), Some(dx)), Some(dy)) = val {
builder.push_rect(Some(&expand_2d_rect(rect, dx, dy)));
} else {
builder.push_null();
}
}
} else {
for val in rect_array.iter().zip(dx.iter()) {
if let (Some(rect), Some(dx)) = val {
builder.push_rect(Some(&expand_2d_rect(rect, dx, dx)));
} else {
builder.push_null();
}
}
}

return Ok(builder.finish().into_array_ref().into());
}

Err(Err(GeoArrowError::General(format!(
"Unexpected data type: {:?}",
rect_array.data_type()
)))?)
}

#[inline]
fn expand_2d_rect(rect: impl RectTrait<T = f64>, dx: f64, dy: f64) -> geo::Rect<f64> {
let min = rect.min();
let max = rect.max();

let new_min = geo::coord! { x: min.x() - dx, y: min.y() - dy };
let new_max = geo::coord! { x: max.x() + dx, y: max.y() + dy };

geo::Rect::new(new_min, new_max)
}

#[cfg(test)]
mod test {
use approx::relative_eq;
use datafusion::prelude::*;
use geo_traits::{CoordTrait, RectTrait};
use geoarrow::array::RectArray;
use geoarrow::datatypes::Dimension;
use geoarrow::trait_::ArrayAccessor;

use crate::data_types::BOX2D_TYPE;
use crate::udf::native::register_native;

#[tokio::test]
async fn test() {
let ctx = SessionContext::new();
register_native(&ctx);

let out = ctx
.sql("SELECT ST_Expand(ST_MakeBox2D(ST_Point(0, 5), ST_Point(10, 20)), 10, 20);")
.await
.unwrap();

let batches = out.collect().await.unwrap();
assert_eq!(batches.len(), 1);
let batch = batches.into_iter().next().unwrap();
assert_eq!(batch.columns().len(), 1);
assert!(batch
.schema()
.field(0)
.data_type()
.equals_datatype(&BOX2D_TYPE.into()));

let rect_array = RectArray::try_from((batch.columns()[0].as_ref(), Dimension::XY)).unwrap();
let rect = rect_array.value(0);

assert!(relative_eq!(rect.min().x(), -10.0));
assert!(relative_eq!(rect.min().y(), -15.0));
assert!(relative_eq!(rect.max().x(), 20.0));
assert!(relative_eq!(rect.max().y(), 40.0));
}
}
Loading

0 comments on commit f51443a

Please sign in to comment.