Skip to content

Commit

Permalink
feat: use stac-api's new python feature (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski authored Dec 6, 2024
1 parent 00f1fab commit df59d1d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 94 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ pyo3-async-runtimes = { version = "0.23.0", features = [
"tokio-runtime",
] }
pythonize = "0.23.0"
serde = "1.0.215"
serde_json = "1.0.133"
stac-api = { version = "0.6.2", git = "https://github.com/stac-utils/stac-rs" }
stac-api = { version = "0.6.2", features = [
"python",
], git = "https://github.com/stac-utils/stac-rs" }
stac = { version = "0.11.0", git = "https://github.com/stac-utils/stac-rs" }
thiserror = "2.0.4"
tokio = "1.41.1"
tokio-postgres = { version = "0.7.12", features = ["with-serde_json-1"] }
95 changes: 11 additions & 84 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![deny(unused_crate_dependencies)]

use bb8::{Pool, RunError};
use bb8_postgres::PostgresConnectionManager;
use geojson::Geometry;
use pgstac::Pgstac;
use pyo3::{
create_exception,
Expand All @@ -9,8 +10,7 @@ use pyo3::{
types::{PyDict, PyList, PyType},
};
use serde_json::Value;
use stac::Bbox;
use stac_api::{Fields, Filter, Items, Search, Sortby};
use stac_api::python::{StringOrDict, StringOrList};
use std::{future::Future, str::FromStr};
use thiserror::Error;
use tokio_postgres::{Config, NoTls};
Expand All @@ -20,18 +20,6 @@ create_exception!(pgstacrs, StacError, PyException);

type PgstacPool = Pool<PostgresConnectionManager<NoTls>>;

#[derive(FromPyObject)]
pub enum StringOrDict {
String(String),
Dict(Py<PyDict>),
}

#[derive(FromPyObject)]
pub enum StringOrList {
String(String),
List(Vec<String>),
}

#[derive(Debug, Error)]
enum Error {
#[error(transparent)]
Expand Down Expand Up @@ -290,71 +278,19 @@ impl Client {
query: Option<Bound<'a, PyDict>>,
limit: Option<u64>,
) -> PyResult<Bound<'a, PyAny>> {
// TODO refactor to use https://github.com/gadomski/stacrs/blob/1528d7e1b7185a86efe9fc7c42b0620093c5e9c6/src/search.rs#L128-L162
let mut fields = Fields::default();
if let Some(include) = include {
fields.include = include.into();
}
if let Some(exclude) = exclude {
fields.exclude = exclude.into();
}
let fields = if fields.include.is_empty() && fields.exclude.is_empty() {
None
} else {
Some(fields)
};
let query = query
.map(|query| pythonize::depythonize(&query))
.transpose()?;
let bbox = bbox
.map(|bbox| Bbox::try_from(bbox))
.transpose()
.map_err(Error::from)?;
let sortby = sortby.map(|sortby| {
Vec::<String>::from(sortby)
.into_iter()
.map(|s| s.parse::<Sortby>().unwrap()) // the parse is infallible
.collect::<Vec<_>>()
});
let filter = filter
.map(|filter| match filter {
StringOrDict::Dict(cql_json) => {
pythonize::depythonize(&cql_json.bind_borrowed(py)).map(Filter::Cql2Json)
}
StringOrDict::String(cql2_text) => Ok(Filter::Cql2Text(cql2_text)),
})
.transpose()?;
let filter = filter
.map(|filter| filter.into_cql2_json())
.transpose()
.map_err(Error::from)?;
let items = Items {
let search = stac_api::python::search(
intersects,
ids,
collections,
limit,
bbox,
datetime,
query,
fields,
include,
exclude,
sortby,
filter,
..Default::default()
};

let intersects = intersects
.map(|intersects| match intersects {
StringOrDict::Dict(json) => pythonize::depythonize(&json.bind_borrowed(py))
.map_err(Error::from)
.and_then(|json| Geometry::from_json_object(json).map_err(Error::from)),
StringOrDict::String(s) => s.parse().map_err(Error::from),
})
.transpose()?;
let ids = ids.map(|ids| ids.into());
let collections = collections.map(|ids| ids.into());
let search = Search {
items,
intersects,
ids,
collections,
};
query,
)?;
self.run(py, |pool| async move {
let connection = pool.get().await?;
let page = connection.search(search).await?;
Expand Down Expand Up @@ -406,15 +342,6 @@ impl From<Error> for PyErr {
}
}

impl From<StringOrList> for Vec<String> {
fn from(value: StringOrList) -> Vec<String> {
match value {
StringOrList::List(list) => list,
StringOrList::String(s) => vec![s],
}
}
}

#[pymodule]
fn pgstacrs(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Client>()?;
Expand Down

0 comments on commit df59d1d

Please sign in to comment.