diff --git a/python/cdshealpix/skymap/skymap.py b/python/cdshealpix/skymap/skymap.py index f79b6f4..0710515 100644 --- a/python/cdshealpix/skymap/skymap.py +++ b/python/cdshealpix/skymap/skymap.py @@ -1,8 +1,9 @@ """Manipulation of skymaps. -SkyMaps are described in _ -This sub-module supports skymaps in the nested scheme, and in the implicit format where the first pixels. -The coordsystem should be 'CEL'. +SkyMaps are described in _ +This sub-module supports skymaps in the nested scheme, and in the implicit format. +The coordinates system should be 'CEL'. """ from .. import cdshealpix @@ -28,7 +29,7 @@ def from_fits(cls, path: Union[str, Path]): Parameters ---------- - path : Union[str | Path] + path : str, `pathlib.Path` The file's path. Returns @@ -37,3 +38,13 @@ def from_fits(cls, path: Union[str, Path]): The map in a numpy array. Its dtype is inferred from the fits header. """ return cls(cdshealpix.read_skymap(str(path))) + + def to_fits(self, path): + """Write a Skymap in a fits file. + + Parameters + ---------- + path : str, pathlib.Path + The file's path. + """ + cdshealpix.write_skymap(self.values, str(path)) diff --git a/python/cdshealpix/tests/test_skymaps.py b/python/cdshealpix/tests/test_skymaps.py index f7b447d..9d8114f 100644 --- a/python/cdshealpix/tests/test_skymaps.py +++ b/python/cdshealpix/tests/test_skymaps.py @@ -1,13 +1,22 @@ from pathlib import Path +from tempfile import NamedTemporaryFile import numpy as np from ..skymap import Skymap +path_to_test_skymap = Path(__file__).parent.resolve() / "resources" / "skymap.fits" + def test_read(): - values = Skymap.from_fits( - Path(__file__).parent.resolve() / "resources" / "skymap.fits" - ).values + values = Skymap.from_fits(path_to_test_skymap).values assert values.dtype == np.int32 assert len(values) == 49152 + + +def test_read_write_read_conservation(): + skymap = Skymap.from_fits(path_to_test_skymap) + with NamedTemporaryFile() as fp: + skymap.to_fits(fp.name) + skymap2 = Skymap.from_fits(fp.name) + assert all(skymap.values == skymap2.values) diff --git a/src/lib.rs b/src/lib.rs index 2299d46..5c42042 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,8 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { // add skymap pyfunctions here m.add_function(wrap_pyfunction!(skymap_functions::read_skymap, m)?) .unwrap(); + m.add_function(wrap_pyfunction!(skymap_functions::write_skymap, m)?) + .unwrap(); // wrapper of to_ring and from_ring #[pyfn(m)] diff --git a/src/skymap_functions.rs b/src/skymap_functions.rs index db72cb2..60d0307 100644 --- a/src/skymap_functions.rs +++ b/src/skymap_functions.rs @@ -1,16 +1,20 @@ extern crate healpix; +use std::fs::File; use std::i64; -use numpy::IntoPyArray; +use numpy::{IntoPyArray, PyReadonlyArray1}; use pyo3::{ exceptions::PyIOError, prelude::*, types::{PyAny, PyModule}, - Bound, PyResult, + Bound, PyErr, PyResult, }; -use healpix::nested::map::skymap::{SkyMap, SkyMapEnum}; +use healpix::nested::map::{ + fits::write::write_implicit_skymap_fits, + skymap::{SkyMap, SkyMapEnum}, +}; #[pyfunction] #[pyo3(pass_module)] @@ -59,3 +63,45 @@ pub fn read_skymap<'py>( .into_any(), }) } + +// we define an enum for the supported numpy dtypes +#[derive(FromPyObject)] +pub enum SupportedArray<'py> { + F64(PyReadonlyArray1<'py, f64>), + I64(PyReadonlyArray1<'py, i64>), + F32(PyReadonlyArray1<'py, f32>), + I32(PyReadonlyArray1<'py, i32>), + I16(PyReadonlyArray1<'py, i16>), + U8(PyReadonlyArray1<'py, u8>), +} + +#[pyfunction] +pub fn write_skymap<'py>(values: SupportedArray<'py>, path: String) -> Result<(), PyErr> { + let writer = File::create(path).map_err(|err| PyIOError::new_err(err.to_string()))?; + match values { + SupportedArray::F64(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| { + write_implicit_skymap_fits(writer, slice) + .map_err(|err| PyIOError::new_err(err.to_string()).into()) + }), + SupportedArray::I64(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| { + write_implicit_skymap_fits(writer, slice) + .map_err(|err| PyIOError::new_err(err.to_string()).into()) + }), + SupportedArray::F32(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| { + write_implicit_skymap_fits(writer, slice) + .map_err(|err| PyIOError::new_err(err.to_string()).into()) + }), + SupportedArray::I32(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| { + write_implicit_skymap_fits(writer, slice) + .map_err(|err| PyIOError::new_err(err.to_string()).into()) + }), + SupportedArray::I16(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| { + write_implicit_skymap_fits(writer, slice) + .map_err(|err| PyIOError::new_err(err.to_string()).into()) + }), + SupportedArray::U8(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| { + write_implicit_skymap_fits(writer, slice) + .map_err(|err| PyIOError::new_err(err.to_string()).into()) + }), + } +}