diff --git a/CHANGELOG.md b/CHANGELOG.md index ba0e4068f..f03048573 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,11 @@ ## Bug Fixes ## Features +* Use the latest version of ndx-pose for `DeepLabCutInterface` [PR #1128](https://github.com/catalystneuro/neuroconv/pull/1128) ## Improvements + # v0.6.9 (Upcoming) Small fixes should be here. diff --git a/pyproject.toml b/pyproject.toml index cd92852ec..de1dd693f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ sleap = [ "sleap-io>=0.0.2; python_version>='3.9'", ] deeplabcut = [ - "ndx-pose==0.1.1", + "ndx-pose>=0.2", "tables; platform_system != 'Darwin'", "tables>=3.10.1; platform_system == 'Darwin' and python_version >= '3.10'", ] @@ -128,7 +128,7 @@ video = [ "opencv-python-headless>=4.8.1.78", ] lightningpose = [ - "ndx-pose==0.1.1", + "ndx-pose>=0.1.1", "neuroconv[video]", ] medpc = [ diff --git a/src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py b/src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py index 14866510d..bf087afda 100644 --- a/src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py +++ b/src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py @@ -1,4 +1,3 @@ -import importlib import pickle import warnings from pathlib import Path @@ -93,7 +92,7 @@ def _get_cv2_timestamps(file_path: Union[Path, str]): return timestamps -def _get_movie_timestamps(movie_file, VARIABILITYBOUND=1000, infer_timestamps=True): +def _get_video_timestamps(movie_file, VARIABILITYBOUND=1000, infer_timestamps=True): """ Return numpy array of the timestamps for a video. @@ -263,13 +262,52 @@ def _write_pes_to_nwbfile( exclude_nans, pose_estimation_container_kwargs: Optional[dict] = None, ): - - from ndx_pose import PoseEstimation, PoseEstimationSeries + """ + Updated version of _write_pes_to_nwbfile to work with ndx-pose v0.2.0+ + """ + from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons + from pynwb.file import Subject pose_estimation_container_kwargs = pose_estimation_container_kwargs or dict() + pose_estimation_name = pose_estimation_container_kwargs.get("name", "PoseEstimationDeepLabCut") + + # Create a subject if it doesn't exist + if nwbfile.subject is None: + subject = Subject(subject_id=animal) + nwbfile.subject = subject + else: + subject = nwbfile.subject + + # Create skeleton from the keypoints + keypoints = df_animal.columns.get_level_values("bodyparts").unique() + animal = animal if animal else "" + subject = subject if animal == subject.subject_id else None + skeleton_name = f"Skeleton{pose_estimation_name}_{animal.capitalize()}" + skeleton = Skeleton( + name=skeleton_name, + nodes=list(keypoints), + edges=np.array(paf_graph) if paf_graph else None, # Convert paf_graph to numpy array + subject=subject, + ) + + # Create Skeletons container + if "behavior" not in nwbfile.processing: + behavior_processing_module = nwbfile.create_processing_module( + name="behavior", description="processed behavioral data" + ) + skeletons = Skeletons(skeletons=[skeleton]) + behavior_processing_module.add(skeletons) + else: + behavior_processing_module = nwbfile.processing["behavior"] + if "Skeletons" not in behavior_processing_module.data_interfaces: + skeletons = Skeletons(skeletons=[skeleton]) + behavior_processing_module.add(skeletons) + else: + skeletons = behavior_processing_module["Skeletons"] + skeletons.add_skeletons(skeleton) pose_estimation_series = [] - for keypoint in df_animal.columns.get_level_values("bodyparts").unique(): + for keypoint in keypoints: data = df_animal.xs(keypoint, level="bodyparts", axis=1).to_numpy() if exclude_nans: @@ -292,35 +330,31 @@ def _write_pes_to_nwbfile( ) pose_estimation_series.append(pes) - deeplabcut_version = None - is_deeplabcut_installed = importlib.util.find_spec(name="deeplabcut") is not None - if is_deeplabcut_installed: - deeplabcut_version = importlib.metadata.version(distribution_name="deeplabcut") + camera_name = pose_estimation_name + if camera_name not in nwbfile.devices: + camera = nwbfile.create_device( + name=camera_name, + description="Camera used for behavioral recording and pose estimation.", + ) + else: + camera = nwbfile.devices[camera_name] - # TODO, taken from the original implementation, improve it if the video is passed + # Create PoseEstimation container with updated arguments dimensions = [list(map(int, image_shape.split(",")))[1::2]] dimensions = np.array(dimensions, dtype="uint32") pose_estimation_default_kwargs = dict( pose_estimation_series=pose_estimation_series, description="2D keypoint coordinates estimated using DeepLabCut.", - original_videos=[video_file_path], + original_videos=[video_file_path] if video_file_path else None, dimensions=dimensions, + devices=[camera], scorer=scorer, source_software="DeepLabCut", - source_software_version=deeplabcut_version, - nodes=[pes.name for pes in pose_estimation_series], - edges=paf_graph if paf_graph else None, - **pose_estimation_container_kwargs, + skeleton=skeleton, ) pose_estimation_default_kwargs.update(pose_estimation_container_kwargs) pose_estimation_container = PoseEstimation(**pose_estimation_default_kwargs) - if "behavior" in nwbfile.processing: # TODO: replace with get_module - behavior_processing_module = nwbfile.processing["behavior"] - else: - behavior_processing_module = nwbfile.create_processing_module( - name="behavior", description="processed behavioral data" - ) behavior_processing_module.add(pose_estimation_container) return nwbfile @@ -387,7 +421,7 @@ def _add_subject_to_nwbfile( if video_file_path is None: timestamps = df.index.tolist() # setting timestamps to dummy else: - timestamps = _get_movie_timestamps(video_file_path, infer_timestamps=True) + timestamps = _get_video_timestamps(video_file_path, infer_timestamps=True) # Fetch the corresponding metadata pickle file, we extract the edges graph from here # TODO: This is the original implementation way to extract the file name but looks very brittle. Improve it diff --git a/src/neuroconv/datainterfaces/behavior/deeplabcut/deeplabcutdatainterface.py b/src/neuroconv/datainterfaces/behavior/deeplabcut/deeplabcutdatainterface.py index 147fcf6ea..e42926c4d 100644 --- a/src/neuroconv/datainterfaces/behavior/deeplabcut/deeplabcutdatainterface.py +++ b/src/neuroconv/datainterfaces/behavior/deeplabcut/deeplabcutdatainterface.py @@ -12,7 +12,7 @@ class DeepLabCutInterface(BaseTemporalAlignmentInterface): """Data interface for DeepLabCut datasets.""" display_name = "DeepLabCut" - keywords = ("DLC",) + keywords = ("DLC", "DeepLabCut", "pose estimation", "behavior") associated_suffixes = (".h5", ".csv") info = "Interface for handling data from DeepLabCut." @@ -62,6 +62,8 @@ def __init__( self.config_dict = _read_config(config_file_path=config_file_path) self.subject_name = subject_name self.verbose = verbose + self.pose_estimation_container_kwargs = dict() + super().__init__(file_path=file_path, config_file_path=config_file_path) def get_metadata(self): @@ -101,7 +103,7 @@ def add_to_nwbfile( self, nwbfile: NWBFile, metadata: Optional[dict] = None, - container_name: str = "PoseEstimation", + container_name: str = "PoseEstimationDeepLabCut", ): """ Conversion from DLC output files to nwb. Derived from dlc2nwb library. @@ -112,8 +114,9 @@ def add_to_nwbfile( nwb file to which the recording information is to be added metadata: dict metadata info for constructing the nwb file (optional). - container_name: str, default: "PoseEstimation" - Name of the container to store the pose estimation. + container_name: str, default: "PoseEstimationDeepLabCut" + name of the PoseEstimation container in the nwb + """ from ._dlc_utils import _add_subject_to_nwbfile @@ -123,5 +126,5 @@ def add_to_nwbfile( individual_name=self.subject_name, config_file=self.source_data["config_file_path"], timestamps=self._timestamps, - pose_estimation_container_kwargs=dict(name=container_name), + pose_estimation_container_kwargs=self.pose_estimation_container_kwargs, ) diff --git a/src/neuroconv/datainterfaces/behavior/lightningpose/lightningposedatainterface.py b/src/neuroconv/datainterfaces/behavior/lightningpose/lightningposedatainterface.py index dbd425b5b..0ec17c810 100644 --- a/src/neuroconv/datainterfaces/behavior/lightningpose/lightningposedatainterface.py +++ b/src/neuroconv/datainterfaces/behavior/lightningpose/lightningposedatainterface.py @@ -80,14 +80,22 @@ def __init__( verbose : bool, default: True controls verbosity. ``True`` by default. """ + from importlib.metadata import version + # This import is to assure that the ndx_pose is in the global namespace when an pynwb.io object is created # For more detail, see https://github.com/rly/ndx-pose/issues/36 import ndx_pose # noqa: F401 + from packaging import version as version_parse from neuroconv.datainterfaces.behavior.video.video_utils import ( VideoCaptureContext, ) + ndx_pose_version = version("ndx-pose") + + if version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2.0"): + raise ImportError("The ndx-pose version must be less than 0.2.0.") + self._vc = VideoCaptureContext self.file_path = Path(file_path) diff --git a/tests/test_on_data/behavior/test_behavior_interfaces.py b/tests/test_on_data/behavior/test_behavior_interfaces.py index 0b5c63376..00fcf7c5d 100644 --- a/tests/test_on_data/behavior/test_behavior_interfaces.py +++ b/tests/test_on_data/behavior/test_behavior_interfaces.py @@ -40,7 +40,16 @@ except ImportError: from setup_paths import BEHAVIOR_DATA_PATH, OUTPUT_PATH +from importlib.metadata import version +from packaging import version as version_parse + +ndx_pose_version = version("ndx-pose") + + +@pytest.mark.skipif( + version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2" +) class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMixin): data_interface_cls = LightningPoseDataInterface interface_kwargs = dict( @@ -156,6 +165,9 @@ def check_read_nwb(self, nwbfile_path: str): assert_array_equal(pose_estimation_series.data[:], test_data[["x", "y"]].values) +@pytest.mark.skipif( + version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2" +) class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, TemporalAlignmentMixin): data_interface_cls = LightningPoseDataInterface interface_kwargs = dict( @@ -363,7 +375,7 @@ def check_renaming_instance(self, nwbfile_path: str): with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io: nwbfile = io.read() assert "behavior" in nwbfile.processing - assert "PoseEstimation" not in nwbfile.processing["behavior"].data_interfaces + assert "PoseEstimationDeepLabCut" not in nwbfile.processing["behavior"].data_interfaces assert custom_container_name in nwbfile.processing["behavior"].data_interfaces def check_read_nwb(self, nwbfile_path: str): @@ -371,9 +383,11 @@ def check_read_nwb(self, nwbfile_path: str): nwbfile = io.read() assert "behavior" in nwbfile.processing processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces - assert "PoseEstimation" in processing_module_interfaces + assert "PoseEstimationDeepLabCut" in processing_module_interfaces - pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series + pose_estimation_series_in_nwb = processing_module_interfaces[ + "PoseEstimationDeepLabCut" + ].pose_estimation_series expected_pose_estimation_series = ["ind1_leftear", "ind1_rightear", "ind1_snout", "ind1_tailbase"] expected_pose_estimation_series_are_in_nwb_file = [ @@ -449,9 +463,11 @@ def check_read_nwb(self, nwbfile_path: str): nwbfile = io.read() assert "behavior" in nwbfile.processing processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces - assert "PoseEstimation" in processing_module_interfaces + assert "PoseEstimationDeepLabCut" in processing_module_interfaces - pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series + pose_estimation_series_in_nwb = processing_module_interfaces[ + "PoseEstimationDeepLabCut" + ].pose_estimation_series expected_pose_estimation_series = ["ind1_leftear", "ind1_rightear", "ind1_snout", "ind1_tailbase"] expected_pose_estimation_series_are_in_nwb_file = [ @@ -500,9 +516,11 @@ def check_custom_timestamps(self, nwbfile_path: str): nwbfile = io.read() assert "behavior" in nwbfile.processing processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces - assert "PoseEstimation" in processing_module_interfaces + assert "PoseEstimationDeepLabCut" in processing_module_interfaces - pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series + pose_estimation_series_in_nwb = processing_module_interfaces[ + "PoseEstimationDeepLabCut" + ].pose_estimation_series for pose_estimation in pose_estimation_series_in_nwb.values(): pose_timestamps = pose_estimation.timestamps diff --git a/tests/test_on_data/behavior/test_lightningpose_converter.py b/tests/test_on_data/behavior/test_lightningpose_converter.py index dd93632a4..ebf2f59f5 100644 --- a/tests/test_on_data/behavior/test_lightningpose_converter.py +++ b/tests/test_on_data/behavior/test_lightningpose_converter.py @@ -1,10 +1,14 @@ import shutil import tempfile from datetime import datetime +from importlib.metadata import version from pathlib import Path from warnings import warn +import pytest from hdmf.testing import TestCase +from packaging import version +from packaging import version as version_parse from pynwb import NWBHDF5IO from pynwb.image import ImageSeries @@ -15,7 +19,12 @@ from ..setup_paths import BEHAVIOR_DATA_PATH +ndx_pose_version = version("ndx-pose") + +@pytest.mark.skipif( + version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2" +) class TestLightningPoseConverter(TestCase): @classmethod def setUpClass(cls) -> None: