Skip to content

Commit

Permalink
fix: update key_source, update FieldPostProcessing
Browse files Browse the repository at this point in the history
  • Loading branch information
ttngu207 committed Mar 18, 2024
1 parent ece145c commit f1bbafb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
52 changes: 30 additions & 22 deletions element_calcium_imaging/field_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@
import pathlib
from collections.abc import Callable
from datetime import datetime
import re

import datajoint as dj
import numpy as np
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory

from . import scan
from .scan import (
get_calcium_imaging_files,
get_imaging_root_data_dir,
get_processed_root_data_dir,
)

log = dj.logger()
log = dj.logger

schema = dj.schema()

Expand All @@ -30,8 +26,8 @@ def activate(
create_tables=True,
):
"""
activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None)
:param schema_name: schema name on the database server to activate the `spike_sorting` schema
activate(schema_name, *, imaging_module, create_schema=True, create_tables=True)
:param schema_name: schema name on the database server to activate the `field_processing` schema
:param imaging_module: the activated imaging element for which this `processing` schema will be downstream from
:param create_schema: when True (default), create schema in the database if it does not yet exist.
:param create_tables: when True (default), create tables in the database if they do not yet exist.
Expand All @@ -44,6 +40,7 @@ def activate(
create_tables=create_tables,
add_objects=imaging.__dict__,
)
imaging.Processing.key_source -= FieldPreprocessing.key_source.proj()


# ---------------- Multi-field Processing (per-field basis) ----------------
Expand Down Expand Up @@ -83,9 +80,10 @@ def key_source(self):

def make(self, key):
execution_time = datetime.utcnow()
processed_root_data_dir = scan.get_processed_root_data_dir()

output_dir = (imaging.ProcessingTask & key).fetch1("processing_output_dir")
output_dir = find_full_path(get_imaging_root_data_dir(), output_dir)
output_dir = find_full_path(processed_root_data_dir, output_dir)

method, params = (
imaging.ProcessingTask * imaging.ProcessingParamSet & key
Expand Down Expand Up @@ -141,7 +139,9 @@ def make(self, key):
"image_files": [f.as_posix() for f in image_files],
},
},
"processing_output_dir": pln_output_dir,
"processing_output_dir": pln_output_dir.relative_to(
processed_root_data_dir
).as_posix(),
}
)
elif method == "suite2p" and acq_software == "ScanImage" and nrois > 0:
Expand All @@ -150,7 +150,7 @@ def make(self, key):

image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path")
image_files = [
find_full_path(get_imaging_root_data_dir(), image_file).as_posix()
find_full_path(scan.get_imaging_root_data_dir(), image_file).as_posix()
for image_file in image_files
]

Expand All @@ -170,6 +170,7 @@ def make(self, key):

ops.update(
{
"mesoscan": True,
"input_format": "mesoscan",
"nrois": nfields,
"dx": [], # x-offset for each field
Expand All @@ -185,7 +186,6 @@ def make(self, key):
ops["lines"].append(
np.arange(field_info.yslices[0].start, field_info.yslices[0].stop)
)
ops["extra_dj_params"] = {"field_idx": field_idx}

# generate binary files for each field
save_folder = output_dir / ops["save_folder"]
Expand All @@ -198,13 +198,20 @@ def make(self, key):
field_processing_tasks = []
for ops_path in ops_paths:
ops = np.load(ops_path, allow_pickle=True).item()
ops["extra_dj_params"]["ops_path"] = ops_path.as_posix()
ops["extra_dj_params"] = {
"ops_path": ops_path.as_posix(),
"field_idx": int(
re.search(r"plane(\d+)", ops_path.parent.name).group(1)
),
}
field_processing_tasks.append(
{
**key,
"field_idx": ops["extra_dj_params"]["field_idx"],
"params": ops,
"processing_output_dir": ops_path.parent.as_posix(),
"processing_output_dir": ops_path.parent.relative_to(
processed_root_data_dir
).as_posix(),
}
)
else:
Expand Down Expand Up @@ -235,11 +242,11 @@ class FieldProcessing(dj.Computed):
def make(self, key):
execution_time = datetime.utcnow()

output_dir, params = (FieldPreprocessing & key).fetch1(
output_dir, params = (FieldPreprocessing.Field & key).fetch1(
"processing_output_dir", "params"
)
extra_params = params.pop("extra_dj_params", {})
output_dir = find_full_path(get_imaging_root_data_dir(), output_dir)
output_dir = find_full_path(scan.get_imaging_root_data_dir(), output_dir)

acq_software = (scan.Scan & key).fetch1("acq_software")
method = (imaging.ProcessingParamSet * imaging.ProcessingTask & key).fetch1(
Expand Down Expand Up @@ -307,15 +314,15 @@ def key_source(self):

def make(self, key):
execution_time = datetime.utcnow()
method = (imaging.ProcessingTask * imaging.ProcessingParamSet & key).fetch1(
"processing_method"
)
method, params = (
imaging.ProcessingTask * imaging.ProcessingParamSet & key
).fetch1("processing_method", "params")

if method == "suite2p":
if method == "suite2p" and params.get("combined", True):
from suite2p import io

output_dir = (imaging.ProcessingTask & key).fetch1("processing_output_dir")
output_dir = find_full_path(get_imaging_root_data_dir(), output_dir)
output_dir = find_full_path(scan.get_imaging_root_data_dir(), output_dir)

io.combined(output_dir / "suite2p", save=True)

Expand All @@ -332,5 +339,6 @@ def make(self, key):
**key,
"processing_time": datetime.utcnow(),
"package_version": "",
}
},
allow_direct_insert=True,
)
16 changes: 1 addition & 15 deletions element_calcium_imaging/imaging_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,21 +354,7 @@ class Processing(dj.Computed):
# Run processing only on Scan with ScanInfo inserted
@property
def key_source(self):
"""Limit the Processing to Scans that have their metadata ingested to the
database."""
ks = ProcessingTask & scan.ScanInfo
per_plane_proc = (
ProcessingTask.aggr(
PerPlaneProcessingTask.proj(), task_count="count(*)", keep_all_rows=True
)
* ProcessingTask.aggr(
PerPlaneProcessing.proj(),
finished_task_count="count(*)",
keep_all_rows=True,
)
& "task_count = finished_task_count"
)
return ks & per_plane_proc
return ProcessingTask & scan.ScanInfo

def make(self, key):
"""
Expand Down

0 comments on commit f1bbafb

Please sign in to comment.