Skip to content

Commit

Permalink
Simplify code throughout + add testing notebook for codespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
kushalbakshi committed Apr 9, 2024
1 parent a720c6c commit 70db8cb
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 77 deletions.
93 changes: 60 additions & 33 deletions element_calcium_imaging/plotting/draw_rois.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
Serverside,
ServersideOutputTransform,
)
from scipy import ndimage
from skimage import draw, measure
from tifffile import TiffFile

from .utilities import *

Expand All @@ -25,6 +22,7 @@


def draw_rois(db_prefix: str):
scan = dj.create_virtual_module("scan", f"{db_prefix}scan")
imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging")
all_keys = (imaging.MotionCorrection).fetch("KEY")

Expand All @@ -34,29 +32,42 @@ def draw_rois(db_prefix: str):
app.layout = html.Div(
[
html.H2("Draw ROIs", style={"color": colors["text"]}),
html.Label("Select data key from dropdown", style={"color": colors["text"]}),
html.Label(
"Select data key from dropdown", style={"color": colors["text"]}
),
dcc.Dropdown(
id="toplevel-dropdown", options=[str(key) for key in all_keys]
),
html.Br(),
html.Div(
[
html.Button("Load Image", id="load-image-button", style={"margin-right": "20px"}),
html.Button(
"Load Image",
id="load-image-button",
style={"margin-right": "20px"},
),
dcc.RadioItems(
id='image-type-radio',
id="image-type-radio",
options=[
{'label': 'Average Image', 'value': 'average_image'},
{'label': 'Max Projection Image', 'value': 'max_projection_image'}
{"label": "Average Image", "value": "average_image"},
{
"label": "Max Projection Image",
"value": "max_projection_image",
},
],
value='average_image', # Default value
labelStyle={'display': 'inline-block', 'margin-right': '10px'}, # Inline display with some margin
style={'display': 'inline-block', 'color': colors['text']} # Inline display to keep it on the same line
value="average_image",
labelStyle={"display": "inline-block", "margin-right": "10px"},
style={"display": "inline-block", "color": colors["text"]},
),
html.Div(
[
html.Button("Submit Curated Masks", id="submit-button"),
],
style={"textAlign": "right", "flex": "1", "display": "inline-block"},
style={
"textAlign": "right",
"flex": "1",
"display": "inline-block",
},
),
],
style={
Expand All @@ -76,6 +87,7 @@ def draw_rois(db_prefix: str):
"drawclosedpath",
"drawrect",
"drawcircle",
"drawline",
"eraseshape",
],
},
Expand All @@ -84,10 +96,10 @@ def draw_rois(db_prefix: str):
],
style={
"display": "flex",
"justify-content": "center", # Centers the child horizontally
"align-items": "center", # Centers the child vertically (if you have vertical space to work with)
"justify-content": "center",
"align-items": "center",
"padding": "0.0",
"margin": "auto" # Automatically adjust the margins to center the div
"margin": "auto",
},
),
html.Pre(id="annotations"),
Expand Down Expand Up @@ -120,13 +132,22 @@ def store_key(value):
def create_figure(value, render_n_clicks, image_type):
if render_n_clicks is not None:
if image_type == "average_image":
summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("average_image")
summary_images = (
imaging.MotionCorrection.Summary & yaml.safe_load(value)
).fetch("average_image")
else:
summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("max_proj_image")
summary_images = (
imaging.MotionCorrection.Summary & yaml.safe_load(value)
).fetch("max_proj_image")
average_images = [image.astype("float") for image in summary_images]
roi_contours = get_contours(yaml.safe_load(value), db_prefix)
logger.info("Generating figure.")
fig = px.imshow(np.asarray(average_images), animation_frame=0, binary_string=True, labels=dict(animation_frame="plane"))
fig = px.imshow(
np.asarray(average_images),
animation_frame=0,
binary_string=True,
labels=dict(animation_frame="plane"),
)
for contour in roi_contours:
# Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates
fig.add_trace(
Expand Down Expand Up @@ -171,30 +192,36 @@ def on_relayout(relayout_data):
elif any(["shapes" in key for key in relayout_data]):
return Serverside(relayout_data)


@app.callback(
Output("submit-output", "children"),
Input("submit-button", "n_clicks"),
State("store-mask", "annotation_list"),
State("store-key", "value")
State("store-key", "value"),
)
def submit_annotations(n_clicks, annotation_list, value):
print("submitting annotations")
x_mask_li = []
y_mask_li = []
if n_clicks is not None:
if "shapes" in annotation_list:
shapes = [d["type"] for d in annotation_list["shapes"]]
for shape, annotation in zip(shapes, annotation_list["shapes"]):
mask = create_mask(annotation, shape)
y_mask_li.append(mask[0])
x_mask_li.append(mask[1])

suite2p_masks = convert_masks_to_suite2p_format(
[np.array([x_mask_li, y_mask_li])], (512, 512)
)
fluo_traces = extract_signals_suite2p(yaml.safe_load(value), suite2p_masks)

if annotation_list:
if "shapes" in annotation_list:
logger.info("Creating Masks.")
shapes = [d["type"] for d in annotation_list["shapes"]]
for shape, annotation in zip(shapes, annotation_list["shapes"]):
mask = create_mask(annotation, shape)
y_mask_li.append(mask[0])
x_mask_li.append(mask[1])
print("Masks created")
insert_into_database(
scan, imaging, yaml.safe_load(value), x_mask_li, y_mask_li
)
else:
logger.warn(
"Incorrect annotation list format. This is a known bug. Please draw a line anywhere on the image and click `Submit Curated Masks`. It will be ignored in the final submission but will format the list correctly."
)
return no_update
else:
logger.warn("No annotations to submit.")
return no_update
else:
return no_update

Expand Down
112 changes: 68 additions & 44 deletions element_calcium_imaging/plotting/utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import pathlib
import datajoint as dj
import numpy as np
from scipy import ndimage
from skimage import draw, measure
from element_interface.utils import find_full_path


logger = dj.logger


def get_imaging_root_data_dir():
"""Retrieve imaging root data directory."""
imaging_root_dirs = dj.config.get("custom", {}).get("imaging_root_data_dir", None)
if not imaging_root_dirs:
return None
elif isinstance(imaging_root_dirs, (str, pathlib.Path)):
return [imaging_root_dirs]
elif isinstance(imaging_root_dirs, list):
return imaging_root_dirs
else:
raise TypeError("`imaging_root_data_dir` must be a string, pathlib, or list")


def path_to_indices(path):
Expand Down Expand Up @@ -88,9 +106,7 @@ def create_mask(coordinates, shape_type):
xy_coordinates = np.asarray(
[item for item in coordinates.values()], dtype="int"
)
mask = np.asarray(
create_ellipse_mask(xy_coordinates, (512, 512))
).nonzero()
mask = np.asarray(create_ellipse_mask(xy_coordinates, (512, 512))).nonzero()
elif shape_type == "rect":
try:
mask = np.asarray(
Expand Down Expand Up @@ -140,57 +156,65 @@ def get_contours(image_key, prefix):
return contours


def convert_masks_to_suite2p_format(masks, frame_dims):
"""
Convert masks to the format expected by Suite2P.
Parameters:
masks (list of np.ndarray): A list where each item is an array representing a mask,
with non-zero values for the ROI and zeros elsewhere.
frame_dims (tuple): The dimensions of the imaging frame, (height, width).
Returns:
np.ndarray: A 2D array where each column represents a flattened binary mask for an ROI.
"""
# Calculate the total number of pixels in a frame
num_pixels = frame_dims[0] * frame_dims[1]

# Initialize an empty array to store the flattened binary masks
suite2p_masks = np.zeros((num_pixels, len(masks)), dtype=np.float32)

# Convert each mask
for idx, mask in enumerate(masks):
# Ensure the mask is binary (1 for ROI, 0 for background)
binary_mask = np.where(mask > 0, 1, 0).astype(np.float32)

# Flatten the binary mask and add it as a column in the suite2p_masks array
suite2p_masks[:, idx] = binary_mask.flatten()

return suite2p_masks


def load_imaging_data_for_session(key):
def load_imaging_data_for_session(scan, key):
image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
image_files = [
find_full_path(get_imaging_root_data_dir(), image_file)
for image_file in image_files
]
acq_software = (scan.Scan & key).fetch1("acq_software")
if acq_software == "ScanImage":
imaging_data = tifffile.imread(image_files[0])
import tifffile

imaging_data = tifffile.imread(image_files[0])
elif acq_software == "NIS":
import nd2

imaging_data = nd2.imread(image_files[0])
else:
raise ValueError(f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget.")
raise ValueError(
f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget."
)
return imaging_data


def extract_signals_suite2p(key, masks):
from suite2p.extraction.extract import extrace_traces

F, _ = extrace_traces(load_imaging_data_for_session(key), masks, neuropil_masks=np.zeros_like(masks))


def insert_signals_into_datajoint(signals, session_key):
# Implement logic to insert the extracted signals into DataJoint
pass
def insert_into_database(scan_module, imaging_module, session_key, x_masks, y_masks):
images = load_imaging_data_for_session(scan_module, session_key)
print(f"Images shape: {images.shape}")
mask_id = (imaging_module.Segmentation.Mask & session_key).fetch(
"mask_id", order_by="DESC mask_id", limit=1
)
print(f"Mask ID: {mask_id}")
logger.info(f"Inserting {len(x_masks)} masks into the database.")
# imaging_module.Segmentation.Mask.insert(
# [
# dict(
# **session_key,
# mask=mask_id + mask_num,
# segmentation_channel=1,
# mask_npix=y_mask.shape[0],
# mask_center_x=int(sum(x_mask) / x_mask.shape[0]),
# mask_center_y=int(sum(y_mask) / y_mask.shape[0]),
# mask_center_z=0,
# mask_xpix=x_mask,
# mask_ypix=y_mask,
# mask_zpix=0,
# mask_weights=np.ones_like(y_mask),
# )
# for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks))
# ],
# allow_direct_insert=True,
# )
logger.info(f"Inserting {len(x_masks)} traces into the database.")
# imaging_module.Fluorescence.Trace.insert(
# [
# dict(
# **session_key,
# mask=mask_id + mask_num,
# fluo_channel=1,
# fluorescence=images[:, y_mask, x_mask].mean(axis=1),
# )
# for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks))
# ],
# allow_direct_insert=True,
# )
50 changes: 50 additions & 0 deletions notebooks/test_widget.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# change to the upper level folder to detect dj_local_conf.json\n",
"if os.path.basename(os.getcwd()) == \"notebooks\":\n",
" os.chdir(\"..\")\n",
"\n",
"import datajoint as dj\n",
"from element_calcium_imaging.plotting.draw_rois import draw_rois"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"draw_rois(\"neuro_\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "elements",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 70db8cb

Please sign in to comment.