From edcf0a9d8603aed9ae9568c44f408c67d4bd2c6e Mon Sep 17 00:00:00 2001 From: dummyindex Date: Thu, 2 Nov 2023 17:57:36 -0400 Subject: [PATCH] update sc_video_utils.py, classify_utils.py and classify_mmdetection_mitosis_eval.py --- livecellx/core/sc_video_utils.py | 7 +- livecellx/track/classify_utils.py | 40 ++++++++ .../classify_mmdetection_mitosis_eval.py | 99 +++++++++++++++++-- 3 files changed, 135 insertions(+), 11 deletions(-) diff --git a/livecellx/core/sc_video_utils.py b/livecellx/core/sc_video_utils.py index 8bf83ea..3c9b078 100644 --- a/livecellx/core/sc_video_utils.py +++ b/livecellx/core/sc_video_utils.py @@ -239,15 +239,16 @@ def video_frames_and_masks_from_sample( return video_frames, video_frame_masks -def combine_video_frames_and_masks(video_frames, video_frame_masks, edt_transform=True): +def combine_video_frames_and_masks(video_frames, video_frame_masks, edt_transform=True, is_gray=False): """returns a list of combined video frames and masks, each item contains a 3-channel image with first channel as frame and second channel as mask""" if edt_transform: video_frame_masks = [label_mask_to_edt_mask(x) for x in video_frame_masks] res_frames = [] for frame, mask in zip(video_frames, video_frame_masks): - frame = rgb_img_to_gray(frame) - mask = rgb_img_to_gray(mask) + if not is_gray: + frame = rgb_img_to_gray(frame) + mask = rgb_img_to_gray(mask) res_frame = np.array([frame, mask, mask]).transpose(1, 2, 0) res_frames.append(res_frame) return res_frames diff --git a/livecellx/track/classify_utils.py b/livecellx/track/classify_utils.py index 8807a6d..5da5632 100644 --- a/livecellx/track/classify_utils.py +++ b/livecellx/track/classify_utils.py @@ -102,3 +102,43 @@ def gen_inference_sctc_sample_videos( fps=fps, ) return saved_sample_info_df + + +def save_data_input(data_input, file_path): + from livecellx.core.sc_video_utils import gen_mp4_from_frames + + imgs = data_input[1][2].detach().cpu().numpy() # 8 x 224 x 224 + masks = data_input[1][0].detach().cpu().numpy() # 8 x 224 x 224 + imgs = list(imgs) + masks = list(masks) + imgs = [normalize_img_to_uint8(img) for img in imgs] + masks = [normalize_img_to_uint8(mask) for mask in masks] + + # already edt transformed + frames = combine_video_frames_and_masks(imgs, masks, is_gray=True, edt_transform=False) + gen_mp4_from_frames(frames, file_path) + + +def is_decord_invalid_video(path): + """More information: https://github.com/dmlc/decord/issues/150""" + import decord + + reader = decord.VideoReader(str(path)) + reader.seek(0) + imgs = list() + frame_inds = range(0, len(reader)) + for idx in frame_inds: + reader.seek(idx) + frame = reader.next() + imgs.append(frame.asnumpy()) + frame = frame.asnumpy() + + num_channels = frame.shape[-1] + if num_channels != 3: + print("invalid video for decord (https://github.com/dmlc/decord/issues/150): ", path) + return True + # fig, axes = plt.subplots(1, num_channels, figsize=(20, 10)) + # for i in range(num_channels): + # axes[i].imshow(frame[:, :, i]) + # plt.show() + return False diff --git a/notebooks/scripts/mmdetection_classify/classify_mmdetection_mitosis_eval.py b/notebooks/scripts/mmdetection_classify/classify_mmdetection_mitosis_eval.py index 5529823..9f7aa1e 100644 --- a/notebooks/scripts/mmdetection_classify/classify_mmdetection_mitosis_eval.py +++ b/notebooks/scripts/mmdetection_classify/classify_mmdetection_mitosis_eval.py @@ -17,6 +17,9 @@ from livecellx.core.datasets import LiveCellImageDataset, SingleImageDataset from skimage import measure from livecellx.core import SingleCellTrajectory, SingleCellStatic +from livecellx.core.sc_video_utils import gen_mp4_from_frames, combine_video_frames_and_masks +from livecellx.preprocess.utils import normalize_img_to_uint8 +from livecellx.core.io_utils import save_png # import detectron2 # from detectron2.utils.logger import setup_logger @@ -78,10 +81,20 @@ ) parser.add_argument("--device", type=str, help="Device to use", default="cuda:0") parser.add_argument("--add-random-crop", action="store_true", help="Add random crop to the pipeline") +parser.add_argument("--is-tsn", action="store_true", help="if it is tsn model") +parser.add_argument( + "--raw-video-treat-as-negative", + action="store_true", + help="In combined ver, raw videos input labels should be all WRONG", + default=False, +) args = parser.parse_args() +out_wrong_video_dir = Path(args.out_dir) / "wrong_videos" +out_wrong_video_dir.mkdir(parents=True, exist_ok=True) + print( "#" * 40, "args", @@ -116,7 +129,38 @@ print("test data frame:", test_data_df.columns[:2]) -if args.add_random_crop: +if args.add_random_crop and args.is_tsn: + model.cfg.test_pipeline = [ + dict(io_backend="disk", type="DecordInit"), + dict(clip_len=3, frame_interval=1, num_clips=3, test_mode=True, type="SampleFrames"), + dict(type="DecordDecode"), + dict( + scale_range=( + 224, + 300, + ), + type="RandomRescale", + ), + dict(size=224, type="RandomCrop"), + dict( + scale=( + -1, + 224, + ), + type="Resize", + ), + dict( + scale=( + -1, + 256, + ), + type="Resize", + ), + dict(crop_size=224, type="TenCrop"), + dict(input_format="NCHW", type="FormatShape"), + dict(type="PackActionInputs"), + ] +elif args.add_random_crop: model.cfg.test_pipeline = [ dict(io_backend="disk", type="DecordInit"), dict(clip_len=8, frame_interval=1, num_clips=1, test_mode=True, type="SampleFrames"), @@ -182,6 +226,8 @@ for row_ in tqdm(all_rows): idx, row_series = row_ video_path = str(video_dir / row_series["path"]) + input_frame_type = row_series["frame_type"] + padding_pixels = row_series["padding_pixels"] try: model.zero_grad() results, data, grad = livecellx.track.timesformer_inference.inference_recognizer( @@ -194,21 +240,52 @@ traceback.print_exc() print("video_path:", video_path) + continue # predicted_label = results.pred_labels.item.cpu().numpy()[0] + predicted_label = results.pred_label.item() - test_gt_label = row_series["label"] - if test_gt_label not in gt2total: - gt2total[test_gt_label] = 0 - gt2correct[test_gt_label] = 0 + if args.raw_video_treat_as_negative: + test_gt_label = 2 # no focus! + else: + test_gt_label = row_series["label"] + if test_gt_label not in gt2total: + gt2total[test_gt_label] = 0 + gt2correct[test_gt_label] = 0 gt2total[test_gt_label] += 1 row_series = row_series.copy() row_series["predicted_label"] = predicted_label row_series["true_label"] = test_gt_label row_series["correct"] = predicted_label == test_gt_label all_predictions.append(row_series) + if predicted_label != test_gt_label: print("wrong prediction:", video_path, "predicted_label:", predicted_label, "gt_label:", test_gt_label) wrong_predictions.append(row_series) + + data_input = data["inputs"][0] # 3 x 3 x 8 x 224 x 224 + if not args.is_tsn: + # timeSformer + imgs = data_input[1][2].detach().cpu().numpy() # 8 x 224 x 224 + masks = data_input[1][0].detach().cpu().numpy() # 8 x 224 x 224 + imgs = list(imgs) + masks = list(masks) + imgs = [normalize_img_to_uint8(img) for img in imgs] + masks = [normalize_img_to_uint8(mask) for mask in masks] + + # already edt transformed, so set to false + frames = combine_video_frames_and_masks(imgs, masks, is_gray=True, edt_transform=False) + gen_mp4_from_frames( + frames, out_wrong_video_dir / f"wrong_{idx}-{input_frame_type}-padding_{padding_pixels}.mp4", fps=3 + ) + else: + # tsn + imgs = data_input.detach().cpu().numpy() # 90 x 3 x h x w + imgs = list(imgs) + imgs = [normalize_img_to_uint8(img) for img in imgs] + tmp_out_sample_dir = out_wrong_video_dir / f"wrong_{idx}-{input_frame_type}-padding_{padding_pixels}" + tmp_out_sample_dir.mkdir(parents=True, exist_ok=True) + for i, img in enumerate(imgs): + save_png(tmp_out_sample_dir / f"sample_dim0-{i}.png", img.swapaxes(0, 2), mode="RGB") else: gt2correct[test_gt_label] += 1 @@ -242,7 +319,7 @@ def report_classification_metrics(true_labels, predicted_labels): f1 = f1_score(true_labels, predicted_labels, average="weighted") # generate a classification report - report = classification_report(true_labels, predicted_labels) + report = classification_report(true_labels, predicted_labels, digits=4) # print the metrics and classification report print(f"Accuracy: {accuracy:.2f}") @@ -260,8 +337,14 @@ def report_classification_metrics(true_labels, predicted_labels): ) # %% -indexer = all_predictions_df["frame_type"] == "combined" -report_classification_metrics(all_predictions_df[indexer]["true_label"], all_predictions_df[indexer]["predicted_label"]) +frame_types = all_predictions_df["frame_type"].unique() + +for frame_type in frame_types: + indexer = all_predictions_df["frame_type"] == frame_type + print("#" * 40, "frame_type:", frame_type, "#" * 40) + report_classification_metrics( + all_predictions_df[indexer]["true_label"], all_predictions_df[indexer]["predicted_label"] + ) # %% all_predictions_df.to_csv(out_dir / "all_predictions.csv", index=False)