diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index 2e334456a..d15c4491f 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -775,6 +775,7 @@ def make_viz_pipeline(self, data_provider: Provider) -> Pipeline: provider=data_provider, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, @@ -1250,6 +1251,7 @@ def make_viz_pipeline(self, data_provider: Provider) -> Pipeline: provider=data_provider, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index c27382e52..7e0de5aa2 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -727,11 +727,18 @@ class CentroidCropGroundTruth(tf.keras.layers.Layer): Attributes: crop_size: The length of the square box to extract around each centroid. + input_scale: Float indicating if the images should be resized before being + passed to the model. """ - def __init__(self, crop_size: int): + def __init__( + self, + crop_size: int, + input_scale: float = 1.0, + ): super().__init__() self.crop_size = crop_size + self.input_scale = input_scale def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Return the ground truth instance crops. @@ -758,6 +765,9 @@ def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """ # Pull out data from example. full_imgs = example_gt["image"] + if self.input_scale != 1.0: + full_imgs = sleap.nn.data.resizing.resize_image(full_imgs, self.input_scale) + example_gt["centroids"] *= self.input_scale crop_sample_inds = example_gt["centroids"].value_rowids() # (n_peaks,) n_peaks = tf.shape(crop_sample_inds)[0] # total number of peaks in the batch centroid_points = example_gt["centroids"].flat_values # (n_peaks, 2) @@ -927,11 +937,12 @@ def __init__( self.ensure_grayscale = ensure_grayscale self.ensure_float = ensure_float - def preprocess(self, imgs: tf.Tensor) -> tf.Tensor: + def preprocess(self, imgs: tf.Tensor, resize_img: bool = True) -> tf.Tensor: """Apply all preprocessing operations configured for this layer. Args: imgs: A batch of images as a tensor. + resize_img: Bool to indicate if the images should be resized. Returns: The input tensor after applying preprocessing operations. The tensor will @@ -947,7 +958,7 @@ def preprocess(self, imgs: tf.Tensor) -> tf.Tensor: if self.ensure_float: imgs = sleap.nn.data.normalization.ensure_float(imgs) - if self.input_scale != 1.0: + if resize_img and self.input_scale != 1.0: imgs = sleap.nn.data.resizing.resize_image(imgs, self.input_scale) if self.pad_to_stride > 1: @@ -1638,6 +1649,9 @@ class CentroidCrop(InferenceLayer): crop_size: Integer scalar specifying the height/width of the centered crops. input_scale: Float indicating if the images should be resized before being passed to the model. + precrop_resize: Float indicating the factor by which the original images + (not images resized for centroid model) should be resized before cropping. + Note: this resize only after getting the predictions for centroid model. pad_to_stride: If not 1, input image will be paded to ensure that it is divisible by this value (after scaling). This should be set to the max stride of the model. @@ -1678,6 +1692,7 @@ def __init__( keras_model: tf.keras.Model, crop_size: int, input_scale: float = 1.0, + precrop_resize: float = 1.0, pad_to_stride: int = 1, output_stride: Optional[int] = None, peak_threshold: float = 0.2, @@ -1698,6 +1713,7 @@ def __init__( ) self.crop_size = crop_size + self.precrop_resize = precrop_resize self.confmaps_ind = confmaps_ind self.offsets_ind = offsets_ind @@ -1816,6 +1832,13 @@ def call(self, inputs): # See: https://github.com/tensorflow/tensorflow/issues/6720 centroid_points = (centroid_points / self.input_scale) + 0.5 + # resize full images + if self.precrop_resize != 1.0: + full_imgs = sleap.nn.data.resizing.resize_image( + full_imgs, self.precrop_resize + ) + centroid_points *= self.precrop_resize + # Store crop offsets. crop_offsets = centroid_points - (self.crop_size / 2) @@ -1956,6 +1979,11 @@ class FindInstancePeaks(InferenceLayer): centered instance confidence maps. input_scale: Float indicating if the images should be resized before being passed to the model. + resize_input_image: Bool indicating if the crops should be resized. If + `CentroidCropGroundTruth` or `CentroidCrop` is used along with `FindInstancePeaks`, + then the images are resized in the `CentroidCropGroundTruth` or `CentroidCrop` + before cropping and this is set to `False`. However, the output keypoints + are adjusted to the actual scale with the `input_scaling` argument. output_stride: Output stride of the model, denoting the scale of the output confidence maps relative to the images (after input scaling). This is used for adjusting the peak coordinates to the image grid. This will be inferred @@ -1986,6 +2014,7 @@ def __init__( self, keras_model: tf.keras.Model, input_scale: float = 1.0, + resize_input_image: bool = True, output_stride: Optional[int] = None, peak_threshold: float = 0.2, refinement: Optional[str] = "local", @@ -1998,6 +2027,7 @@ def __init__( super().__init__( keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs ) + self.resize_input_image = resize_input_image self.peak_threshold = peak_threshold self.refinement = refinement self.integral_patch_size = integral_patch_size @@ -2095,7 +2125,7 @@ def call( crop_sample_inds = tf.range(samples, dtype=tf.int32) # Preprocess inputs (scaling, padding, colorspace, int to float). - crops = self.preprocess(crops) + crops = self.preprocess(crops, resize_img=self.resize_input_image) # Network forward pass. out = self.keras_model(crops) @@ -2142,7 +2172,9 @@ def call( if "crop_offsets" in inputs: # Flatten (samples, ?, 2) -> (n_peaks, 2). crop_offsets = inputs["crop_offsets"].merge_dims(0, 1) - peak_points = peak_points + tf.expand_dims(crop_offsets, axis=1) + peak_points = peak_points + ( + tf.expand_dims(crop_offsets, axis=1) / self.input_scale + ) # Group peaks by sample (samples, ?, nodes, 2). peaks = tf.RaggedTensor.from_value_rowids( @@ -2345,7 +2377,7 @@ def _initialize_inference_model(self): if use_gt_centroid: centroid_crop_layer = CentroidCropGroundTruth( - crop_size=self.confmap_config.data.instance_cropping.crop_size + crop_size=self.confmap_config.data.instance_cropping.crop_size, ) else: if use_gt_confmap: @@ -2356,6 +2388,7 @@ def _initialize_inference_model(self): keras_model=self.centroid_model.keras_model, crop_size=crop_size, input_scale=self.centroid_config.data.preprocessing.input_scaling, + precrop_resize=1.0, pad_to_stride=self.centroid_config.data.preprocessing.pad_to_stride, output_stride=self.centroid_config.model.heads.centroid.output_stride, peak_threshold=self.peak_threshold, @@ -2377,7 +2410,14 @@ def _initialize_inference_model(self): refinement="integral" if self.integral_refinement else "local", integral_patch_size=self.integral_patch_size, return_confmaps=False, + resize_input_image=False, ) + if use_gt_centroid: + centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling + else: + centroid_crop_layer.precrop_resize = ( + cfg.data.preprocessing.input_scaling + ) self.inference_model = TopDownInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer @@ -3834,6 +3874,11 @@ class TopDownMultiClassFindPeaks(InferenceLayer): centered instance confidence maps and classification. input_scale: Float indicating if the images should be resized before being passed to the model. + resize_input_image: Bool indicating if the crops should be resized. If + `CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the + images are resized in the `CentroidCropGroundTruth` and this is set to `False`. + However, the output keypoints are adjusted to the actual scale with the + `input_scaling` argument. output_stride: Output stride of the model, denoting the scale of the output confidence maps relative to the images (after input scaling). This is used for adjusting the peak coordinates to the image grid. This will be inferred @@ -3875,6 +3920,7 @@ def __init__( self, keras_model: tf.keras.Model, input_scale: float = 1.0, + resize_input_image: bool = True, output_stride: Optional[int] = None, peak_threshold: float = 0.2, refinement: Optional[str] = "local", @@ -3890,6 +3936,7 @@ def __init__( super().__init__( keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs ) + self.resize_input_image = resize_input_image self.peak_threshold = peak_threshold self.refinement = refinement self.integral_patch_size = integral_patch_size @@ -4007,7 +4054,7 @@ def call( crop_sample_inds = tf.range(samples, dtype=tf.int32) # Preprocess inputs (scaling, padding, colorspace, int to float). - crops = self.preprocess(crops) + crops = self.preprocess(crops, resize_img=self.resize_input_image) # Network forward pass. out = self.keras_model(crops) @@ -4256,7 +4303,10 @@ def _initialize_inference_model(self): refinement="integral" if self.integral_refinement else "local", integral_patch_size=self.integral_patch_size, return_confmaps=False, + resize_input_image=False, ) + if use_gt_centroid: + centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling self.inference_model = TopDownMultiClassInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer diff --git a/sleap/nn/training.py b/sleap/nn/training.py index c3692637c..7d32dd797 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -1315,10 +1315,11 @@ def _setup_visualization(self): # Create an instance peak finding layer. find_peaks = FindInstancePeaks( keras_model=self.keras_model, - input_scale=self.config.data.preprocessing.input_scaling, + input_scale=1.0, peak_threshold=0.2, refinement="local", return_confmaps=True, + resize_input_image=False, ) def visualize_example(example): @@ -1755,10 +1756,11 @@ def _setup_visualization(self): # Create an instance peak finding layer. find_peaks = FindInstancePeaks( keras_model=self.keras_model, - input_scale=self.config.data.preprocessing.input_scaling, + input_scale=1.0, peak_threshold=0.2, refinement="local", return_confmaps=True, + resize_input_image=False, ) def visualize_example(example): diff --git a/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/best_model.h5 b/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/best_model.h5 new file mode 100644 index 000000000..c379b0860 Binary files /dev/null and b/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/best_model.h5 differ diff --git a/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/initial_config.json b/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/initial_config.json new file mode 100644 index 000000000..f622ec415 --- /dev/null +++ b/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/initial_config.json @@ -0,0 +1,164 @@ +{ + "data": { + "labels": { + "training_labels": null, + "validation_labels": null, + "validation_fraction": 0.1, + "test_labels": null, + "split_by_inds": false, + "training_inds": null, + "validation_inds": null, + "test_inds": null, + "search_path_hints": [], + "skeletons": [] + }, + "preprocessing": { + "ensure_rgb": false, + "ensure_grayscale": false, + "imagenet_mode": null, + "input_scaling": 0.5, + "pad_to_stride": null, + "resize_and_pad_to_target": true, + "target_height": null, + "target_width": null + }, + "instance_cropping": { + "center_on_part": null, + "crop_size": null, + "crop_size_detection_padding": 16 + } + }, + "model": { + "backbone": { + "leap": null, + "unet": { + "stem_stride": null, + "max_stride": 8, + "output_stride": 2, + "filters": 16, + "filters_rate": 1.5, + "middle_block": true, + "up_interpolate": false, + "stacks": 1 + }, + "hourglass": null, + "resnet": null, + "pretrained_encoder": null + }, + "heads": { + "single_instance": null, + "centroid": null, + "centered_instance": { + "anchor_part": null, + "part_names": null, + "sigma": 1.5, + "output_stride": 2, + "loss_weight": 1.0, + "offset_refinement": true + }, + "multi_instance": null, + "multi_class_bottomup": null, + "multi_class_topdown": null + }, + "base_checkpoint": null + }, + "optimization": { + "preload_data": false, + "augmentation_config": { + "rotate": false, + "rotation_min_angle": -180.0, + "rotation_max_angle": 180.0, + "translate": false, + "translate_min": -5, + "translate_max": 5, + "scale": false, + "scale_min": 0.9, + "scale_max": 1.1, + "uniform_noise": false, + "uniform_noise_min_val": 0.0, + "uniform_noise_max_val": 10.0, + "gaussian_noise": false, + "gaussian_noise_mean": 5.0, + "gaussian_noise_stddev": 1.0, + "contrast": false, + "contrast_min_gamma": 0.5, + "contrast_max_gamma": 2.0, + "brightness": false, + "brightness_min_val": 0.0, + "brightness_max_val": 10.0, + "random_crop": false, + "random_crop_height": 256, + "random_crop_width": 256, + "random_flip": false, + "flip_horizontal": true + }, + "online_shuffling": true, + "shuffle_buffer_size": 128, + "prefetch": true, + "batch_size": 4, + "batches_per_epoch": null, + "min_batches_per_epoch": 100, + "val_batches_per_epoch": null, + "min_val_batches_per_epoch": 1, + "epochs": 10, + "optimizer": "adam", + "initial_learning_rate": 0.0001, + "learning_rate_schedule": { + "reduce_on_plateau": true, + "reduction_factor": 0.5, + "plateau_min_delta": 1e-06, + "plateau_patience": 5, + "plateau_cooldown": 3, + "min_learning_rate": 1e-08 + }, + "hard_keypoint_mining": { + "online_mining": false, + "hard_to_easy_ratio": 2.0, + "min_hard_keypoints": 2, + "max_hard_keypoints": null, + "loss_scale": 5.0 + }, + "early_stopping": { + "stop_training_on_plateau": true, + "plateau_min_delta": 1e-06, + "plateau_patience": 10 + } + }, + "outputs": { + "save_outputs": true, + "run_name": "minimal_instance.UNet.centered_instance_with_scaling", + "run_name_prefix": "", + "run_name_suffix": null, + "runs_folder": "models", + "tags": [], + "save_visualizations": false, + "keep_viz_images": false, + "zip_outputs": false, + "log_to_csv": true, + "checkpointing": { + "initial_model": false, + "best_model": true, + "every_epoch": false, + "latest_model": false, + "final_model": false + }, + "tensorboard": { + "write_logs": false, + "loss_frequency": "epoch", + "architecture_graph": false, + "profile_graph": false, + "visualizations": true + }, + "zmq": { + "subscribe_to_controller": false, + "controller_address": "tcp://127.0.0.1:9000", + "controller_polling_timeout": 10, + "publish_updates": false, + "publish_address": "tcp://127.0.0.1:9001" + } + }, + "name": "", + "description": "", + "sleap_version": "1.4.1", + "filename": "models\\minimal_instance.UNet.centered_instance_with_scaling\\initial_config.json" +} \ No newline at end of file diff --git a/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/training_config.json b/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/training_config.json new file mode 100644 index 000000000..af563c1bf --- /dev/null +++ b/tests/data/models/minimal_instance.UNet.centered_instance_with_scaling/training_config.json @@ -0,0 +1,226 @@ +{ + "data": { + "labels": { + "training_labels": null, + "validation_labels": null, + "validation_fraction": 0.1, + "test_labels": null, + "split_by_inds": false, + "training_inds": null, + "validation_inds": null, + "test_inds": null, + "search_path_hints": [ + "" + ], + "skeletons": [ + { + "directed": true, + "graph": { + "name": "Skeleton-0", + "num_edges_inserted": 1 + }, + "links": [ + { + "edge_insert_idx": 0, + "key": 0, + "source": { + "py/object": "sleap.skeleton.Node", + "py/state": { + "py/tuple": [ + "A", + 1.0 + ] + } + }, + "target": { + "py/object": "sleap.skeleton.Node", + "py/state": { + "py/tuple": [ + "B", + 1.0 + ] + } + }, + "type": { + "py/reduce": [ + { + "py/type": "sleap.skeleton.EdgeType" + }, + { + "py/tuple": [ + 1 + ] + } + ] + } + } + ], + "multigraph": true, + "nodes": [ + { + "id": { + "py/id": 1 + } + }, + { + "id": { + "py/id": 2 + } + } + ] + } + ] + }, + "preprocessing": { + "ensure_rgb": false, + "ensure_grayscale": false, + "imagenet_mode": null, + "input_scaling": 0.5, + "pad_to_stride": 1, + "resize_and_pad_to_target": true, + "target_height": 384, + "target_width": 384 + }, + "instance_cropping": { + "center_on_part": null, + "crop_size": 56, + "crop_size_detection_padding": 16 + } + }, + "model": { + "backbone": { + "leap": null, + "unet": { + "stem_stride": null, + "max_stride": 8, + "output_stride": 2, + "filters": 16, + "filters_rate": 1.5, + "middle_block": true, + "up_interpolate": false, + "stacks": 1 + }, + "hourglass": null, + "resnet": null, + "pretrained_encoder": null + }, + "heads": { + "single_instance": null, + "centroid": null, + "centered_instance": { + "anchor_part": null, + "part_names": [ + "A", + "B" + ], + "sigma": 1.5, + "output_stride": 2, + "loss_weight": 1.0, + "offset_refinement": true + }, + "multi_instance": null, + "multi_class_bottomup": null, + "multi_class_topdown": null + }, + "base_checkpoint": null + }, + "optimization": { + "preload_data": false, + "augmentation_config": { + "rotate": false, + "rotation_min_angle": -180.0, + "rotation_max_angle": 180.0, + "translate": false, + "translate_min": -5, + "translate_max": 5, + "scale": false, + "scale_min": 0.9, + "scale_max": 1.1, + "uniform_noise": false, + "uniform_noise_min_val": 0.0, + "uniform_noise_max_val": 10.0, + "gaussian_noise": false, + "gaussian_noise_mean": 5.0, + "gaussian_noise_stddev": 1.0, + "contrast": false, + "contrast_min_gamma": 0.5, + "contrast_max_gamma": 2.0, + "brightness": false, + "brightness_min_val": 0.0, + "brightness_max_val": 10.0, + "random_crop": false, + "random_crop_height": 256, + "random_crop_width": 256, + "random_flip": false, + "flip_horizontal": true + }, + "online_shuffling": true, + "shuffle_buffer_size": 128, + "prefetch": true, + "batch_size": 4, + "batches_per_epoch": 100, + "min_batches_per_epoch": 100, + "val_batches_per_epoch": 1, + "min_val_batches_per_epoch": 1, + "epochs": 10, + "optimizer": "adam", + "initial_learning_rate": 0.0001, + "learning_rate_schedule": { + "reduce_on_plateau": true, + "reduction_factor": 0.5, + "plateau_min_delta": 1e-06, + "plateau_patience": 5, + "plateau_cooldown": 3, + "min_learning_rate": 1e-08 + }, + "hard_keypoint_mining": { + "online_mining": false, + "hard_to_easy_ratio": 2.0, + "min_hard_keypoints": 2, + "max_hard_keypoints": null, + "loss_scale": 5.0 + }, + "early_stopping": { + "stop_training_on_plateau": true, + "plateau_min_delta": 1e-06, + "plateau_patience": 10 + } + }, + "outputs": { + "save_outputs": true, + "run_name": "minimal_instance.UNet.centered_instance_with_scaling", + "run_name_prefix": "", + "run_name_suffix": "", + "runs_folder": "models", + "tags": [], + "save_visualizations": false, + "keep_viz_images": false, + "zip_outputs": false, + "log_to_csv": true, + "checkpointing": { + "initial_model": false, + "best_model": true, + "every_epoch": false, + "latest_model": false, + "final_model": false + }, + "tensorboard": { + "write_logs": false, + "loss_frequency": "epoch", + "architecture_graph": false, + "profile_graph": false, + "visualizations": true + }, + "zmq": { + "subscribe_to_controller": false, + "controller_address": "tcp://127.0.0.1:9000", + "controller_polling_timeout": 10, + "publish_updates": false, + "publish_address": "tcp://127.0.0.1:9001" + } + }, + "name": "", + "description": "", + "sleap_version": "1.4.1", + "filename": "models\\minimal_instance.UNet.centered_instance_with_scaling\\training_config.json" +} \ No newline at end of file diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 2f7c168fb..c7edb1501 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -12,6 +12,11 @@ def min_centered_instance_model_path(): return "tests/data/models/minimal_instance.UNet.centered_instance" +@pytest.fixture +def min_centered_instance_with_scaling_model_path(): + return "tests/data/models/minimal_instance.UNet.centered_instance_with_scaling" + + @pytest.fixture def min_bottomup_model_path(): return "tests/data/models/minimal_instance.UNet.bottomup" diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 0a978de0a..c7ebf7983 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -705,6 +705,55 @@ def test_topdown_predictor_centered_instance( assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5) +def test_topdown_predictor_centered_instance_with_scaling( + min_labels, min_centered_instance_with_scaling_model_path +): + predictor = TopDownPredictor.from_trained_models( + confmap_model_path=min_centered_instance_with_scaling_model_path + ) + + predictor.verbosity = "none" + labels_pr = predictor.predict(min_labels) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 2 + + assert predictor.is_grayscale == True + + points_gt = np.concatenate( + [min_labels[0][0].numpy(), min_labels[0][1].numpy()], axis=0 + ) + points_pr = np.concatenate( + [labels_pr[0][0].numpy(), labels_pr[0][1].numpy()], axis=0 + ) + inds1, inds2 = sleap.nn.utils.match_points(points_gt, points_pr) + assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5) + + +def test_topdown_predictor_centroid_centered_instance_with_scaling( + min_labels, min_centered_instance_with_scaling_model_path, min_centroid_model_path +): + predictor = TopDownPredictor.from_trained_models( + centroid_model_path=min_centroid_model_path, + confmap_model_path=min_centered_instance_with_scaling_model_path, + ) + + predictor.verbosity = "none" + labels_pr = predictor.predict(min_labels) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 2 + + assert predictor.is_grayscale == True + + points_gt = np.concatenate( + [min_labels[0][0].numpy(), min_labels[0][1].numpy()], axis=0 + ) + points_pr = np.concatenate( + [labels_pr[0][0].numpy(), labels_pr[0][1].numpy()], axis=0 + ) + inds1, inds2 = sleap.nn.utils.match_points(points_gt, points_pr) + assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5) + + def test_topdown_predictor_centered_instance_high_threshold( min_labels, min_centered_instance_model_path ):