From a720c6c54605589b54366f2686ed0b3cbfec6654 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Wed, 27 Mar 2024 14:34:24 -0500 Subject: [PATCH 1/9] Add ROI drawing widget for manual curation --- .gitignore | 3 + element_calcium_imaging/plotting/draw_rois.py | 201 ++++++++++++++++++ element_calcium_imaging/plotting/utilities.py | 196 +++++++++++++++++ 3 files changed, 400 insertions(+) create mode 100644 element_calcium_imaging/plotting/draw_rois.py create mode 100644 element_calcium_imaging/plotting/utilities.py diff --git a/.gitignore b/.gitignore index 0f635c3c..b94ad25a 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ example_data # vscode *.code-workspace + +# dash widget +file_system_backend \ No newline at end of file diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py new file mode 100644 index 00000000..8b63a981 --- /dev/null +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -0,0 +1,201 @@ +import yaml +import datajoint as dj +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +from dash import no_update +from dash_extensions.enrich import ( + DashProxy, + Input, + Output, + State, + html, + dcc, + Serverside, + ServersideOutputTransform, +) +from scipy import ndimage +from skimage import draw, measure +from tifffile import TiffFile + +from .utilities import * + + +logger = dj.logger + + +def draw_rois(db_prefix: str): + imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging") + all_keys = (imaging.MotionCorrection).fetch("KEY") + + colors = {"background": "#111111", "text": "#00a0df"} + + app = DashProxy(transforms=[ServersideOutputTransform()]) + app.layout = html.Div( + [ + html.H2("Draw ROIs", style={"color": colors["text"]}), + html.Label("Select data key from dropdown", style={"color": colors["text"]}), + dcc.Dropdown( + id="toplevel-dropdown", options=[str(key) for key in all_keys] + ), + html.Br(), + html.Div( + [ + html.Button("Load Image", id="load-image-button", style={"margin-right": "20px"}), + dcc.RadioItems( + id='image-type-radio', + options=[ + {'label': 'Average Image', 'value': 'average_image'}, + {'label': 'Max Projection Image', 'value': 'max_projection_image'} + ], + value='average_image', # Default value + labelStyle={'display': 'inline-block', 'margin-right': '10px'}, # Inline display with some margin + style={'display': 'inline-block', 'color': colors['text']} # Inline display to keep it on the same line + ), + html.Div( + [ + html.Button("Submit Curated Masks", id="submit-button"), + ], + style={"textAlign": "right", "flex": "1", "display": "inline-block"}, + ), + ], + style={ + "display": "flex", + "justify-content": "flex-start", + "align-items": "center", + }, + ), + html.Br(), + html.Br(), + html.Div( + [ + dcc.Graph( + id="avg-image", + config={ + "modeBarButtonsToAdd": [ + "drawclosedpath", + "drawrect", + "drawcircle", + "eraseshape", + ], + }, + style={"width": "100%", "height": "100%"}, + ) + ], + style={ + "display": "flex", + "justify-content": "center", # Centers the child horizontally + "align-items": "center", # Centers the child vertically (if you have vertical space to work with) + "padding": "0.0", + "margin": "auto" # Automatically adjust the margins to center the div + }, + ), + html.Pre(id="annotations"), + html.Div(id="button-output"), + dcc.Store(id="store-key"), + dcc.Store(id="store-mask"), + dcc.Store(id="store-movie"), + html.Div(id="submit-output"), + ] + ) + + @app.callback( + Output("store-key", "value"), + Input("toplevel-dropdown", "value"), + ) + def store_key(value): + if value is not None: + return Serverside(value) + else: + return no_update + + @app.callback( + Output("avg-image", "figure"), + Output("store-movie", "average_images"), + State("store-key", "value"), + Input("load-image-button", "n_clicks"), + Input("image-type-radio", "value"), + prevent_initial_call=True, + ) + def create_figure(value, render_n_clicks, image_type): + if render_n_clicks is not None: + if image_type == "average_image": + summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("average_image") + else: + summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("max_proj_image") + average_images = [image.astype("float") for image in summary_images] + roi_contours = get_contours(yaml.safe_load(value), db_prefix) + logger.info("Generating figure.") + fig = px.imshow(np.asarray(average_images), animation_frame=0, binary_string=True, labels=dict(animation_frame="plane")) + for contour in roi_contours: + # Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates + fig.add_trace( + go.Scatter( + x=contour[:, 1], # Plotly uses x, y order for coordinates + y=contour[:, 0], + mode="lines", # Display as lines (not markers) + line=dict(color="white", width=0.5), # Set line color and width + showlegend=False, # Do not show legend for each contour + ) + ) + fig.update_layout( + dragmode="drawrect", + autosize=True, + height=550, + newshape=dict(opacity=0.6, fillcolor="#00a0df"), + plot_bgcolor=colors["background"], + paper_bgcolor=colors["background"], + font_color=colors["text"], + ) + fig.update_annotations(bgcolor="#00a0df") + else: + return no_update + return fig, Serverside(average_images) + + @app.callback( + Output("store-mask", "annotation_list"), + Input("avg-image", "relayoutData"), + prevent_initial_call=True, + ) + def on_relayout(relayout_data): + if not relayout_data: + return no_update + else: + if "shapes" in relayout_data: + global shape_type + try: + shape_type = relayout_data["shapes"][-1]["type"] + return Serverside(relayout_data) + except IndexError: + return no_update + elif any(["shapes" in key for key in relayout_data]): + return Serverside(relayout_data) + + + @app.callback( + Output("submit-output", "children"), + Input("submit-button", "n_clicks"), + State("store-mask", "annotation_list"), + State("store-key", "value") + ) + def submit_annotations(n_clicks, annotation_list, value): + print("submitting annotations") + x_mask_li = [] + y_mask_li = [] + if n_clicks is not None: + if "shapes" in annotation_list: + shapes = [d["type"] for d in annotation_list["shapes"]] + for shape, annotation in zip(shapes, annotation_list["shapes"]): + mask = create_mask(annotation, shape) + y_mask_li.append(mask[0]) + x_mask_li.append(mask[1]) + + suite2p_masks = convert_masks_to_suite2p_format( + [np.array([x_mask_li, y_mask_li])], (512, 512) + ) + fluo_traces = extract_signals_suite2p(yaml.safe_load(value), suite2p_masks) + + else: + return no_update + + app.run_server(port=8000) diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py new file mode 100644 index 00000000..0308a4f8 --- /dev/null +++ b/element_calcium_imaging/plotting/utilities.py @@ -0,0 +1,196 @@ +import datajoint as dj +import numpy as np +from scipy import ndimage +from skimage import draw, measure + + +def path_to_indices(path): + """From SVG path to numpy array of coordinates, each row being a (row, col) point""" + indices_str = [ + el.replace("M", "").replace("Z", "").split(",") for el in path.split("L") + ] + return np.rint(np.array(indices_str, dtype=float)).astype(int) + + +def path_to_mask(path, shape): + """From SVG path to a boolean array where all pixels enclosed by the path + are True, and the other pixels are False. + """ + cols, rows = path_to_indices(path).T + rr, cc = draw.polygon(rows, cols) + mask = np.zeros(shape, dtype=bool) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + return mask + + +def create_ellipse_mask(vertices, image_shape): + """ + Create a mask for an ellipse given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the bounding box of the ellipse. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the ellipse. + """ + x0, x1, y0, y1 = vertices + center = ((x0 + x1) / 2, (y0 + y1) / 2) + axis_lengths = (abs(x1 - x0) / 2, abs(y1 - y0) / 2) + + rr, cc = draw.ellipse( + center[1], center[0], axis_lengths[1], axis_lengths[0], shape=image_shape + ) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_rectangle_mask(vertices, image_shape): + """ + Create a mask for a rectangle given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the top-left and bottom-right corners of the rectangle. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the rectangle. + """ + x0, x1, y0, y1 = vertices + rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1), shape=image_shape) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_mask(coordinates, shape_type): + if shape_type == "path": + try: + mask = np.asarray(path_to_mask(coordinates["path"], (512, 512))).nonzero() + except KeyError: + for key, info in coordinates.items(): + mask = np.asarray(path_to_mask(info, (512, 512))).nonzero() + + elif shape_type == "circle": + try: + mask = np.asarray( + create_ellipse_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray( + create_ellipse_mask(xy_coordinates, (512, 512)) + ).nonzero() + elif shape_type == "rect": + try: + mask = np.asarray( + create_rectangle_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray( + create_rectangle_mask(xy_coordinates, (512, 512)) + ).nonzero() + elif shape_type == "line": + try: + mask = np.array( + ( + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ) + ) + except KeyError: + mask = np.asarray([item for item in coordinates.values()], dtype="int") + return mask + + +def get_contours(image_key, prefix): + scan = dj.create_virtual_module("scan", f"{prefix}scan") + imaging = dj.create_virtual_module("imaging", f"{prefix}imaging") + yshape, xshape = (scan.ScanInfo.Field & image_key).fetch1("px_height", "px_width") + mask_xpix, mask_ypix = (imaging.Segmentation.Mask & image_key).fetch( + "mask_xpix", "mask_ypix" + ) + mask_image = np.zeros((yshape, xshape), dtype=bool) + for xpix, ypix in zip(mask_xpix, mask_ypix): + mask_image[ypix, xpix] = True + contours = measure.find_contours(mask_image.astype(float), 0.5) + return contours + + +def convert_masks_to_suite2p_format(masks, frame_dims): + """ + Convert masks to the format expected by Suite2P. + + Parameters: + masks (list of np.ndarray): A list where each item is an array representing a mask, + with non-zero values for the ROI and zeros elsewhere. + frame_dims (tuple): The dimensions of the imaging frame, (height, width). + + Returns: + np.ndarray: A 2D array where each column represents a flattened binary mask for an ROI. + """ + # Calculate the total number of pixels in a frame + num_pixels = frame_dims[0] * frame_dims[1] + + # Initialize an empty array to store the flattened binary masks + suite2p_masks = np.zeros((num_pixels, len(masks)), dtype=np.float32) + + # Convert each mask + for idx, mask in enumerate(masks): + # Ensure the mask is binary (1 for ROI, 0 for background) + binary_mask = np.where(mask > 0, 1, 0).astype(np.float32) + + # Flatten the binary mask and add it as a column in the suite2p_masks array + suite2p_masks[:, idx] = binary_mask.flatten() + + return suite2p_masks + + +def load_imaging_data_for_session(key): + image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path") + image_files = [ + find_full_path(get_imaging_root_data_dir(), image_file) + for image_file in image_files + ] + acq_software = (scan.Scan & key).fetch1("acq_software") + if acq_software == "ScanImage": + imaging_data = tifffile.imread(image_files[0]) + elif acq_software == "NIS": + imaging_data = nd2.imread(image_files[0]) + else: + raise ValueError(f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget.") + return imaging_data + + +def extract_signals_suite2p(key, masks): + from suite2p.extraction.extract import extrace_traces + + F, _ = extrace_traces(load_imaging_data_for_session(key), masks, neuropil_masks=np.zeros_like(masks)) + + +def insert_signals_into_datajoint(signals, session_key): + # Implement logic to insert the extracted signals into DataJoint + pass \ No newline at end of file From 70db8cb12c8e49e49ef25b0f2c8d1de0cc4b8c15 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Mon, 8 Apr 2024 21:58:38 -0500 Subject: [PATCH 2/9] Simplify code throughout + add testing notebook for codespaces --- element_calcium_imaging/plotting/draw_rois.py | 93 +++++++++------ element_calcium_imaging/plotting/utilities.py | 112 +++++++++++------- notebooks/test_widget.ipynb | 50 ++++++++ 3 files changed, 178 insertions(+), 77 deletions(-) create mode 100644 notebooks/test_widget.ipynb diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py index 8b63a981..8308b78a 100644 --- a/element_calcium_imaging/plotting/draw_rois.py +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -14,9 +14,6 @@ Serverside, ServersideOutputTransform, ) -from scipy import ndimage -from skimage import draw, measure -from tifffile import TiffFile from .utilities import * @@ -25,6 +22,7 @@ def draw_rois(db_prefix: str): + scan = dj.create_virtual_module("scan", f"{db_prefix}scan") imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging") all_keys = (imaging.MotionCorrection).fetch("KEY") @@ -34,29 +32,42 @@ def draw_rois(db_prefix: str): app.layout = html.Div( [ html.H2("Draw ROIs", style={"color": colors["text"]}), - html.Label("Select data key from dropdown", style={"color": colors["text"]}), + html.Label( + "Select data key from dropdown", style={"color": colors["text"]} + ), dcc.Dropdown( id="toplevel-dropdown", options=[str(key) for key in all_keys] ), html.Br(), html.Div( [ - html.Button("Load Image", id="load-image-button", style={"margin-right": "20px"}), + html.Button( + "Load Image", + id="load-image-button", + style={"margin-right": "20px"}, + ), dcc.RadioItems( - id='image-type-radio', + id="image-type-radio", options=[ - {'label': 'Average Image', 'value': 'average_image'}, - {'label': 'Max Projection Image', 'value': 'max_projection_image'} + {"label": "Average Image", "value": "average_image"}, + { + "label": "Max Projection Image", + "value": "max_projection_image", + }, ], - value='average_image', # Default value - labelStyle={'display': 'inline-block', 'margin-right': '10px'}, # Inline display with some margin - style={'display': 'inline-block', 'color': colors['text']} # Inline display to keep it on the same line + value="average_image", + labelStyle={"display": "inline-block", "margin-right": "10px"}, + style={"display": "inline-block", "color": colors["text"]}, ), html.Div( [ html.Button("Submit Curated Masks", id="submit-button"), ], - style={"textAlign": "right", "flex": "1", "display": "inline-block"}, + style={ + "textAlign": "right", + "flex": "1", + "display": "inline-block", + }, ), ], style={ @@ -76,6 +87,7 @@ def draw_rois(db_prefix: str): "drawclosedpath", "drawrect", "drawcircle", + "drawline", "eraseshape", ], }, @@ -84,10 +96,10 @@ def draw_rois(db_prefix: str): ], style={ "display": "flex", - "justify-content": "center", # Centers the child horizontally - "align-items": "center", # Centers the child vertically (if you have vertical space to work with) + "justify-content": "center", + "align-items": "center", "padding": "0.0", - "margin": "auto" # Automatically adjust the margins to center the div + "margin": "auto", }, ), html.Pre(id="annotations"), @@ -120,13 +132,22 @@ def store_key(value): def create_figure(value, render_n_clicks, image_type): if render_n_clicks is not None: if image_type == "average_image": - summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("average_image") + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("average_image") else: - summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("max_proj_image") + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("max_proj_image") average_images = [image.astype("float") for image in summary_images] roi_contours = get_contours(yaml.safe_load(value), db_prefix) logger.info("Generating figure.") - fig = px.imshow(np.asarray(average_images), animation_frame=0, binary_string=True, labels=dict(animation_frame="plane")) + fig = px.imshow( + np.asarray(average_images), + animation_frame=0, + binary_string=True, + labels=dict(animation_frame="plane"), + ) for contour in roi_contours: # Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates fig.add_trace( @@ -171,30 +192,36 @@ def on_relayout(relayout_data): elif any(["shapes" in key for key in relayout_data]): return Serverside(relayout_data) - @app.callback( Output("submit-output", "children"), Input("submit-button", "n_clicks"), State("store-mask", "annotation_list"), - State("store-key", "value") + State("store-key", "value"), ) def submit_annotations(n_clicks, annotation_list, value): - print("submitting annotations") x_mask_li = [] y_mask_li = [] if n_clicks is not None: - if "shapes" in annotation_list: - shapes = [d["type"] for d in annotation_list["shapes"]] - for shape, annotation in zip(shapes, annotation_list["shapes"]): - mask = create_mask(annotation, shape) - y_mask_li.append(mask[0]) - x_mask_li.append(mask[1]) - - suite2p_masks = convert_masks_to_suite2p_format( - [np.array([x_mask_li, y_mask_li])], (512, 512) - ) - fluo_traces = extract_signals_suite2p(yaml.safe_load(value), suite2p_masks) - + if annotation_list: + if "shapes" in annotation_list: + logger.info("Creating Masks.") + shapes = [d["type"] for d in annotation_list["shapes"]] + for shape, annotation in zip(shapes, annotation_list["shapes"]): + mask = create_mask(annotation, shape) + y_mask_li.append(mask[0]) + x_mask_li.append(mask[1]) + print("Masks created") + insert_into_database( + scan, imaging, yaml.safe_load(value), x_mask_li, y_mask_li + ) + else: + logger.warn( + "Incorrect annotation list format. This is a known bug. Please draw a line anywhere on the image and click `Submit Curated Masks`. It will be ignored in the final submission but will format the list correctly." + ) + return no_update + else: + logger.warn("No annotations to submit.") + return no_update else: return no_update diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py index 0308a4f8..83afbe1b 100644 --- a/element_calcium_imaging/plotting/utilities.py +++ b/element_calcium_imaging/plotting/utilities.py @@ -1,7 +1,25 @@ +import pathlib import datajoint as dj import numpy as np from scipy import ndimage from skimage import draw, measure +from element_interface.utils import find_full_path + + +logger = dj.logger + + +def get_imaging_root_data_dir(): + """Retrieve imaging root data directory.""" + imaging_root_dirs = dj.config.get("custom", {}).get("imaging_root_data_dir", None) + if not imaging_root_dirs: + return None + elif isinstance(imaging_root_dirs, (str, pathlib.Path)): + return [imaging_root_dirs] + elif isinstance(imaging_root_dirs, list): + return imaging_root_dirs + else: + raise TypeError("`imaging_root_data_dir` must be a string, pathlib, or list") def path_to_indices(path): @@ -88,9 +106,7 @@ def create_mask(coordinates, shape_type): xy_coordinates = np.asarray( [item for item in coordinates.values()], dtype="int" ) - mask = np.asarray( - create_ellipse_mask(xy_coordinates, (512, 512)) - ).nonzero() + mask = np.asarray(create_ellipse_mask(xy_coordinates, (512, 512))).nonzero() elif shape_type == "rect": try: mask = np.asarray( @@ -140,36 +156,7 @@ def get_contours(image_key, prefix): return contours -def convert_masks_to_suite2p_format(masks, frame_dims): - """ - Convert masks to the format expected by Suite2P. - - Parameters: - masks (list of np.ndarray): A list where each item is an array representing a mask, - with non-zero values for the ROI and zeros elsewhere. - frame_dims (tuple): The dimensions of the imaging frame, (height, width). - - Returns: - np.ndarray: A 2D array where each column represents a flattened binary mask for an ROI. - """ - # Calculate the total number of pixels in a frame - num_pixels = frame_dims[0] * frame_dims[1] - - # Initialize an empty array to store the flattened binary masks - suite2p_masks = np.zeros((num_pixels, len(masks)), dtype=np.float32) - - # Convert each mask - for idx, mask in enumerate(masks): - # Ensure the mask is binary (1 for ROI, 0 for background) - binary_mask = np.where(mask > 0, 1, 0).astype(np.float32) - - # Flatten the binary mask and add it as a column in the suite2p_masks array - suite2p_masks[:, idx] = binary_mask.flatten() - - return suite2p_masks - - -def load_imaging_data_for_session(key): +def load_imaging_data_for_session(scan, key): image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path") image_files = [ find_full_path(get_imaging_root_data_dir(), image_file) @@ -177,20 +164,57 @@ def load_imaging_data_for_session(key): ] acq_software = (scan.Scan & key).fetch1("acq_software") if acq_software == "ScanImage": - imaging_data = tifffile.imread(image_files[0]) + import tifffile + + imaging_data = tifffile.imread(image_files[0]) elif acq_software == "NIS": + import nd2 + imaging_data = nd2.imread(image_files[0]) else: - raise ValueError(f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget.") + raise ValueError( + f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget." + ) return imaging_data -def extract_signals_suite2p(key, masks): - from suite2p.extraction.extract import extrace_traces - - F, _ = extrace_traces(load_imaging_data_for_session(key), masks, neuropil_masks=np.zeros_like(masks)) - - -def insert_signals_into_datajoint(signals, session_key): - # Implement logic to insert the extracted signals into DataJoint - pass \ No newline at end of file +def insert_into_database(scan_module, imaging_module, session_key, x_masks, y_masks): + images = load_imaging_data_for_session(scan_module, session_key) + print(f"Images shape: {images.shape}") + mask_id = (imaging_module.Segmentation.Mask & session_key).fetch( + "mask_id", order_by="DESC mask_id", limit=1 + ) + print(f"Mask ID: {mask_id}") + logger.info(f"Inserting {len(x_masks)} masks into the database.") + # imaging_module.Segmentation.Mask.insert( + # [ + # dict( + # **session_key, + # mask=mask_id + mask_num, + # segmentation_channel=1, + # mask_npix=y_mask.shape[0], + # mask_center_x=int(sum(x_mask) / x_mask.shape[0]), + # mask_center_y=int(sum(y_mask) / y_mask.shape[0]), + # mask_center_z=0, + # mask_xpix=x_mask, + # mask_ypix=y_mask, + # mask_zpix=0, + # mask_weights=np.ones_like(y_mask), + # ) + # for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + # ], + # allow_direct_insert=True, + # ) + logger.info(f"Inserting {len(x_masks)} traces into the database.") + # imaging_module.Fluorescence.Trace.insert( + # [ + # dict( + # **session_key, + # mask=mask_id + mask_num, + # fluo_channel=1, + # fluorescence=images[:, y_mask, x_mask].mean(axis=1), + # ) + # for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + # ], + # allow_direct_insert=True, + # ) diff --git a/notebooks/test_widget.ipynb b/notebooks/test_widget.ipynb new file mode 100644 index 00000000..55a166d1 --- /dev/null +++ b/notebooks/test_widget.ipynb @@ -0,0 +1,50 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# change to the upper level folder to detect dj_local_conf.json\n", + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", + " os.chdir(\"..\")\n", + "\n", + "import datajoint as dj\n", + "from element_calcium_imaging.plotting.draw_rois import draw_rois" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "draw_rois(\"neuro_\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "elements", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 101b2b591c41e494220322beef18ad28e010fa20 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Mon, 8 Apr 2024 22:16:26 -0500 Subject: [PATCH 3/9] Add Dash to package requirements --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 259b1def..d02da1f5 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,8 @@ "ipykernel>=6.0.1", "ipywidgets", "plotly", + "dash-extensions", + "scikit-image", "element-interface @ git+https://github.com/datajoint/element-interface.git", ], extras_require={ From 5231850d417fd4aa5d5148d80a893d0e7616df93 Mon Sep 17 00:00:00 2001 From: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> Date: Tue, 9 Apr 2024 04:00:00 +0000 Subject: [PATCH 4/9] Improve app handling from Jupyter notebook --- element_calcium_imaging/plotting/draw_rois.py | 3 +- notebooks/test_widget.ipynb | 39 ++++++++++++++++--- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py index 8308b78a..ef49b151 100644 --- a/element_calcium_imaging/plotting/draw_rois.py +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -224,5 +224,6 @@ def submit_annotations(n_clicks, annotation_list, value): return no_update else: return no_update + + return app - app.run_server(port=8000) diff --git a/notebooks/test_widget.ipynb b/notebooks/test_widget.ipynb index 55a166d1..a5cc8a8b 100644 --- a/notebooks/test_widget.ipynb +++ b/notebooks/test_widget.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -18,12 +18,41 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "draw_rois(\"neuro_\")" + "draw_rois(\"neuro_\").run_server(debug=True, host='0.0.0.0')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -42,7 +71,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.9.17" } }, "nbformat": 4, From 7d9af74e63ab56a9172dedcf58ff4c3256d9bfe9 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Tue, 9 Apr 2024 00:23:59 -0500 Subject: [PATCH 5/9] Update database insert utility after testing --- element_calcium_imaging/plotting/utilities.py | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py index 83afbe1b..f1ab8fbb 100644 --- a/element_calcium_imaging/plotting/utilities.py +++ b/element_calcium_imaging/plotting/utilities.py @@ -180,41 +180,40 @@ def load_imaging_data_for_session(scan, key): def insert_into_database(scan_module, imaging_module, session_key, x_masks, y_masks): images = load_imaging_data_for_session(scan_module, session_key) - print(f"Images shape: {images.shape}") mask_id = (imaging_module.Segmentation.Mask & session_key).fetch( - "mask_id", order_by="DESC mask_id", limit=1 + "mask", order_by="mask desc", limit=1 ) - print(f"Mask ID: {mask_id}") logger.info(f"Inserting {len(x_masks)} masks into the database.") - # imaging_module.Segmentation.Mask.insert( - # [ - # dict( - # **session_key, - # mask=mask_id + mask_num, - # segmentation_channel=1, - # mask_npix=y_mask.shape[0], - # mask_center_x=int(sum(x_mask) / x_mask.shape[0]), - # mask_center_y=int(sum(y_mask) / y_mask.shape[0]), - # mask_center_z=0, - # mask_xpix=x_mask, - # mask_ypix=y_mask, - # mask_zpix=0, - # mask_weights=np.ones_like(y_mask), - # ) - # for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) - # ], - # allow_direct_insert=True, - # ) + imaging_module.Segmentation.Mask.insert( + [ + dict( + **session_key, + mask=mask_id + mask_num, + segmentation_channel=1, + mask_npix=y_mask.shape[0], + mask_center_x=int(sum(x_mask) / x_mask.shape[0]), + mask_center_y=int(sum(y_mask) / y_mask.shape[0]), + mask_center_z=0, + mask_xpix=x_mask, + mask_ypix=y_mask, + mask_zpix=0, + mask_weights=np.ones_like(y_mask), + ) + for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + ], + allow_direct_insert=True, + ) logger.info(f"Inserting {len(x_masks)} traces into the database.") - # imaging_module.Fluorescence.Trace.insert( - # [ - # dict( - # **session_key, - # mask=mask_id + mask_num, - # fluo_channel=1, - # fluorescence=images[:, y_mask, x_mask].mean(axis=1), - # ) - # for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) - # ], - # allow_direct_insert=True, - # ) + imaging_module.Fluorescence.Trace.insert( + [ + dict( + **session_key, + mask=mask_id + mask_num, + fluo_channel=1, + fluorescence=images[:, y_mask, x_mask].mean(axis=1), + ) + for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + ], + allow_direct_insert=True, + ) + logger.info("Inserts complete.") From 1298e08afa7e7a0f45f2f79e3a3e0a98fe3af5ef Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Tue, 9 Apr 2024 00:26:16 -0500 Subject: [PATCH 6/9] Remove testing notebook --- notebooks/test_widget.ipynb | 79 ------------------------------------- 1 file changed, 79 deletions(-) delete mode 100644 notebooks/test_widget.ipynb diff --git a/notebooks/test_widget.ipynb b/notebooks/test_widget.ipynb deleted file mode 100644 index a5cc8a8b..00000000 --- a/notebooks/test_widget.ipynb +++ /dev/null @@ -1,79 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "# change to the upper level folder to detect dj_local_conf.json\n", - "if os.path.basename(os.getcwd()) == \"notebooks\":\n", - " os.chdir(\"..\")\n", - "\n", - "import datajoint as dj\n", - "from element_calcium_imaging.plotting.draw_rois import draw_rois" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "draw_rois(\"neuro_\").run_server(debug=True, host='0.0.0.0')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "elements", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.17" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 8671a62ca66278cdd81594cb81e64429abcc1fff Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Tue, 9 Apr 2024 00:26:59 -0500 Subject: [PATCH 7/9] Apply black formatting --- element_calcium_imaging/plotting/draw_rois.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py index ef49b151..0b8481d4 100644 --- a/element_calcium_imaging/plotting/draw_rois.py +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -224,6 +224,5 @@ def submit_annotations(n_clicks, annotation_list, value): return no_update else: return no_update - - return app + return app From 594587b2521dee95a9d13444d5a947c805084845 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Tue, 9 Apr 2024 14:44:56 -0500 Subject: [PATCH 8/9] Add documentation for the widget --- docs/src/roadmap.md | 2 ++ docs/src/tutorials/index.md | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/docs/src/roadmap.md b/docs/src/roadmap.md index 41715c9e..518209a8 100644 --- a/docs/src/roadmap.md +++ b/docs/src/roadmap.md @@ -18,6 +18,8 @@ the common motifs to create Element Calcium Imaging. Major features include: - [ ] Deepinterpolation - [x] Data export to NWB - [x] Data publishing to DANDI +- [x] Widgets for manual ROI mask creation and curation for cell segmentation of Fluorescent voltage sensitive indicators, neurotransmitter imaging, and neuromodulator imaging +- [ ] Expand creation widget to provide pixel weights for each mask based on Fluorescence intensity traces at each pixel Further development of this Element is community driven. Upon user requests and based on guidance from the Scientific Steering Group we will continue adding features to this diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index f7bb8108..8329add9 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -29,3 +29,8 @@ please set `processing_method="extract"` in the ProcessingParamSet table, and provide the `params` attribute of the ProcessingParamSet table in the `{'suite2p': {...}, 'extract': {...}}` dictionary format. Please also install the [MATLAB engine](https://pypi.org/project/matlabengine/) API for Python. + +## Manual ROI Mask Creation and Curation + ++ Manual creation of ROI masks for fluorescence activity extraction is supported by the `draw_rois.py` plotly/dash widget. This widget allows the user to draw new ROI masks and submit them to the database. The widget can be launched in a Jupyter notebook after following the [installation instructions](#installation-instructions-for-active-projects) and importing `draw_rois` from the module. ++ ROI masks can be curated using the `widget.py` jupyter widget that allows the user to mark each regions as either a `cell` or `non-cell`. This widget can be launched in a Jupyter notebook after following the [installation instructions](#installation-instructions-for-active-projects) and importing `main` from the module. From df2c8341e2c97f0ebe860af8baf2419fc2475ad4 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Tue, 9 Apr 2024 14:47:56 -0500 Subject: [PATCH 9/9] Update CHANGELOG and version --- CHANGELOG.md | 8 ++++++++ element_calcium_imaging/version.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b54b8e0..6791dff9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.10.0] - 2024-04-09 + ++ Add - ROI mask creation widget ++ Update documentation for using the included widgets in the package + ## [0.9.5] - 2024-03-22 + Add - pytest @@ -209,6 +214,9 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + Add - `scan` and `imaging` modules + Add - Readers for `ScanImage`, `ScanBox`, `Suite2p`, `CaImAn` +[0.10.0]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.10.0 +[0.9.5]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.5 +[0.9.4]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.4 [0.9.3]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.3 [0.9.2]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.2 [0.9.1]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.1 diff --git a/element_calcium_imaging/version.py b/element_calcium_imaging/version.py index 3e909f53..e51b9d68 100644 --- a/element_calcium_imaging/version.py +++ b/element_calcium_imaging/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.9.5" +__version__ = "0.10.0"