diff --git a/notebooks/scripts/mmdetection_classify/centroid_aug.py b/notebooks/scripts/mmdetection_classify/centroid_aug.py index a129e7a..ed2ef70 100644 --- a/notebooks/scripts/mmdetection_classify/centroid_aug.py +++ b/notebooks/scripts/mmdetection_classify/centroid_aug.py @@ -62,14 +62,24 @@ def __call__(self, results, exact_match_rg_channel=False): properties.sort(key=lambda x: x.area, reverse=True) props = properties[0] centroid = props.centroid + # centroid_box = np.array( + # [ + # centroid[0] - centroid_box_width / 2, + # centroid[1] - centroid_box_width / 2, + # centroid[0] + centroid_box_width / 2, + # centroid[1] + centroid_box_width / 2, + # ] + # ) + centroid_box = np.array( [ - centroid[0] - centroid_box_width / 2, centroid[1] - centroid_box_width / 2, - centroid[0] + centroid_box_width / 2, + centroid[0] - centroid_box_width / 2, centroid[1] + centroid_box_width / 2, + centroid[0] + centroid_box_width / 2, ] ) + # Clip to image size centroid_box[0] = max(0, centroid_box[0]) centroid_box[1] = max(0, centroid_box[1])