Skip to content

Commit

Permalink
Clean up returning table to Python
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Dec 6, 2024
1 parent e35e12c commit 1086c43
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 60 deletions.
7 changes: 3 additions & 4 deletions python/geoarrow-io/src/io/csv.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{FileReader, FileWriter};
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::io::csv::read_csv as _read_csv;
use geoarrow::io::csv::write_csv as _write_csv;
use geoarrow::io::csv::CSVReaderOptions;
Expand All @@ -10,14 +10,13 @@ use pyo3_arrow::input::AnyRecordBatch;
#[pyfunction]
#[pyo3(signature = (file, geometry_column_name, *, batch_size=65536))]
pub fn read_csv(
py: Python,
mut file: FileReader,
geometry_column_name: &str,
batch_size: usize,
) -> PyGeoArrowResult<PyObject> {
) -> PyGeoArrowResult<Arro3Table> {
let options = CSVReaderOptions::new(Default::default(), batch_size);
let table = _read_csv(&mut file, geometry_column_name, options)?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
}

#[pyfunction]
Expand Down
22 changes: 10 additions & 12 deletions python/geoarrow-io/src/io/flatgeobuf/async.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::{PyGeoArrowError, PyGeoArrowResult};
use crate::error::PyGeoArrowError;
use crate::io::input::construct_async_reader;
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::io::flatgeobuf::read_flatgeobuf_async as _read_flatgeobuf_async;
use geoarrow::io::flatgeobuf::FlatGeobufReaderOptions;
use pyo3::prelude::*;
Expand All @@ -9,16 +9,16 @@ use pyo3_geoarrow::PyCoordType;

#[pyfunction]
#[pyo3(signature = (path, *, store=None, batch_size=65536, bbox=None, coord_type=None))]
pub fn read_flatgeobuf_async(
py: Python,
path: Bound<PyAny>,
store: Option<Bound<PyAny>>,
pub fn read_flatgeobuf_async<'py>(
py: Python<'py>,
path: Bound<'py, PyAny>,
store: Option<Bound<'py, PyAny>>,
batch_size: usize,
bbox: Option<(f64, f64, f64, f64)>,
coord_type: Option<PyCoordType>,
) -> PyGeoArrowResult<PyObject> {
) -> PyResult<Bound<'py, PyAny>> {
let reader = construct_async_reader(path, store)?;
let fut = future_into_py(py, async move {
future_into_py(py, async move {
let options = FlatGeobufReaderOptions {
batch_size: Some(batch_size),
bbox,
Expand All @@ -27,8 +27,6 @@ pub fn read_flatgeobuf_async(
let table = _read_flatgeobuf_async(reader.store, reader.path, options)
.await
.map_err(PyGeoArrowError::GeoArrowError)?;

Ok(table_to_pytable(table))
})?;
Ok(fut.into())
Ok(Arro3Table::from_geoarrow(table))
})
}
8 changes: 4 additions & 4 deletions python/geoarrow-io/src/io/flatgeobuf/sync.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error::{PyGeoArrowError, PyGeoArrowResult};
use crate::io::input::sync::FileWriter;
use crate::io::input::{construct_reader, AnyFileReader};
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::io::flatgeobuf::{
read_flatgeobuf as _read_flatgeobuf, write_flatgeobuf_with_options as _write_flatgeobuf,
FlatGeobufReaderOptions, FlatGeobufWriterOptions,
Expand All @@ -18,7 +18,7 @@ pub fn read_flatgeobuf(
store: Option<Bound<PyAny>>,
batch_size: usize,
bbox: Option<(f64, f64, f64, f64)>,
) -> PyGeoArrowResult<PyObject> {
) -> PyGeoArrowResult<Arro3Table> {
let reader = construct_reader(file, store)?;
match reader {
#[cfg(feature = "async")]
Expand All @@ -39,7 +39,7 @@ pub fn read_flatgeobuf(
.await
.map_err(PyGeoArrowError::GeoArrowError)?;

Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
})
}
AnyFileReader::Sync(mut sync_reader) => {
Expand All @@ -49,7 +49,7 @@ pub fn read_flatgeobuf(
..Default::default()
};
let table = _read_flatgeobuf(&mut sync_reader, options)?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
}
}
}
Expand Down
10 changes: 3 additions & 7 deletions python/geoarrow-io/src/io/geojson.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{FileReader, FileWriter};
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::io::geojson::read_geojson as _read_geojson;
use geoarrow::io::geojson::write_geojson as _write_geojson;
use pyo3::prelude::*;
use pyo3_arrow::PyRecordBatchReader;

#[pyfunction]
#[pyo3(signature = (file, *, batch_size=65536))]
pub fn read_geojson(
py: Python,
mut file: FileReader,
batch_size: usize,
) -> PyGeoArrowResult<PyObject> {
pub fn read_geojson(mut file: FileReader, batch_size: usize) -> PyGeoArrowResult<Arro3Table> {
let table = _read_geojson(&mut file, Some(batch_size))?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
}

#[pyfunction]
Expand Down
10 changes: 3 additions & 7 deletions python/geoarrow-io/src/io/geojson_lines.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{FileReader, FileWriter};
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::io::geojson_lines::read_geojson_lines as _read_geojson_lines;
use geoarrow::io::geojson_lines::write_geojson_lines as _write_geojson_lines;
use pyo3::prelude::*;
use pyo3_arrow::input::AnyRecordBatch;

#[pyfunction]
#[pyo3(signature = (file, *, batch_size=65536))]
pub fn read_geojson_lines(
py: Python,
mut file: FileReader,
batch_size: usize,
) -> PyGeoArrowResult<PyObject> {
pub fn read_geojson_lines(mut file: FileReader, batch_size: usize) -> PyGeoArrowResult<Arro3Table> {
let table = _read_geojson_lines(&mut file, Some(batch_size))?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
}

#[pyfunction]
Expand Down
16 changes: 8 additions & 8 deletions python/geoarrow-io/src/io/parquet/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::io::input::{construct_reader, AnyFileReader};
use crate::io::parquet::options::create_options;
#[cfg(feature = "async")]
use crate::runtime::get_runtime;
use crate::util::table_to_pytable;
use crate::util::Arro3Table;

use geo_traits::CoordTrait;
use geoarrow::error::GeoArrowError;
Expand Down Expand Up @@ -66,7 +66,7 @@ pub fn read_parquet_async(
.await
.map_err(PyGeoArrowError::GeoArrowError)?;

Ok(table_to_pytable(table))
Ok(Arro3Table::from_geoarrow(table))
})?;
Ok(fut.into())
}
Expand Down Expand Up @@ -202,7 +202,7 @@ impl ParquetFile {
.read_table()
.await
.map_err(PyGeoArrowError::GeoArrowError)?;
Ok(table_to_pytable(table))
Ok(Arro3Table::from_geoarrow(table))
})?;
Ok(fut.into())
}
Expand All @@ -216,7 +216,7 @@ impl ParquetFile {
offset: Option<usize>,
bbox: Option<[f64; 4]>,
bbox_paths: Option<Bound<'_, PyAny>>,
) -> PyGeoArrowResult<PyObject> {
) -> PyGeoArrowResult<Arro3Table> {
let runtime = get_runtime(py)?;
let reader = ParquetObjectReader::new(self.store.clone(), self.object_meta.clone());
let options = create_options(batch_size, limit, offset, bbox, bbox_paths)?;
Expand All @@ -231,7 +231,7 @@ impl ParquetFile {
.read_table()
.await
.map_err(PyGeoArrowError::GeoArrowError)?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
})
}
}
Expand Down Expand Up @@ -425,7 +425,7 @@ impl ParquetDataset {
});
let table = Table::try_new(all_batches, output_schema)
.map_err(PyGeoArrowError::GeoArrowError)?;
Ok(table_to_pytable(table))
Ok(Arro3Table::from_geoarrow(table))
})?;
Ok(fut.into())
}
Expand All @@ -439,7 +439,7 @@ impl ParquetDataset {
offset: Option<usize>,
bbox: Option<[f64; 4]>,
bbox_paths: Option<Bound<'_, PyAny>>,
) -> PyGeoArrowResult<PyObject> {
) -> PyGeoArrowResult<Arro3Table> {
let runtime = get_runtime(py)?;
let options = create_options(batch_size, limit, offset, bbox, bbox_paths)?;
let readers = self.to_readers(options)?;
Expand All @@ -460,7 +460,7 @@ impl ParquetDataset {
});
let table = Table::try_new(all_batches, output_schema)
.map_err(PyGeoArrowError::GeoArrowError)?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
})
}
}
8 changes: 4 additions & 4 deletions python/geoarrow-io/src/io/parquet/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Mutex;
use crate::error::{PyGeoArrowError, PyGeoArrowResult};
use crate::io::input::sync::{FileReader, FileWriter};
use crate::io::input::{construct_reader, AnyFileReader};
use crate::util::table_to_pytable;
use crate::util::Arro3Table;

use geoarrow::io::parquet::{GeoParquetReaderOptions, GeoParquetRecordBatchReaderBuilder};
use parquet::arrow::arrow_reader::ArrowReaderOptions;
Expand All @@ -27,7 +27,7 @@ pub fn read_parquet(
path: Bound<PyAny>,
store: Option<Bound<PyAny>>,
batch_size: Option<usize>,
) -> PyGeoArrowResult<PyObject> {
) -> PyGeoArrowResult<Arro3Table> {
let reader = construct_reader(path, store)?;
match reader {
#[cfg(feature = "async")]
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn read_parquet(
.read_table()
.await?;

Ok::<_, PyGeoArrowError>(table_to_pytable(table).to_arro3(py)?)
Ok::<_, PyGeoArrowError>(Arro3Table::from_geoarrow(table))
})?;
Ok(table)
}
Expand All @@ -85,7 +85,7 @@ pub fn read_parquet(
)?
.build()?
.read_table()?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
}
_ => Err(PyValueError::new_err("File objects not supported in Parquet reader.").into()),
},
Expand Down
12 changes: 5 additions & 7 deletions python/geoarrow-io/src/io/postgis.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
use crate::error::PyGeoArrowError;
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::error::GeoArrowError;
use geoarrow::io::postgis::read_postgis as _read_postgis;
use pyo3::prelude::*;
use pyo3_arrow::PyTable;
use pyo3_async_runtimes::tokio::future_into_py;
use sqlx::postgres::PgPoolOptions;

#[pyfunction]
pub fn read_postgis(py: Python, connection_url: String, sql: String) -> PyResult<Option<PyObject>> {
pub fn read_postgis(connection_url: String, sql: String) -> PyResult<Option<Arro3Table>> {
// https://tokio.rs/tokio/topics/bridging#what-tokiomain-expands-to
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

// TODO: py.allow_threads
let out = runtime.block_on(read_postgis_inner(connection_url, sql))?;
out.map(|table| table.to_arro3(py)).transpose()
runtime.block_on(read_postgis_inner(connection_url, sql))
}

#[pyfunction]
Expand All @@ -29,7 +27,7 @@ pub fn read_postgis_async(
future_into_py(py, read_postgis_inner(connection_url, sql))
}

async fn read_postgis_inner(connection_url: String, sql: String) -> PyResult<Option<PyTable>> {
async fn read_postgis_inner(connection_url: String, sql: String) -> PyResult<Option<Arro3Table>> {
let pool = PgPoolOptions::new()
.connect(&connection_url)
.await
Expand All @@ -39,5 +37,5 @@ async fn read_postgis_inner(connection_url: String, sql: String) -> PyResult<Opt
.await
.map_err(PyGeoArrowError::GeoArrowError)?;

Ok(table.map(table_to_pytable))
Ok(table.map(Arro3Table::from_geoarrow))
}
7 changes: 3 additions & 4 deletions python/geoarrow-io/src/io/shapefile.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::FileReader;
use crate::util::table_to_pytable;
use crate::util::Arro3Table;
use geoarrow::io::shapefile::read_shapefile as _read_shapefile;
use pyo3::prelude::*;

#[pyfunction]
// #[pyo3(signature = (file, *, batch_size=65536))]
pub fn read_shapefile(
py: Python,
mut shp_file: FileReader,
mut dbf_file: FileReader,
) -> PyGeoArrowResult<PyObject> {
) -> PyGeoArrowResult<Arro3Table> {
let table = _read_shapefile(&mut shp_file, &mut dbf_file)?;
Ok(table_to_pytable(table).to_arro3(py)?)
Ok(Arro3Table::from_geoarrow(table))
}
32 changes: 29 additions & 3 deletions python/geoarrow-io/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
use pyo3::prelude::*;
use pyo3_arrow::PyTable;

pub(crate) fn table_to_pytable(table: geoarrow::table::Table) -> PyTable {
let (batches, schema) = table.into_inner();
PyTable::try_new(batches, schema).unwrap()
/// A wrapper around a [PyTable] that implements [IntoPyObject] to convert to a runtime-available
/// arro3.core.Table
///
/// This ensures that we return with the user's runtime-provided arro3.core.Table and not the one
/// we linked from Rust.
pub struct Arro3Table(PyTable);

impl Arro3Table {
pub fn from_geoarrow(table: geoarrow::table::Table) -> Self {
let (batches, schema) = table.into_inner();
PyTable::try_new(batches, schema).unwrap().into()
}
}

impl From<PyTable> for Arro3Table {
fn from(value: PyTable) -> Self {
Self(value)
}
}

impl<'py> IntoPyObject<'py> for Arro3Table {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = PyErr;

fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
Ok(self.0.to_arro3(py)?.bind(py).clone())
}
}

0 comments on commit 1086c43

Please sign in to comment.