Skip to content

Commit

Permalink
update sc_video_utils.py, classify_utils.py and classify_mmdetection_…
Browse files Browse the repository at this point in the history
…mitosis_eval.py
  • Loading branch information
dummyindex committed Nov 2, 2023
1 parent 0faf8b2 commit edcf0a9
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 11 deletions.
7 changes: 4 additions & 3 deletions livecellx/core/sc_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,16 @@ def video_frames_and_masks_from_sample(
return video_frames, video_frame_masks


def combine_video_frames_and_masks(video_frames, video_frame_masks, edt_transform=True):
def combine_video_frames_and_masks(video_frames, video_frame_masks, edt_transform=True, is_gray=False):
"""returns a list of combined video frames and masks, each item contains a 3-channel image with first channel as frame and second channel as mask"""
if edt_transform:
video_frame_masks = [label_mask_to_edt_mask(x) for x in video_frame_masks]

res_frames = []
for frame, mask in zip(video_frames, video_frame_masks):
frame = rgb_img_to_gray(frame)
mask = rgb_img_to_gray(mask)
if not is_gray:
frame = rgb_img_to_gray(frame)
mask = rgb_img_to_gray(mask)
res_frame = np.array([frame, mask, mask]).transpose(1, 2, 0)
res_frames.append(res_frame)
return res_frames
40 changes: 40 additions & 0 deletions livecellx/track/classify_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,43 @@ def gen_inference_sctc_sample_videos(
fps=fps,
)
return saved_sample_info_df


def save_data_input(data_input, file_path):
from livecellx.core.sc_video_utils import gen_mp4_from_frames

imgs = data_input[1][2].detach().cpu().numpy() # 8 x 224 x 224
masks = data_input[1][0].detach().cpu().numpy() # 8 x 224 x 224
imgs = list(imgs)
masks = list(masks)
imgs = [normalize_img_to_uint8(img) for img in imgs]
masks = [normalize_img_to_uint8(mask) for mask in masks]

# already edt transformed
frames = combine_video_frames_and_masks(imgs, masks, is_gray=True, edt_transform=False)
gen_mp4_from_frames(frames, file_path)


def is_decord_invalid_video(path):
"""More information: https://github.com/dmlc/decord/issues/150"""
import decord

reader = decord.VideoReader(str(path))
reader.seek(0)
imgs = list()
frame_inds = range(0, len(reader))
for idx in frame_inds:
reader.seek(idx)
frame = reader.next()
imgs.append(frame.asnumpy())
frame = frame.asnumpy()

num_channels = frame.shape[-1]
if num_channels != 3:
print("invalid video for decord (https://github.com/dmlc/decord/issues/150): ", path)
return True
# fig, axes = plt.subplots(1, num_channels, figsize=(20, 10))
# for i in range(num_channels):
# axes[i].imshow(frame[:, :, i])
# plt.show()
return False
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from livecellx.core.datasets import LiveCellImageDataset, SingleImageDataset
from skimage import measure
from livecellx.core import SingleCellTrajectory, SingleCellStatic
from livecellx.core.sc_video_utils import gen_mp4_from_frames, combine_video_frames_and_masks
from livecellx.preprocess.utils import normalize_img_to_uint8
from livecellx.core.io_utils import save_png

# import detectron2
# from detectron2.utils.logger import setup_logger
Expand Down Expand Up @@ -78,10 +81,20 @@
)
parser.add_argument("--device", type=str, help="Device to use", default="cuda:0")
parser.add_argument("--add-random-crop", action="store_true", help="Add random crop to the pipeline")
parser.add_argument("--is-tsn", action="store_true", help="if it is tsn model")
parser.add_argument(
"--raw-video-treat-as-negative",
action="store_true",
help="In combined ver, raw videos input labels should be all WRONG",
default=False,
)


args = parser.parse_args()

out_wrong_video_dir = Path(args.out_dir) / "wrong_videos"
out_wrong_video_dir.mkdir(parents=True, exist_ok=True)

print(
"#" * 40,
"args",
Expand Down Expand Up @@ -116,7 +129,38 @@

print("test data frame:", test_data_df.columns[:2])

if args.add_random_crop:
if args.add_random_crop and args.is_tsn:
model.cfg.test_pipeline = [
dict(io_backend="disk", type="DecordInit"),
dict(clip_len=3, frame_interval=1, num_clips=3, test_mode=True, type="SampleFrames"),
dict(type="DecordDecode"),
dict(
scale_range=(
224,
300,
),
type="RandomRescale",
),
dict(size=224, type="RandomCrop"),
dict(
scale=(
-1,
224,
),
type="Resize",
),
dict(
scale=(
-1,
256,
),
type="Resize",
),
dict(crop_size=224, type="TenCrop"),
dict(input_format="NCHW", type="FormatShape"),
dict(type="PackActionInputs"),
]
elif args.add_random_crop:
model.cfg.test_pipeline = [
dict(io_backend="disk", type="DecordInit"),
dict(clip_len=8, frame_interval=1, num_clips=1, test_mode=True, type="SampleFrames"),
Expand Down Expand Up @@ -182,6 +226,8 @@
for row_ in tqdm(all_rows):
idx, row_series = row_
video_path = str(video_dir / row_series["path"])
input_frame_type = row_series["frame_type"]
padding_pixels = row_series["padding_pixels"]
try:
model.zero_grad()
results, data, grad = livecellx.track.timesformer_inference.inference_recognizer(
Expand All @@ -194,21 +240,52 @@

traceback.print_exc()
print("video_path:", video_path)
continue
# predicted_label = results.pred_labels.item.cpu().numpy()[0]

predicted_label = results.pred_label.item()
test_gt_label = row_series["label"]
if test_gt_label not in gt2total:
gt2total[test_gt_label] = 0
gt2correct[test_gt_label] = 0
if args.raw_video_treat_as_negative:
test_gt_label = 2 # no focus!
else:
test_gt_label = row_series["label"]
if test_gt_label not in gt2total:
gt2total[test_gt_label] = 0
gt2correct[test_gt_label] = 0
gt2total[test_gt_label] += 1
row_series = row_series.copy()
row_series["predicted_label"] = predicted_label
row_series["true_label"] = test_gt_label
row_series["correct"] = predicted_label == test_gt_label
all_predictions.append(row_series)

if predicted_label != test_gt_label:
print("wrong prediction:", video_path, "predicted_label:", predicted_label, "gt_label:", test_gt_label)
wrong_predictions.append(row_series)

data_input = data["inputs"][0] # 3 x 3 x 8 x 224 x 224
if not args.is_tsn:
# timeSformer
imgs = data_input[1][2].detach().cpu().numpy() # 8 x 224 x 224
masks = data_input[1][0].detach().cpu().numpy() # 8 x 224 x 224
imgs = list(imgs)
masks = list(masks)
imgs = [normalize_img_to_uint8(img) for img in imgs]
masks = [normalize_img_to_uint8(mask) for mask in masks]

# already edt transformed, so set to false
frames = combine_video_frames_and_masks(imgs, masks, is_gray=True, edt_transform=False)
gen_mp4_from_frames(
frames, out_wrong_video_dir / f"wrong_{idx}-{input_frame_type}-padding_{padding_pixels}.mp4", fps=3
)
else:
# tsn
imgs = data_input.detach().cpu().numpy() # 90 x 3 x h x w
imgs = list(imgs)
imgs = [normalize_img_to_uint8(img) for img in imgs]
tmp_out_sample_dir = out_wrong_video_dir / f"wrong_{idx}-{input_frame_type}-padding_{padding_pixels}"
tmp_out_sample_dir.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(imgs):
save_png(tmp_out_sample_dir / f"sample_dim0-{i}.png", img.swapaxes(0, 2), mode="RGB")
else:
gt2correct[test_gt_label] += 1

Expand Down Expand Up @@ -242,7 +319,7 @@ def report_classification_metrics(true_labels, predicted_labels):
f1 = f1_score(true_labels, predicted_labels, average="weighted")

# generate a classification report
report = classification_report(true_labels, predicted_labels)
report = classification_report(true_labels, predicted_labels, digits=4)

# print the metrics and classification report
print(f"Accuracy: {accuracy:.2f}")
Expand All @@ -260,8 +337,14 @@ def report_classification_metrics(true_labels, predicted_labels):
)

# %%
indexer = all_predictions_df["frame_type"] == "combined"
report_classification_metrics(all_predictions_df[indexer]["true_label"], all_predictions_df[indexer]["predicted_label"])
frame_types = all_predictions_df["frame_type"].unique()

for frame_type in frame_types:
indexer = all_predictions_df["frame_type"] == frame_type
print("#" * 40, "frame_type:", frame_type, "#" * 40)
report_classification_metrics(
all_predictions_df[indexer]["true_label"], all_predictions_df[indexer]["predicted_label"]
)

# %%
all_predictions_df.to_csv(out_dir / "all_predictions.csv", index=False)
Expand Down

0 comments on commit edcf0a9

Please sign in to comment.