Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Oct 30, 2023
2 parents 3b25fb6 + 5ab130e commit b10b60b
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 33 deletions.
9 changes: 5 additions & 4 deletions livecellx/track/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def make_one_cell_per_timeframe_for_class2samples(
) -> Dict:
class2samples = class2samples.copy()
if class2sample_extra_info is not None:
class2sample_extra_info = class2sample_extra_info.copy()
import copy

class2sample_extra_info = copy.deepcopy(class2sample_extra_info)
for cls in tar_classes:
tmp_samples = []
tmp_sample_extra_info = []
Expand All @@ -96,9 +98,8 @@ def make_one_cell_per_timeframe_for_class2samples(
sct_samples = make_one_cell_per_timeframe_samples(sample)
tmp_samples.extend(sct_samples)
if class2sample_extra_info is not None:
tmp_sample_extra_info.extend(
[dict(class2sample_extra_info[cls][sample_idx]) for _ in range(len(sct_samples))]
)
sample_extra_info = [dict(class2sample_extra_info[cls][sample_idx]) for _ in range(len(sct_samples))]
tmp_sample_extra_info.extend(sample_extra_info)

# check the length of sample is the same as the length of tmp_samples[-1]
sample_times = set([sc.timeframe for sc in sample])
Expand Down
170 changes: 141 additions & 29 deletions notebooks/classify_mitosis_data_prep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,15 @@
" sample_extra_info[\"first_sc_id\"] = sample[0].id\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"[len(all_class2sample_extra_info[cls]) for cls in all_class2sample_extra_info.keys()]"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -374,9 +383,15 @@
"# MAKE_SINGLE_CELL_TRAJ_SAMPLES = False\n",
"# DROP_MITOSIS_DIV = False\n",
"\n",
"ver = \"test-all\"\n",
"# ver = \"test-all\"\n",
"\n",
"# ver = \"13-inclusive-with-mitosis-type\"\n",
"ver = \"13-inclusive-corrected\"\n",
"MAKE_SINGLE_CELL_TRAJ_SAMPLES = False\n",
"DROP_MITOSIS_DIV = False"
"DROP_MITOSIS_DIV = False\n",
"INCLUDE_ALL = True\n",
"\n",
"DEBUG = True"
]
},
{
Expand All @@ -385,7 +400,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"data_dir = Path(f'notebook_results/mmaction_train_data_v{ver}')\n",
"class_labels = ['mitosis', 'apoptosis', 'normal']\n",
"class_label = \"mitosis\"\n",
Expand All @@ -394,6 +408,15 @@
"\n",
"# 1 instead of 0 to prevent the decord (used by mmdetection) python package error\n",
"padding_pixels = [1, 20, 40, 50, 100, 200, 400]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
Expand All @@ -408,8 +431,6 @@
"test_class2sample_extra_info = {}\n",
"\n",
"# randomize train and test data\n",
"\n",
"\n",
"for key in all_class2samples.keys():\n",
" randomized_indices = np.random.permutation(len(all_class2samples[key])).astype(int)\n",
" split_idx = int(len(all_class2samples[key]) * _split)\n",
Expand All @@ -418,8 +439,13 @@
" train_class2samples[key] = np.array(all_class2samples[key], dtype=object)[_train_indices]\n",
" test_class2samples[key] = np.array(all_class2samples[key], dtype=object)[_test_indices]\n",
"\n",
" train_class2samples[key] = list(train_class2samples[key])\n",
" test_class2samples[key] = list(test_class2samples[key])\n",
"\n",
" train_class2sample_extra_info[key] = np.array(all_class2sample_extra_info[key], dtype=object)[_train_indices]\n",
" test_class2sample_extra_info[key] = np.array(all_class2sample_extra_info[key], dtype=object)[_test_indices]\n",
" train_class2sample_extra_info[key] = list(train_class2sample_extra_info[key])\n",
" test_class2sample_extra_info[key] = list(test_class2sample_extra_info[key])\n",
"\n"
]
},
Expand Down Expand Up @@ -579,14 +605,91 @@
" test_class2samples = drop_multiple_cell_frames_in_samples(test_class2samples)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Include original, drop-div and st"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"\n",
"def modify_class2sample_extra_info(class2sample_extra_info: Dict[str, List[Dict]], mitosis_traj_type, copy_info=True):\n",
" for cls in class2sample_extra_info:\n",
" for _info in class2sample_extra_info[cls]:\n",
" if copy_info:\n",
" _info = copy_info.deepcopy(_info)\n",
" _info[\"mitosis_traj_type\"] = mitosis_traj_type\n",
"\n",
"modify_class2sample_extra_info(train_class2sample_extra_info, \"full\")\n",
"modify_class2sample_extra_info(test_class2sample_extra_info, \"full\")\n",
"\n",
"if INCLUDE_ALL:\n",
"\n",
"\n",
" # Extra caution: be careful drop_div related copy by refernce syntax stuff...\n",
" train_drop_div_class2samples = drop_multiple_cell_frames_in_samples(train_class2samples)\n",
" train_drop_div_class2extra_info = {cls: list(train_class2sample_extra_info[cls]) for cls in train_class2sample_extra_info.keys()}\n",
" modify_class2sample_extra_info(train_drop_div_class2extra_info, \"drop_div\")\n",
"\n",
" test_drop_div_class2samples = drop_multiple_cell_frames_in_samples(test_class2samples)\n",
" test_drop_div_class2extra_info = {cls: list(test_class2sample_extra_info[cls]) for cls in test_class2sample_extra_info.keys()}\n",
" modify_class2sample_extra_info(test_drop_div_class2extra_info, \"drop_div\")\n",
"\n",
" train_st_class2samples, train_st_class2sample_extra_info = make_one_cell_per_timeframe_for_class2samples(train_class2samples, train_class2sample_extra_info)\n",
" test_st_class2samples, test_st_class2sample_extra_info = make_one_cell_per_timeframe_for_class2samples(test_class2samples, test_class2sample_extra_info)\n",
" modify_class2sample_extra_info(train_st_class2sample_extra_info, \"st\")\n",
" modify_class2sample_extra_info(test_st_class2sample_extra_info, \"st\")\n",
"\n",
" # check st info and sample length are the same\n",
" assert len(train_st_class2samples[\"mitosis\"]) == len(train_st_class2sample_extra_info[\"mitosis\"])\n",
" assert len(train_st_class2samples[\"normal\"]) == len(train_st_class2sample_extra_info[\"normal\"])\n",
"\n",
" # check drop-div info and sample length are the same\n",
" assert len(train_drop_div_class2samples[\"mitosis\"]) == len(train_drop_div_class2extra_info[\"mitosis\"])\n",
" assert len(train_drop_div_class2samples[\"normal\"]) == len(train_drop_div_class2extra_info[\"normal\"])\n",
"\n",
" for cls in train_drop_div_class2samples:\n",
" assert len(train_drop_div_class2extra_info[cls]) == len(train_drop_div_class2samples[cls])\n",
" \n",
" for cls in train_class2samples:\n",
" print(\"cls:\", cls, \"len(train_class2samples[cls]):\", len(train_class2samples[cls]), \"len(train_class2sample_extra_info[cls]):\", len(train_class2sample_extra_info[cls]))\n",
" assert len(train_class2samples[cls]) == len(train_class2sample_extra_info[cls]), \\\n",
" f\"length of train_class2samples[{cls}] != length of train_class2sample_extra_info[{cls}], {len(train_class2samples[cls])} != {len(train_class2sample_extra_info[cls])}\"\n",
" \n",
" assert len(train_st_class2sample_extra_info[cls]) == len(train_st_class2samples[cls]), f\"flag1: {cls}\"\n",
" train_class2samples[cls].extend(train_st_class2samples[cls])\n",
" train_class2sample_extra_info[cls].extend(train_st_class2sample_extra_info[cls])\n",
" assert len(train_class2samples[cls]) == len(train_class2sample_extra_info[cls]), \\\n",
" f\"length of train_class2samples[{cls}] != length of train_class2sample_extra_info[{cls}], {len(train_class2samples[cls])} != {len(train_class2sample_extra_info[cls])}\"\n",
" \n",
" assert len(train_drop_div_class2extra_info[cls]) == len(train_drop_div_class2samples[cls]), f\"flag2: {cls}, {len(train_drop_div_class2extra_info[cls])} != {len(train_drop_div_class2samples[cls])}\"\n",
" train_class2samples[cls].extend(train_drop_div_class2samples[cls])\n",
" train_class2sample_extra_info[cls].extend(train_drop_div_class2extra_info[cls])\n",
" # print(\"cls:\", cls, \"len(train_class2samples[cls]):\", len(train_class2samples[cls]), \"len(train_class2sample_extra_info[cls]):\", len(train_class2sample_extra_info[cls]))\n",
"\n",
" test_class2samples[cls].extend(test_st_class2samples[cls])\n",
" test_class2sample_extra_info[cls].extend(test_st_class2sample_extra_info[cls])\n",
"\n",
" test_class2samples[cls].extend(test_drop_div_class2samples[cls])\n",
" test_class2sample_extra_info[cls].extend(test_drop_div_class2extra_info[cls])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for key, val in train_class2samples.items():\n",
" assert len(val) == len(train_class2sample_extra_info[key]), f\"key: {key}, len(val): {len(val)}, len(train_class2sample_extra_info[key]): {len(train_class2sample_extra_info[key])}\""
" assert len(train_class2samples[key]) == len(train_class2sample_extra_info[key]), f\"key: {key}, len(val): {len(val)}, len(train_class2sample_extra_info[key]): {len(train_class2sample_extra_info[key])}\""
]
},
{
Expand All @@ -610,12 +713,13 @@
"importlib.reload(livecellx.core.sc_video_utils)\n",
"\n",
"# # for debug\n",
"# test_sample_num = 3\n",
"# padding_pixels = [1, 20]\n",
"# train_class2samples = {key: value[:test_sample_num] for key, value in all_class2samples.items()}\n",
"# test_class2samples = {key: value[:test_sample_num] for key, value in all_class2samples.items()}\n",
"# train_class2sample_extra_info = {key: value[:test_sample_num] for key, value in all_class2sample_extra_info.items()}\n",
"# test_class2sample_extra_info = {key: value[:test_sample_num] for key, value in all_class2sample_extra_info.items()}\n",
"if DEBUG:\n",
" test_sample_num = 3\n",
" padding_pixels = [1, 20]\n",
" train_class2samples = {key: value[:test_sample_num] for key, value in all_class2samples.items()}\n",
" test_class2samples = {key: value[:test_sample_num] for key, value in all_class2samples.items()}\n",
" train_class2sample_extra_info = {key: value[:test_sample_num] for key, value in all_class2sample_extra_info.items()}\n",
" test_class2sample_extra_info = {key: value[:test_sample_num] for key, value in all_class2sample_extra_info.items()}\n",
"\n",
"# padding_pixels = [20]\n",
"\n",
Expand Down Expand Up @@ -668,26 +772,43 @@
" index=False,\n",
" header=True,\n",
" sep=\" \",\n",
")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_sample_info_df = pd.read_csv(data_dir / f\"train_data.txt\", sep=\" \")\n",
"test_sample_info_df = pd.read_csv(data_dir / f\"test_data.txt\", sep=\" \")\n",
"\n",
"mmaction_df_paths = []\n",
"for selected_frame_type in frame_types:\n",
"for selected_frame_type in frame_types + [\"all\"]:\n",
" train_df_path = data_dir / f\"mmaction_train_data_{selected_frame_type}.txt\"\n",
" train_selected_frame_type_df = train_sample_info_df[train_sample_info_df[\"frame_type\"] == selected_frame_type]\n",
" if selected_frame_type == \"all\":\n",
" train_selected_frame_type_df = train_sample_info_df\n",
" else:\n",
" train_selected_frame_type_df = train_sample_info_df[train_sample_info_df[\"frame_type\"] == selected_frame_type]\n",
" train_selected_frame_type_df = train_selected_frame_type_df.reset_index(drop=True)\n",
" train_selected_frame_type_df = train_selected_frame_type_df[[\"path\", \"label_index\"]]\n",
" train_selected_frame_type_df.to_csv(train_df_path, index=False, header=False, sep=\" \")\n",
"\n",
" test_df_path = data_dir / f\"mmaction_test_data_{selected_frame_type}.txt\"\n",
" test_selected_frame_type_df = test_sample_info_df[test_sample_info_df[\"frame_type\"] == selected_frame_type]\n",
"\n",
" if selected_frame_type == \"all\":\n",
" test_selected_frame_type_df = test_sample_info_df\n",
" else:\n",
" test_selected_frame_type_df = test_sample_info_df[test_sample_info_df[\"frame_type\"] == selected_frame_type]\n",
" test_selected_frame_type_df = test_selected_frame_type_df[[\"path\", \"label_index\"]]\n",
" test_selected_frame_type_df = test_selected_frame_type_df.reset_index(drop=True)\n",
" test_selected_frame_type_df.to_csv(test_df_path, index=False, header=False, sep=\" \")\n",
"\n",
" mmaction_df_paths.append(train_df_path)\n",
" mmaction_df_paths.append(test_df_path)\n",
"\n",
"\n",
"mmaction_df_paths\n",
"# # The follwing code generates v1-v7 test data. The issue is that some of test data shows up in train data, through different padding values.\n",
"# data_df_path = data_dir/'all_data.txt'\n",
"# sample_df = gen_samples_df(data_dir, class_labels, padding_pixels, frame_types, fps)\n",
Expand All @@ -708,15 +829,6 @@
"# test_df.to_csv(test_df_path, index=False, header=False, sep=' ')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_class2samples"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -731,8 +843,8 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"video_paths = list(Path(data_dir/'videos').glob('*.mp4'))"
"video_paths = list(Path(data_dir/'videos').glob('*.mp4'))\n",
"print(\"len(video_paths):\", len(video_paths))"
]
},
{
Expand Down Expand Up @@ -761,7 +873,7 @@
"source": [
"import decord\n",
"invalid_decord_paths = []\n",
"for path in video_paths:\n",
"for path in tqdm(video_paths):\n",
"# for path in [\"./notebook_results/train_normal_6_raw_padding-0.mp4\"]:\n",
"# for path in [\"./test_video_output.mp4\"]:\n",
" reader = decord.VideoReader(str(path))\n",
Expand Down

0 comments on commit b10b60b

Please sign in to comment.