Skip to content

Commit

Permalink
update single_cell.py, utils.py and classify_utils.py: show sct on gr…
Browse files Browse the repository at this point in the history
…id and refactor
  • Loading branch information
dummyindex committed Nov 12, 2023
1 parent c9088c2 commit b97669a
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 38 deletions.
143 changes: 107 additions & 36 deletions livecellx/core/single_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
from collections import deque
import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1114,42 +1115,11 @@ def get_next_largest_element(sorted_list, elem):
return get_next_largest_element(self.times, time)

@staticmethod
def show_trajectory_on_grid(
trajectory: "SingleCellTrajectory",
nr=4,
nc=4,
start_timeframe=20,
interval=5,
padding=20,
):
fig, axes = plt.subplots(nr, nc, figsize=(nc * 4, nr * 4))
if nr == 1:
axes = np.array([axes])
span_range = trajectory.get_timeframe_span()
traj_start, traj_end = span_range
if start_timeframe < traj_start:
print(
"start timeframe larger than the first timeframe of the trajectory, replace start_timeframe with the first timeframe..."
)
start_timeframe = span_range[0]
for r in range(nr):
for c in range(nc):
ax = axes[r, c]
ax.axis("off")
timeframe = start_timeframe + interval * (r * nc + c)
if timeframe > traj_end:
break
if timeframe not in trajectory.timeframe_set:
continue
sc = trajectory.get_single_cell(timeframe)
sc_img = sc.get_img_crop(padding=padding)
ax.imshow(sc_img)
# contour_coords = sc.get_img_crop_contour_coords(padding=padding)
contour_coords = sc.get_contour_coords_on_crop(padding=padding)
ax.scatter(contour_coords[:, 1], contour_coords[:, 0], s=1, c="r")
# trajectory_collection[timeframe].plot(axes[r, c])
ax.set_title(f"timeframe: {timeframe}")
fig.tight_layout(pad=0.5, h_pad=0.4, w_pad=0.4)
def show_trajectory_on_grid(**kwargs):
return show_sct_on_grid(**kwargs)

def show_on_grid(self, **kwargs):
return show_sct_on_grid(self, **kwargs)


class SingleCellTrajectoryCollection:
Expand Down Expand Up @@ -1356,3 +1326,104 @@ def create_sc_table(
continue
df[key] = [sc.meta[key] for sc in scs]
return df


def show_sct_on_grid(
trajectory: "SingleCellTrajectory",
nr=4,
nc=4,
start=0,
interval=5,
padding=20,
dims: Tuple[int, int] = None,
dims_offset: Tuple[int, int] = (0, 0),
pad_dims=True,
ax_width=4,
ax_height=4,
ax_title_fontsize=8,
cmap="viridis",
ax_contour_polygon_kwargs=dict(fill=None, edgecolor="r"),
) -> matplotlib.axes.Axes:
"""
Display a grid of single cell images with contours overlaid.
Parameters:
-----------
trajectory : SingleCellTrajectory
The trajectory object containing the single cell images.
nr : int, optional
Number of rows in the grid, by default 4.
nc : int, optional
Number of columns in the grid, by default 4.
start : int, optional
The starting timeframe, by default 0.
interval : int, optional
The interval between timeframes, by default 5.
padding : int, optional
The padding around the single cell image, by default 20.
dims : Tuple[int, int], optional
The dimensions to crop the single cell image to, by default None.
dims_offset : Tuple[int, int], optional
The offset to apply to the cropped image, by default (0, 0).
pad_dims : bool, optional
Whether to pad the cropped image to match the specified dimensions, by default True.
ax_width : int, optional
The width of each subplot, by default 4.
ax_height : int, optional
The height of each subplot, by default 4.
ax_title_fontsize : int, optional
The fontsize of the subplot titles, by default 8.
cmap : str, optional
The colormap to use for displaying the single cell images, by default "viridis".
ax_contour_polygon_kwargs : dict, optional
The keyword arguments to pass to the Polygon object for drawing the contour, by default dict(fill=None, edgecolor='r').
Returns:
--------
matplotlib.axes.Axes
The axes object containing the grid of subplots.
"""
fig, axes = plt.subplots(nr, nc, figsize=(nc * ax_width, nr * ax_height))
if nr == 1:
axes = np.array([axes])
span_range = trajectory.get_timeframe_span()
traj_start, traj_end = span_range
if start < traj_start:
print(
"start timeframe larger than the first timeframe of the trajectory, replace start_timeframe with the first timeframe..."
)
start = span_range[0]
for r in range(nr):
for c in range(nc):
ax = axes[r, c]
ax.axis("off")
timeframe = start + interval * (r * nc + c)
if timeframe > traj_end:
break
if timeframe not in trajectory.timeframe_set:
continue
sc = trajectory.get_single_cell(timeframe)
sc_img = sc.get_img_crop(padding=padding)
contour_coords = sc.get_contour_coords_on_crop(padding=padding)

if dims is not None:
sc_img = sc_img[dims_offset[0] : dims_offset[0] + dims[0], dims_offset[1] : dims_offset[1] + dims[1]]
contour_coords[:, 0] -= dims_offset[0]
contour_coords[:, 1] -= dims_offset[1]

if pad_dims:
_pad_pixels = [max(0, dims[i] - sc_img.shape[i]) for i in range(len(dims))]
sc_img = np.pad(sc_img, _pad_pixels, mode="constant", constant_values=0)
contour_coords[:, 0] += _pad_pixels[0]
contour_coords[:, 1] += _pad_pixels[1]
ax.imshow(sc_img, cmap=cmap)
# draw a polygon based on contour coordinates
from matplotlib.patches import Polygon

polygon = Polygon(
np.array([contour_coords[:, 1], contour_coords[:, 0]]).transpose(), **ax_contour_polygon_kwargs
)
ax.add_patch(polygon)
ax.set_title(f"time: {timeframe}", fontsize=ax_title_fontsize)
fig.tight_layout(pad=0.5, h_pad=0.4, w_pad=0.4)
return axes
4 changes: 2 additions & 2 deletions livecellx/track/classify_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_all_json_dirs(
return all_class2samples, all_class2sample_extra_info


def gen_one_sc_samples_by_window(sctc: SingleCellTrajectoryCollection, window_size=7, step_size=1):
def gen_tid2samples_by_window(sctc: SingleCellTrajectoryCollection, window_size=7, step_size=1):
tid2samples = {}
tid2start_end_times = {}
for tid, sct in sctc:
Expand Down Expand Up @@ -86,7 +86,7 @@ def gen_inference_sctc_sample_videos(
sc_samples = []
samples_info_list = []

tid2samples, tid2start_end_times = gen_one_sc_samples_by_window(sctc, window_size=window_size, step_size=step_size)
tid2samples, tid2start_end_times = gen_tid2samples_by_window(sctc, window_size=window_size, step_size=step_size)
for tid, samples in tid2samples.items():
start_end_times = tid2start_end_times[tid]
for i, sample in enumerate(samples):
Expand Down

0 comments on commit b97669a

Please sign in to comment.