From a720c6c54605589b54366f2686ed0b3cbfec6654 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Wed, 27 Mar 2024 14:34:24 -0500 Subject: [PATCH] Add ROI drawing widget for manual curation --- .gitignore | 3 + element_calcium_imaging/plotting/draw_rois.py | 201 ++++++++++++++++++ element_calcium_imaging/plotting/utilities.py | 196 +++++++++++++++++ 3 files changed, 400 insertions(+) create mode 100644 element_calcium_imaging/plotting/draw_rois.py create mode 100644 element_calcium_imaging/plotting/utilities.py diff --git a/.gitignore b/.gitignore index 0f635c3c..b94ad25a 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ example_data # vscode *.code-workspace + +# dash widget +file_system_backend \ No newline at end of file diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py new file mode 100644 index 00000000..8b63a981 --- /dev/null +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -0,0 +1,201 @@ +import yaml +import datajoint as dj +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +from dash import no_update +from dash_extensions.enrich import ( + DashProxy, + Input, + Output, + State, + html, + dcc, + Serverside, + ServersideOutputTransform, +) +from scipy import ndimage +from skimage import draw, measure +from tifffile import TiffFile + +from .utilities import * + + +logger = dj.logger + + +def draw_rois(db_prefix: str): + imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging") + all_keys = (imaging.MotionCorrection).fetch("KEY") + + colors = {"background": "#111111", "text": "#00a0df"} + + app = DashProxy(transforms=[ServersideOutputTransform()]) + app.layout = html.Div( + [ + html.H2("Draw ROIs", style={"color": colors["text"]}), + html.Label("Select data key from dropdown", style={"color": colors["text"]}), + dcc.Dropdown( + id="toplevel-dropdown", options=[str(key) for key in all_keys] + ), + html.Br(), + html.Div( + [ + html.Button("Load Image", id="load-image-button", style={"margin-right": "20px"}), + dcc.RadioItems( + id='image-type-radio', + options=[ + {'label': 'Average Image', 'value': 'average_image'}, + {'label': 'Max Projection Image', 'value': 'max_projection_image'} + ], + value='average_image', # Default value + labelStyle={'display': 'inline-block', 'margin-right': '10px'}, # Inline display with some margin + style={'display': 'inline-block', 'color': colors['text']} # Inline display to keep it on the same line + ), + html.Div( + [ + html.Button("Submit Curated Masks", id="submit-button"), + ], + style={"textAlign": "right", "flex": "1", "display": "inline-block"}, + ), + ], + style={ + "display": "flex", + "justify-content": "flex-start", + "align-items": "center", + }, + ), + html.Br(), + html.Br(), + html.Div( + [ + dcc.Graph( + id="avg-image", + config={ + "modeBarButtonsToAdd": [ + "drawclosedpath", + "drawrect", + "drawcircle", + "eraseshape", + ], + }, + style={"width": "100%", "height": "100%"}, + ) + ], + style={ + "display": "flex", + "justify-content": "center", # Centers the child horizontally + "align-items": "center", # Centers the child vertically (if you have vertical space to work with) + "padding": "0.0", + "margin": "auto" # Automatically adjust the margins to center the div + }, + ), + html.Pre(id="annotations"), + html.Div(id="button-output"), + dcc.Store(id="store-key"), + dcc.Store(id="store-mask"), + dcc.Store(id="store-movie"), + html.Div(id="submit-output"), + ] + ) + + @app.callback( + Output("store-key", "value"), + Input("toplevel-dropdown", "value"), + ) + def store_key(value): + if value is not None: + return Serverside(value) + else: + return no_update + + @app.callback( + Output("avg-image", "figure"), + Output("store-movie", "average_images"), + State("store-key", "value"), + Input("load-image-button", "n_clicks"), + Input("image-type-radio", "value"), + prevent_initial_call=True, + ) + def create_figure(value, render_n_clicks, image_type): + if render_n_clicks is not None: + if image_type == "average_image": + summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("average_image") + else: + summary_images = (imaging.MotionCorrection.Summary & yaml.safe_load(value)).fetch("max_proj_image") + average_images = [image.astype("float") for image in summary_images] + roi_contours = get_contours(yaml.safe_load(value), db_prefix) + logger.info("Generating figure.") + fig = px.imshow(np.asarray(average_images), animation_frame=0, binary_string=True, labels=dict(animation_frame="plane")) + for contour in roi_contours: + # Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates + fig.add_trace( + go.Scatter( + x=contour[:, 1], # Plotly uses x, y order for coordinates + y=contour[:, 0], + mode="lines", # Display as lines (not markers) + line=dict(color="white", width=0.5), # Set line color and width + showlegend=False, # Do not show legend for each contour + ) + ) + fig.update_layout( + dragmode="drawrect", + autosize=True, + height=550, + newshape=dict(opacity=0.6, fillcolor="#00a0df"), + plot_bgcolor=colors["background"], + paper_bgcolor=colors["background"], + font_color=colors["text"], + ) + fig.update_annotations(bgcolor="#00a0df") + else: + return no_update + return fig, Serverside(average_images) + + @app.callback( + Output("store-mask", "annotation_list"), + Input("avg-image", "relayoutData"), + prevent_initial_call=True, + ) + def on_relayout(relayout_data): + if not relayout_data: + return no_update + else: + if "shapes" in relayout_data: + global shape_type + try: + shape_type = relayout_data["shapes"][-1]["type"] + return Serverside(relayout_data) + except IndexError: + return no_update + elif any(["shapes" in key for key in relayout_data]): + return Serverside(relayout_data) + + + @app.callback( + Output("submit-output", "children"), + Input("submit-button", "n_clicks"), + State("store-mask", "annotation_list"), + State("store-key", "value") + ) + def submit_annotations(n_clicks, annotation_list, value): + print("submitting annotations") + x_mask_li = [] + y_mask_li = [] + if n_clicks is not None: + if "shapes" in annotation_list: + shapes = [d["type"] for d in annotation_list["shapes"]] + for shape, annotation in zip(shapes, annotation_list["shapes"]): + mask = create_mask(annotation, shape) + y_mask_li.append(mask[0]) + x_mask_li.append(mask[1]) + + suite2p_masks = convert_masks_to_suite2p_format( + [np.array([x_mask_li, y_mask_li])], (512, 512) + ) + fluo_traces = extract_signals_suite2p(yaml.safe_load(value), suite2p_masks) + + else: + return no_update + + app.run_server(port=8000) diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py new file mode 100644 index 00000000..0308a4f8 --- /dev/null +++ b/element_calcium_imaging/plotting/utilities.py @@ -0,0 +1,196 @@ +import datajoint as dj +import numpy as np +from scipy import ndimage +from skimage import draw, measure + + +def path_to_indices(path): + """From SVG path to numpy array of coordinates, each row being a (row, col) point""" + indices_str = [ + el.replace("M", "").replace("Z", "").split(",") for el in path.split("L") + ] + return np.rint(np.array(indices_str, dtype=float)).astype(int) + + +def path_to_mask(path, shape): + """From SVG path to a boolean array where all pixels enclosed by the path + are True, and the other pixels are False. + """ + cols, rows = path_to_indices(path).T + rr, cc = draw.polygon(rows, cols) + mask = np.zeros(shape, dtype=bool) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + return mask + + +def create_ellipse_mask(vertices, image_shape): + """ + Create a mask for an ellipse given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the bounding box of the ellipse. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the ellipse. + """ + x0, x1, y0, y1 = vertices + center = ((x0 + x1) / 2, (y0 + y1) / 2) + axis_lengths = (abs(x1 - x0) / 2, abs(y1 - y0) / 2) + + rr, cc = draw.ellipse( + center[1], center[0], axis_lengths[1], axis_lengths[0], shape=image_shape + ) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_rectangle_mask(vertices, image_shape): + """ + Create a mask for a rectangle given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the top-left and bottom-right corners of the rectangle. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the rectangle. + """ + x0, x1, y0, y1 = vertices + rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1), shape=image_shape) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_mask(coordinates, shape_type): + if shape_type == "path": + try: + mask = np.asarray(path_to_mask(coordinates["path"], (512, 512))).nonzero() + except KeyError: + for key, info in coordinates.items(): + mask = np.asarray(path_to_mask(info, (512, 512))).nonzero() + + elif shape_type == "circle": + try: + mask = np.asarray( + create_ellipse_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray( + create_ellipse_mask(xy_coordinates, (512, 512)) + ).nonzero() + elif shape_type == "rect": + try: + mask = np.asarray( + create_rectangle_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray( + create_rectangle_mask(xy_coordinates, (512, 512)) + ).nonzero() + elif shape_type == "line": + try: + mask = np.array( + ( + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ) + ) + except KeyError: + mask = np.asarray([item for item in coordinates.values()], dtype="int") + return mask + + +def get_contours(image_key, prefix): + scan = dj.create_virtual_module("scan", f"{prefix}scan") + imaging = dj.create_virtual_module("imaging", f"{prefix}imaging") + yshape, xshape = (scan.ScanInfo.Field & image_key).fetch1("px_height", "px_width") + mask_xpix, mask_ypix = (imaging.Segmentation.Mask & image_key).fetch( + "mask_xpix", "mask_ypix" + ) + mask_image = np.zeros((yshape, xshape), dtype=bool) + for xpix, ypix in zip(mask_xpix, mask_ypix): + mask_image[ypix, xpix] = True + contours = measure.find_contours(mask_image.astype(float), 0.5) + return contours + + +def convert_masks_to_suite2p_format(masks, frame_dims): + """ + Convert masks to the format expected by Suite2P. + + Parameters: + masks (list of np.ndarray): A list where each item is an array representing a mask, + with non-zero values for the ROI and zeros elsewhere. + frame_dims (tuple): The dimensions of the imaging frame, (height, width). + + Returns: + np.ndarray: A 2D array where each column represents a flattened binary mask for an ROI. + """ + # Calculate the total number of pixels in a frame + num_pixels = frame_dims[0] * frame_dims[1] + + # Initialize an empty array to store the flattened binary masks + suite2p_masks = np.zeros((num_pixels, len(masks)), dtype=np.float32) + + # Convert each mask + for idx, mask in enumerate(masks): + # Ensure the mask is binary (1 for ROI, 0 for background) + binary_mask = np.where(mask > 0, 1, 0).astype(np.float32) + + # Flatten the binary mask and add it as a column in the suite2p_masks array + suite2p_masks[:, idx] = binary_mask.flatten() + + return suite2p_masks + + +def load_imaging_data_for_session(key): + image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path") + image_files = [ + find_full_path(get_imaging_root_data_dir(), image_file) + for image_file in image_files + ] + acq_software = (scan.Scan & key).fetch1("acq_software") + if acq_software == "ScanImage": + imaging_data = tifffile.imread(image_files[0]) + elif acq_software == "NIS": + imaging_data = nd2.imread(image_files[0]) + else: + raise ValueError(f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget.") + return imaging_data + + +def extract_signals_suite2p(key, masks): + from suite2p.extraction.extract import extrace_traces + + F, _ = extrace_traces(load_imaging_data_for_session(key), masks, neuropil_masks=np.zeros_like(masks)) + + +def insert_signals_into_datajoint(signals, session_key): + # Implement logic to insert the extracted signals into DataJoint + pass \ No newline at end of file