Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix validation and test split not being reproducible #218

Merged
merged 18 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions crabs/detector/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ def _compute_splits(
A tuple with the train, test and validation datasets
"""

# Optionally fix the generator for a reproducible split of data
generator = None
# Optionally fix the random number generators for reproducible
# splits of data
rng_train_split, rng_val_split = None, None
if self.split_seed:
generator = torch.Generator().manual_seed(self.split_seed)
rng_train_split = torch.Generator().manual_seed(self.split_seed)
rng_val_split = torch.Generator().manual_seed(self.split_seed)

# Create dataset (combining all datasets passed)
full_dataset = CrabsCocoDetection(
Expand All @@ -189,7 +191,7 @@ def _compute_splits(
train_dataset, test_val_dataset = random_split(
full_dataset,
[self.config["train_fraction"], 1 - self.config["train_fraction"]],
generator=generator,
generator=rng_train_split,
)

# Split test/val sets from the remainder
Expand All @@ -199,6 +201,7 @@ def _compute_splits(
1 - self.config["val_over_test_fraction"],
self.config["val_over_test_fraction"],
],
generator=rng_val_split,
)

return train_dataset, test_dataset, val_dataset
Expand All @@ -216,9 +219,9 @@ def setup(self, stage: str):
Define the transforms for each split of the data and compute them.
"""
# Assign transforms
self.train_transform = self._get_train_transform()
# right now assuming validation and test get the same transforms
test_and_val_transform = self._get_test_val_transform()
self.train_transform = self._get_train_transform()
self.test_transform = test_and_val_transform
self.val_transform = test_and_val_transform

Expand Down
75 changes: 74 additions & 1 deletion crabs/detector/utils/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import datetime
import os
from pathlib import Path
from typing import Any
from typing import Any, Optional

import torch
from lightning.pytorch.loggers import MLFlowLogger

DEFAULT_ANNOTATIONS_FILENAME = "VIA_JSON_combined_coco_gen.json"
Expand Down Expand Up @@ -242,3 +243,75 @@ def slurm_logs_as_artifacts(logger: MLFlowLogger, slurm_job_id: str):
logger.run_id,
f"{log_filename}.{ext}",
)


def bbox_tensors_to_COCO_dict(
sfmig marked this conversation as resolved.
Show resolved Hide resolved
bbox_tensors: torch.Tensor, list_img_filenames: Optional[list] = None
) -> dict:
"""Convert list of bounding boxes as tensors to COCO-crab format.
Parameters
----------
bbox_tensors : list[torch.Tensor]
List of tensors with bounding boxes for each image.
Each element of the list corresponds to an image, and each tensor in
the list contains the bounding boxes for that image. Each tensor is of
size (n, 4) where n is the number of bounding boxes in the image.
The 4 values in the second dimension are x_min, y_min, x_max, y_max.
list_img_filenames : list[str], optional
List of image filenames. If not provided, filenames are generated
as "frame_{i:04d}.png" where i is the 0-based index of the image in the
list of bounding boxes.
Returns
-------
dict
COCO format dictionary with bounding boxes.
"""
# Create list of image filenames if not provided
if list_img_filenames is None:
list_img_filenames = [
f"frame_{i:04d}.png" for i in range(len(bbox_tensors))
]

# Create list of dictionaries for images
list_images: list[dict] = []
for img_id, img_name in enumerate(list_img_filenames):
image_entry = {
"id": img_id + 1, # 1-based
"width": 0,
"height": 0,
"file_name": img_name,
}
list_images.append(image_entry)

# Create list of dictionaries for annotations
list_annotations: list[dict] = []
for img_id, img_bboxes in enumerate(bbox_tensors):
# loop thru bboxes in image
for bbox_row in img_bboxes:
x_min, y_min, x_max, y_max = bbox_row.numpy().tolist()
# we convert the array to list to make it JSON serializable

annotation = {
"id": len(list_annotations)
+ 1, # 1-based by default in VIA tool
"image_id": img_id + 1, # 1-based by default in VIA tool
"category_id": 1,
"bbox": [x_min, y_min, x_max - x_min, y_max - y_min],
"area": (x_max - x_min) * (y_max - y_min),
"iscrowd": 0,
}

list_annotations.append(annotation)

# Create COCO dictionary
coco_dict = {
"info": {},
"licenses": [],
"categories": [{"id": 1, "name": "crab", "supercategory": "animal"}],
"images": list_images,
"annotations": list_annotations,
}

return coco_dict
Loading