Skip to content

Commit

Permalink
Merge pull request datajoint#5 from kushalbakshi/lu-lab
Browse files Browse the repository at this point in the history
Add symlink to `Processing` for `ZDriftMetrics`
  • Loading branch information
ttngu207 authored Oct 6, 2023
2 parents 94249dc + 3438f79 commit b51c500
Showing 1 changed file with 54 additions and 31 deletions.
85 changes: 54 additions & 31 deletions element_calcium_imaging/imaging_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,45 +291,44 @@ def infer_output_dir(cls, key, relative=False, mkdir=False):
@schema
class ZDriftParamSet(dj.Manual):
definition = """
paramset_idx: int
zdrift_paramset_idx: int
---
paramset_desc: varchar(1280) # Parameter-set description
param_set_hash: uuid # A universally unique identifier for the parameter set
unique index (param_set_hash)
params: longblob # Parameter Set, a dictionary of all z-drift parameters.
z_paramset_desc: varchar(1280) # Parameter-set description
z_param_set_hash: uuid # # A universally unique identifier for the parameter set.
z_params: longblob # Parameter Set, a dictionary of all z-drift parameters.
"""

@classmethod
def insert_new_params(
cls,
paramset_idx: int,
paramset_desc: str,
params: dict,
zdrift_paramset_idx: int,
z_paramset_desc: str,
z_params: dict,
):
"""Insert a parameter set into ProcessingParamSet table.
"""Insert a parameter set into ZDriftParamSet table.
This function automates the parameter set hashing and avoids insertion of an
existing parameter set.
Attributes:
processing_method (str): Processing method/package used for processing of
calcium imaging.
paramset_idx (int): Unique parameter set ID.
paramset_desc (str): Parameter set description.
params (dict): Parameter Set, all applicable parameters to the
zdrift_paramset_idx (int): Unique parameter set ID.
z_paramset_desc (str): Parameter set description.
z_params (dict): Parameter Set, all applicable parameters to the
z-axis correlation analysis.
"""
param_dict = {
"paramset_idx": paramset_idx,
"paramset_desc": paramset_desc,
"params": params,
"param_set_hash": dict_to_uuid(params),
"zdrift_paramset_idx": zdrift_paramset_idx,
"z_paramset_desc": z_paramset_desc,
"z_params": z_params,
"z_param_set_hash": dict_to_uuid(z_params),
}
q_param = cls & {"param_set_hash": param_dict["param_set_hash"]}
q_param = cls & {"z_param_set_hash": param_dict["z_param_set_hash"]}

if q_param: # If the specified param-set already exists
p_name = q_param.fetch1("paramset_idx")
if p_name == paramset_idx: # If the existed set has the same name: job done
p_name = q_param.fetch1("zdrift_paramset_idx")
if (
p_name == zdrift_paramset_idx
): # If the existed set has the same name: job done
return
else: # If not same name: human error, trying to add the same paramset with different name
raise dj.DataJointError(
Expand All @@ -353,6 +352,7 @@ class ZDriftMetrics(dj.Computed):
-> scan.Scan
-> ZDriftParamSet
---
bad_frames=NULL: longblob # `True` if any value in z_drift > threshold from drift_params.
z_drift: longblob # Amount of drift in microns per frame in Z direction.
"""

Expand All @@ -363,7 +363,7 @@ def _make_taper(size, width):
return np.convolve(m, k, mode="full") / k.sum()

nchannels = (scan.ScanInfo & key).fetch1("nchannels")
drift_params = (ZDriftParamSet & key).fetch1("params")
drift_params = (ZDriftParamSet & key).fetch1("z_params")
image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
image_files = [
find_full_path(get_imaging_root_data_dir()[0], image_file)
Expand Down Expand Up @@ -455,8 +455,10 @@ def _make_taper(size, width):
"slice_interval"
]

bad_frames_idx = np.where(drift >= drift_params["bad_frames_threshold"])[0]

self.insert1(
dict(**key, z_drift=drift),
dict(**key, bad_frames=bad_frames_idx, z_drift=drift),
)


Expand Down Expand Up @@ -485,7 +487,7 @@ def key_source(self):
"""Limit the Processing to Scans that have their metadata ingested to the
database."""

return ProcessingTask & scan.ScanInfo
return ProcessingTask & scan.ScanInfo & ZDriftMetrics

def make(self, key):
"""Execute the calcium imaging analysis defined by the ProcessingTask."""
Expand Down Expand Up @@ -522,16 +524,38 @@ def make(self, key):
else:
raise NotImplementedError("Unknown method: {}".format(method))
elif task_mode == "trigger":
try:
drop_frames = (ZDriftMetrics & key).fetch1("bad_frames")
except dj.DataJointError:
raise dj.DataJointError(
"Processing more than 1 set of `bad_frames` per scan is not currently supported."
)
if drop_frames.size > 0:
np.save(pathlib.Path(output_dir) / "bad_frames.npy", drop_frames)
raw_image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
files_to_link = [
find_full_path(get_imaging_root_data_dir(), raw_image_file)
for raw_image_file in raw_image_files
]
image_files = []
for file in files_to_link:
if not (pathlib.Path(output_dir) / file.name).is_symlink():
(pathlib.Path(output_dir) / file.name).symlink_to(file)
image_files.append((pathlib.Path(output_dir) / file.name))
else:
image_files.append((pathlib.Path(output_dir) / file.name))

else:
image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
image_files = [
find_full_path(get_imaging_root_data_dir(), image_file)
for image_file in image_files
]

method = (ProcessingParamSet * ProcessingTask & key).fetch1(
"processing_method"
)

image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
image_files = [
find_full_path(get_imaging_root_data_dir(), image_file)
for image_file in image_files
]

if method == "suite2p":
import suite2p

Expand All @@ -552,7 +576,6 @@ def make(self, key):
"data_path": [image_files[0].parent.as_posix()],
"tiff_list": [f.as_posix() for f in image_files],
}

suite2p.run_s2p(ops=suite2p_params, db=suite2p_paths) # Run suite2p

_, imaging_dataset = get_loader_result(key, ProcessingTask)
Expand Down

0 comments on commit b51c500

Please sign in to comment.