diff --git a/.gitignore b/.gitignore index 0f635c3c..b94ad25a 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ example_data # vscode *.code-workspace + +# dash widget +file_system_backend \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b54b8e0..6791dff9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.10.0] - 2024-04-09 + ++ Add - ROI mask creation widget ++ Update documentation for using the included widgets in the package + ## [0.9.5] - 2024-03-22 + Add - pytest @@ -209,6 +214,9 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + Add - `scan` and `imaging` modules + Add - Readers for `ScanImage`, `ScanBox`, `Suite2p`, `CaImAn` +[0.10.0]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.10.0 +[0.9.5]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.5 +[0.9.4]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.4 [0.9.3]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.3 [0.9.2]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.2 [0.9.1]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.1 diff --git a/docs/src/roadmap.md b/docs/src/roadmap.md index 41715c9e..518209a8 100644 --- a/docs/src/roadmap.md +++ b/docs/src/roadmap.md @@ -18,6 +18,8 @@ the common motifs to create Element Calcium Imaging. Major features include: - [ ] Deepinterpolation - [x] Data export to NWB - [x] Data publishing to DANDI +- [x] Widgets for manual ROI mask creation and curation for cell segmentation of Fluorescent voltage sensitive indicators, neurotransmitter imaging, and neuromodulator imaging +- [ ] Expand creation widget to provide pixel weights for each mask based on Fluorescence intensity traces at each pixel Further development of this Element is community driven. Upon user requests and based on guidance from the Scientific Steering Group we will continue adding features to this diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index f7bb8108..8329add9 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -29,3 +29,8 @@ please set `processing_method="extract"` in the ProcessingParamSet table, and provide the `params` attribute of the ProcessingParamSet table in the `{'suite2p': {...}, 'extract': {...}}` dictionary format. Please also install the [MATLAB engine](https://pypi.org/project/matlabengine/) API for Python. + +## Manual ROI Mask Creation and Curation + ++ Manual creation of ROI masks for fluorescence activity extraction is supported by the `draw_rois.py` plotly/dash widget. This widget allows the user to draw new ROI masks and submit them to the database. The widget can be launched in a Jupyter notebook after following the [installation instructions](#installation-instructions-for-active-projects) and importing `draw_rois` from the module. ++ ROI masks can be curated using the `widget.py` jupyter widget that allows the user to mark each regions as either a `cell` or `non-cell`. This widget can be launched in a Jupyter notebook after following the [installation instructions](#installation-instructions-for-active-projects) and importing `main` from the module. diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py new file mode 100644 index 00000000..0b8481d4 --- /dev/null +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -0,0 +1,228 @@ +import yaml +import datajoint as dj +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +from dash import no_update +from dash_extensions.enrich import ( + DashProxy, + Input, + Output, + State, + html, + dcc, + Serverside, + ServersideOutputTransform, +) + +from .utilities import * + + +logger = dj.logger + + +def draw_rois(db_prefix: str): + scan = dj.create_virtual_module("scan", f"{db_prefix}scan") + imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging") + all_keys = (imaging.MotionCorrection).fetch("KEY") + + colors = {"background": "#111111", "text": "#00a0df"} + + app = DashProxy(transforms=[ServersideOutputTransform()]) + app.layout = html.Div( + [ + html.H2("Draw ROIs", style={"color": colors["text"]}), + html.Label( + "Select data key from dropdown", style={"color": colors["text"]} + ), + dcc.Dropdown( + id="toplevel-dropdown", options=[str(key) for key in all_keys] + ), + html.Br(), + html.Div( + [ + html.Button( + "Load Image", + id="load-image-button", + style={"margin-right": "20px"}, + ), + dcc.RadioItems( + id="image-type-radio", + options=[ + {"label": "Average Image", "value": "average_image"}, + { + "label": "Max Projection Image", + "value": "max_projection_image", + }, + ], + value="average_image", + labelStyle={"display": "inline-block", "margin-right": "10px"}, + style={"display": "inline-block", "color": colors["text"]}, + ), + html.Div( + [ + html.Button("Submit Curated Masks", id="submit-button"), + ], + style={ + "textAlign": "right", + "flex": "1", + "display": "inline-block", + }, + ), + ], + style={ + "display": "flex", + "justify-content": "flex-start", + "align-items": "center", + }, + ), + html.Br(), + html.Br(), + html.Div( + [ + dcc.Graph( + id="avg-image", + config={ + "modeBarButtonsToAdd": [ + "drawclosedpath", + "drawrect", + "drawcircle", + "drawline", + "eraseshape", + ], + }, + style={"width": "100%", "height": "100%"}, + ) + ], + style={ + "display": "flex", + "justify-content": "center", + "align-items": "center", + "padding": "0.0", + "margin": "auto", + }, + ), + html.Pre(id="annotations"), + html.Div(id="button-output"), + dcc.Store(id="store-key"), + dcc.Store(id="store-mask"), + dcc.Store(id="store-movie"), + html.Div(id="submit-output"), + ] + ) + + @app.callback( + Output("store-key", "value"), + Input("toplevel-dropdown", "value"), + ) + def store_key(value): + if value is not None: + return Serverside(value) + else: + return no_update + + @app.callback( + Output("avg-image", "figure"), + Output("store-movie", "average_images"), + State("store-key", "value"), + Input("load-image-button", "n_clicks"), + Input("image-type-radio", "value"), + prevent_initial_call=True, + ) + def create_figure(value, render_n_clicks, image_type): + if render_n_clicks is not None: + if image_type == "average_image": + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("average_image") + else: + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("max_proj_image") + average_images = [image.astype("float") for image in summary_images] + roi_contours = get_contours(yaml.safe_load(value), db_prefix) + logger.info("Generating figure.") + fig = px.imshow( + np.asarray(average_images), + animation_frame=0, + binary_string=True, + labels=dict(animation_frame="plane"), + ) + for contour in roi_contours: + # Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates + fig.add_trace( + go.Scatter( + x=contour[:, 1], # Plotly uses x, y order for coordinates + y=contour[:, 0], + mode="lines", # Display as lines (not markers) + line=dict(color="white", width=0.5), # Set line color and width + showlegend=False, # Do not show legend for each contour + ) + ) + fig.update_layout( + dragmode="drawrect", + autosize=True, + height=550, + newshape=dict(opacity=0.6, fillcolor="#00a0df"), + plot_bgcolor=colors["background"], + paper_bgcolor=colors["background"], + font_color=colors["text"], + ) + fig.update_annotations(bgcolor="#00a0df") + else: + return no_update + return fig, Serverside(average_images) + + @app.callback( + Output("store-mask", "annotation_list"), + Input("avg-image", "relayoutData"), + prevent_initial_call=True, + ) + def on_relayout(relayout_data): + if not relayout_data: + return no_update + else: + if "shapes" in relayout_data: + global shape_type + try: + shape_type = relayout_data["shapes"][-1]["type"] + return Serverside(relayout_data) + except IndexError: + return no_update + elif any(["shapes" in key for key in relayout_data]): + return Serverside(relayout_data) + + @app.callback( + Output("submit-output", "children"), + Input("submit-button", "n_clicks"), + State("store-mask", "annotation_list"), + State("store-key", "value"), + ) + def submit_annotations(n_clicks, annotation_list, value): + x_mask_li = [] + y_mask_li = [] + if n_clicks is not None: + if annotation_list: + if "shapes" in annotation_list: + logger.info("Creating Masks.") + shapes = [d["type"] for d in annotation_list["shapes"]] + for shape, annotation in zip(shapes, annotation_list["shapes"]): + mask = create_mask(annotation, shape) + y_mask_li.append(mask[0]) + x_mask_li.append(mask[1]) + print("Masks created") + insert_into_database( + scan, imaging, yaml.safe_load(value), x_mask_li, y_mask_li + ) + else: + logger.warn( + "Incorrect annotation list format. This is a known bug. Please draw a line anywhere on the image and click `Submit Curated Masks`. It will be ignored in the final submission but will format the list correctly." + ) + return no_update + else: + logger.warn("No annotations to submit.") + return no_update + else: + return no_update + + return app diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py new file mode 100644 index 00000000..f1ab8fbb --- /dev/null +++ b/element_calcium_imaging/plotting/utilities.py @@ -0,0 +1,219 @@ +import pathlib +import datajoint as dj +import numpy as np +from scipy import ndimage +from skimage import draw, measure +from element_interface.utils import find_full_path + + +logger = dj.logger + + +def get_imaging_root_data_dir(): + """Retrieve imaging root data directory.""" + imaging_root_dirs = dj.config.get("custom", {}).get("imaging_root_data_dir", None) + if not imaging_root_dirs: + return None + elif isinstance(imaging_root_dirs, (str, pathlib.Path)): + return [imaging_root_dirs] + elif isinstance(imaging_root_dirs, list): + return imaging_root_dirs + else: + raise TypeError("`imaging_root_data_dir` must be a string, pathlib, or list") + + +def path_to_indices(path): + """From SVG path to numpy array of coordinates, each row being a (row, col) point""" + indices_str = [ + el.replace("M", "").replace("Z", "").split(",") for el in path.split("L") + ] + return np.rint(np.array(indices_str, dtype=float)).astype(int) + + +def path_to_mask(path, shape): + """From SVG path to a boolean array where all pixels enclosed by the path + are True, and the other pixels are False. + """ + cols, rows = path_to_indices(path).T + rr, cc = draw.polygon(rows, cols) + mask = np.zeros(shape, dtype=bool) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + return mask + + +def create_ellipse_mask(vertices, image_shape): + """ + Create a mask for an ellipse given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the bounding box of the ellipse. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the ellipse. + """ + x0, x1, y0, y1 = vertices + center = ((x0 + x1) / 2, (y0 + y1) / 2) + axis_lengths = (abs(x1 - x0) / 2, abs(y1 - y0) / 2) + + rr, cc = draw.ellipse( + center[1], center[0], axis_lengths[1], axis_lengths[0], shape=image_shape + ) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_rectangle_mask(vertices, image_shape): + """ + Create a mask for a rectangle given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the top-left and bottom-right corners of the rectangle. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the rectangle. + """ + x0, x1, y0, y1 = vertices + rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1), shape=image_shape) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_mask(coordinates, shape_type): + if shape_type == "path": + try: + mask = np.asarray(path_to_mask(coordinates["path"], (512, 512))).nonzero() + except KeyError: + for key, info in coordinates.items(): + mask = np.asarray(path_to_mask(info, (512, 512))).nonzero() + + elif shape_type == "circle": + try: + mask = np.asarray( + create_ellipse_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray(create_ellipse_mask(xy_coordinates, (512, 512))).nonzero() + elif shape_type == "rect": + try: + mask = np.asarray( + create_rectangle_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray( + create_rectangle_mask(xy_coordinates, (512, 512)) + ).nonzero() + elif shape_type == "line": + try: + mask = np.array( + ( + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ) + ) + except KeyError: + mask = np.asarray([item for item in coordinates.values()], dtype="int") + return mask + + +def get_contours(image_key, prefix): + scan = dj.create_virtual_module("scan", f"{prefix}scan") + imaging = dj.create_virtual_module("imaging", f"{prefix}imaging") + yshape, xshape = (scan.ScanInfo.Field & image_key).fetch1("px_height", "px_width") + mask_xpix, mask_ypix = (imaging.Segmentation.Mask & image_key).fetch( + "mask_xpix", "mask_ypix" + ) + mask_image = np.zeros((yshape, xshape), dtype=bool) + for xpix, ypix in zip(mask_xpix, mask_ypix): + mask_image[ypix, xpix] = True + contours = measure.find_contours(mask_image.astype(float), 0.5) + return contours + + +def load_imaging_data_for_session(scan, key): + image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path") + image_files = [ + find_full_path(get_imaging_root_data_dir(), image_file) + for image_file in image_files + ] + acq_software = (scan.Scan & key).fetch1("acq_software") + if acq_software == "ScanImage": + import tifffile + + imaging_data = tifffile.imread(image_files[0]) + elif acq_software == "NIS": + import nd2 + + imaging_data = nd2.imread(image_files[0]) + else: + raise ValueError( + f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget." + ) + return imaging_data + + +def insert_into_database(scan_module, imaging_module, session_key, x_masks, y_masks): + images = load_imaging_data_for_session(scan_module, session_key) + mask_id = (imaging_module.Segmentation.Mask & session_key).fetch( + "mask", order_by="mask desc", limit=1 + ) + logger.info(f"Inserting {len(x_masks)} masks into the database.") + imaging_module.Segmentation.Mask.insert( + [ + dict( + **session_key, + mask=mask_id + mask_num, + segmentation_channel=1, + mask_npix=y_mask.shape[0], + mask_center_x=int(sum(x_mask) / x_mask.shape[0]), + mask_center_y=int(sum(y_mask) / y_mask.shape[0]), + mask_center_z=0, + mask_xpix=x_mask, + mask_ypix=y_mask, + mask_zpix=0, + mask_weights=np.ones_like(y_mask), + ) + for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + ], + allow_direct_insert=True, + ) + logger.info(f"Inserting {len(x_masks)} traces into the database.") + imaging_module.Fluorescence.Trace.insert( + [ + dict( + **session_key, + mask=mask_id + mask_num, + fluo_channel=1, + fluorescence=images[:, y_mask, x_mask].mean(axis=1), + ) + for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + ], + allow_direct_insert=True, + ) + logger.info("Inserts complete.") diff --git a/element_calcium_imaging/version.py b/element_calcium_imaging/version.py index 3e909f53..e51b9d68 100644 --- a/element_calcium_imaging/version.py +++ b/element_calcium_imaging/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.9.5" +__version__ = "0.10.0" diff --git a/setup.py b/setup.py index 2586295c..4dfed13a 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ "ipykernel>=6.0.1", "ipywidgets", "plotly", + "dash-extensions", + "scikit-image", "element-interface @ git+https://github.com/datajoint/element-interface.git", ], extras_require={