diff --git a/.github/workflows/build_docs_html.yml b/.github/workflows/build_docs_html.yml
new file mode 100644
index 0000000..0b6d809
--- /dev/null
+++ b/.github/workflows/build_docs_html.yml
@@ -0,0 +1,53 @@
+# This is a basic workflow to help you get started with Actions
+
+name: CI
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the "main" branch
+ push:
+ branches: [ "main", "develop" ]
+ pull_request:
+ branches: [ "main", "develop" ]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ # This workflow contains a single job called "build"
+ build:
+ # The type of runner that the job will run on
+ runs-on: ubuntu-latest
+
+ # Steps represent a sequence of tasks that will be executed as part of the job
+ steps:
+ # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
+ - uses: actions/checkout@v3
+
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.7
+
+ #- name: Install Dependencies
+ # run: |
+ # python -m pip install --upgrade pip
+ # pip install -r docs/requirements-docs.txt
+ # sudo apt-get install pandoc -y
+
+ - name: Test sphinx-build
+ - uses: ammaraskar/sphinx-action@master
+ with:
+ docs-folder: "docs/"
+ build-command: "sphinx-build -W -nT -b dummy ./docs/source build/html"
+
+ # Runs a single command using the runners shell
+ - name: Run a one-line script
+ run: echo Hello, world!
+
+ # Runs a set of commands using the runners shell
+ - name: Run a multi-line script
+ run: |
+ echo Add other actions to build,
+ echo test, and deploy your project.
diff --git a/README.md b/README.md
index d37f7d0..d1884a3 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,27 @@ An earlier, Matlab-based version of Bento is available [here](https://github.com
## New In This Release
+### Version 0.3.0-beta
+#### Added Features
+- Plug-ins supporting display of MARS, DLC, and SLEAP pose files
+- Support for loading and displaying annotations made in Bento, BORIS, SimBA, and the Caltech Behavior Annotator
+- Plug-ins for event-triggered-averaging of neural activity, and for k-means clustering of neurons based on activity over the full trial
+- Export of your experimental data (neural recording, pose, behavior annotations) to a NWB file
+- Main window now supports jumping to a specific time or frame number in a video
+- Added button to delete files from a trial (though you can still also use the delete key)
+- Simplified setup of the conda environment, removing OS-specific environment files
+
+#### Bugs Fixed
+- Desynced scrolling/display of annotations + neural traces has been fixed
+- Editing trials in v0.2.0-beta caused an increment in the trial number; this is now fixed
+- Widgets are now cleared out when new data is loaded
+
+## Getting Started
+
+- Please look for the installation instructions at [Installation Instructions](https://github.com/neuroethology/bento/blob/main/documentation/installation.md)
+- Please look for the detailed step by step instructions at [Tutorial](https://github.com/neuroethology/bento/blob/main/documentation/tutorial.md)
+
+## Previous Release Updates
### Version 0.2.0-beta
#### Added Features
- A plug-in interface to support import and display of pose data
@@ -18,12 +39,6 @@ An earlier, Matlab-based version of Bento is available [here](https://github.com
- On initial startup when no Investigators yet exist, v0.1.0-beta would prompt for the selection of an Investigator anyway.
With this release, it takes you to the "Add Investigators" dialog instead.
- The vertical scaling of annotations has been fixed.
-## Features
-
-## Getting Started
-
-- Please look for the installation instructions at [Installation Instructions](https://github.com/neuroethology/bento/blob/main/documentation/installation.md)
-- Please look for the detailed step by step instructions at [Tutorial](https://github.com/neuroethology/bento/blob/main/documentation/tutorial.md)
## Citation
diff --git a/bento.yml b/bento.yml
new file mode 100755
index 0000000..ef5bd9d
--- /dev/null
+++ b/bento.yml
@@ -0,0 +1,31 @@
+name: bento
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - cryptography
+ - gst-plugins-bad
+ - gst-plugins-good
+ - gst-plugins-base
+ - gstreamer
+ - h5py
+ - matplotlib-base
+ - pip
+ - progressbar
+ - pynwb
+ - pyside2
+ - python
+ - qtpy
+ - scikit-learn
+ - scikit-video
+ - sortedcontainers
+ - sqlalchemy
+ - vispy
+ - xlrd
+ - pip:
+ - colour-demosaicing
+ - colour-science
+ - ndx-pose
+ - opencv-python-headless
+ - pymatreader
+ - qimage2ndarray
\ No newline at end of file
diff --git a/bento_mac.yml b/bento_mac.yml
deleted file mode 100644
index f1f523e..0000000
--- a/bento_mac.yml
+++ /dev/null
@@ -1,131 +0,0 @@
-name: bento
-channels:
- - conda-forge
- - defaults
-dependencies:
- - aom=3.2.0
- - brotli=1.0.9
- - brotli-bin=1.0.9
- - bzip2=1.0.8
- - c-ares=1.18.1
- - ca-certificates=2021.10.8
- - cairo=1.16.0
- - certifi=2021.10.8
- - cffi=1.15.0
- - cryptography=35.0.0
- - cycler=0.10.0
- - ffmpeg=4.3.2
- - fontconfig=2.13.1
- - fonttools=4.25.0
- - freetype=2.11.0
- - gettext=0.21.0
- - glib=2.70.0
- - glib-tools=2.70.0
- - gmp=6.2.1
- - gnutls=3.6.15
- - graphite2=1.3.14
- - greenlet=1.1.1
- - harfbuzz=3.1.1
- - hdf5=1.12.1
- - icu=68.2
- - jasper=2.0.14
- - jbig=2.1
- - jpeg=9d
- - kiwisolver=1.3.1
- - krb5=1.19.2
- - lame=3.100
- - lcms2=2.12
- - lerc=3.0
- - libblas=3.9.0
- - libbrotlicommon=1.0.9
- - libbrotlidec=1.0.9
- - libbrotlienc=1.0.9
- - libcblas=3.9.0
- - libclang=11.1.0
- - libcurl=7.79.1
- - libcxx=12.0.0
- - libdeflate=1.8
- - libedit=3.1.20210714
- - libev=4.33
- - libffi=3.4.2
- - libgfortran=5.0.0
- - libgfortran5=9.3.0
- - libglib=2.70.0
- - libiconv=1.16
- - libidn2=2.3.2
- - liblapack=3.9.0
- - liblapacke=3.9.0
- - libllvm11=11.1.0
- - libnghttp2=1.43.0
- - libopenblas=0.3.18
- - libopencv=4.5.3
- - libpng=1.6.37
- - libpq=13.3
- - libprotobuf=3.18.1
- - libssh2=1.10.0
- - libtasn1=4.16.0
- - libtiff=4.3.0
- - libunistring=0.9.10
- - libvpx=1.11.0
- - libwebp-base=1.2.1
- - libxml2=2.9.12
- - libxslt=1.1.33
- - libzlib=1.2.11
- - llvm-openmp=12.0.0
- - lz4-c=1.9.3
- - matplotlib=3.4.3
- - matplotlib-base=3.4.3
- - munkres=1.1.4
- - mysql-common=8.0.27
- - mysql-libs=8.0.27
- - ncurses=6.2
- - nettle=3.7.3
- - nspr=4.32
- - nss=3.71
- - numpy=1.22.2
- - olefile=0.46
- - openh264=2.1.1
- - openjpeg=2.4.0
- - openssl=1.1.1l
- - pcre=8.45
- - pip=21.2.4
- - pixman=0.40.0
- - progressbar=2.5
- - pycparser=2.20
- - pyparsing=3.0.4
- - python=3.9.7
- - python-dateutil=2.8.2
- - python_abi=3.9
- - qtpy=1.11.2
- - readline=8.1
- - scikit-video=1.1.11
- - scipy=1.7.2
- - setuptools=58.0.4
- - six=1.16.0
- - sortedcontainers=2.4.0
- - sqlalchemy=1.4.26
- - sqlite=3.36.0
- - svt-av1=0.8.7
- - tk=8.6.11
- - tornado=6.1
- - tzdata=2021e
- - wheel=0.37.0
- - x264=1!161.3030
- - x265=3.5
- - xlrd=2.0.1
- - xz=5.2.5
- - zlib=1.2.11
- - zstd=1.5.0
- - pip:
- - colour-demosaicing==0.1.6
- - colour-science==0.3.16
- - future==0.18.2
- - h5py==3.5.0
- - imageio==2.10.3
- - opencv-contrib-python-headless==4.5.5.62
- - pillow==9.0.1
- - pymatreader==0.0.25
- - pyside2==5.15.1
- - qimage2ndarray==1.9.0
- - shiboken2==5.15.1
- - xmltodict==0.12.0
diff --git a/bento_ubuntu.yml b/bento_ubuntu.yml
deleted file mode 100755
index 6852e46..0000000
--- a/bento_ubuntu.yml
+++ /dev/null
@@ -1,82 +0,0 @@
-name: bento
-channels:
- - conda-forge
- - defaults
-dependencies:
- - _libgcc_mutex=0.1
- - _openmp_mutex=4.5
- - ca-certificates=2021.10.8
- - certifi=2021.10.8
- - cffi=1.14.6
- - cryptography=35.0.0
- - freetype=2.10.4
- - greenlet=1.1.0
- - jbig=2.1
- - jpeg=9d
- - lcms2=2.12
- - ld_impl_linux-64=2.35.1
- - lerc=2.2.1
- - libblas=3.9.0
- - libcblas=3.9.0
- - libdeflate=1.7
- - libffi=3.3
- - libgcc-ng=9.3.0
- - libgfortran-ng=11.2.0
- - libgfortran5=11.2.0
- - libgomp=9.3.0
- - liblapack=3.9.0
- - libopenblas=0.3.17
- - libpng=1.6.37
- - libstdcxx-ng=9.3.0
- - libtiff=4.3.0
- - libwebp-base=1.2.0
- - lz4-c=1.9.3
- - ncurses=6.3
- - numpy=1.21.1
- - olefile=0.46
- - openjpeg=2.4.0
- - openssl=1.1.1l
- - pip=21.2.4
- - progressbar=2.5
- - pycparser=2.21
- - python=3.9.7
- - python_abi=3.9
- - readline=8.1
- - scikit-video=1.1.11
- - setuptools=58.0.4
- - sortedcontainers=2.4.0
- - sqlalchemy=1.4.22
- - sqlite=3.36.0
- - tk=8.6.11
- - tzdata=2021e
- - wheel=0.37.0
- - xlrd=2.0.1
- - xz=5.2.5
- - zlib=1.2.11
- - zstd=1.5.0
- - pip:
- - colour-demosaicing==0.1.6
- - colour-science==0.3.16
- - cycler==0.11.0
- - fonttools==4.28.2
- - future==0.18.2
- - h5py==3.6.0
- - imageio==2.12.0
- - kiwisolver==1.3.2
- - matplotlib==3.5.0
- - opencv-contrib-python-headless==4.5.4.60
- - packaging==21.3
- - pillow==8.4.0
- - pymatreader==0.0.26
- - pyparsing==3.0.6
- - pyside2==5.15.2
- - python-dateutil==2.8.2
- - qimage2ndarray==1.9.0
- - qtpy==1.11.2
- - scipy==1.7.3
- - setuptools-scm==6.3.2
- - shiboken2==5.15.2
- - six==1.16.0
- - tomli==1.2.2
- - xmltodict==0.12.0
-
diff --git a/bento_windows.yml b/bento_windows.yml
deleted file mode 100755
index 1cf1a25..0000000
--- a/bento_windows.yml
+++ /dev/null
@@ -1,82 +0,0 @@
-name: bento
-channels:
- - conda-forge
- - defaults
-dependencies:
- - ca-certificates=2021.10.8
- - certifi=2021.10.8
- - cffi=1.15.0
- - cryptography=36.0.0
- - freetype=2.10.4
- - greenlet=1.1.2
- - intel-openmp=2021.4.0
- - jbig=2.1
- - jpeg=9d
- - lcms2=2.12
- - lerc=3.0
- - libblas=3.9.0
- - libcblas=3.9.0
- - libdeflate=1.8
- - liblapack=3.9.0
- - libpng=1.6.37
- - libtiff=4.3.0
- - libzlib=1.2.11
- - lz4-c=1.9.3
- - m2w64-gcc-libgfortran=5.3.0
- - m2w64-gcc-libs=5.3.0
- - m2w64-gcc-libs-core=5.3.0
- - m2w64-gmp=6.1.0
- - m2w64-libwinpthread-git=5.0.0.4634.697f757
- - mkl=2021.4.0
- - msys2-conda-epoch=20160418
- - numpy=1.21.4
- - olefile=0.46
- - openjpeg=2.4.0
- - openssl=1.1.1l
- - pillow=8.4.0
- - pip=21.2.4
- - progressbar=2.5
- - pycparser=2.21
- - python=3.9.7
- - python_abi=3.9
- - scikit-video=1.1.11
- - scipy=1.7.2
- - setuptools=58.0.4
- - sortedcontainers=2.4.0
- - sqlalchemy=1.4.27
- - sqlite=3.36.0
- - tbb=2021.4.0
- - tk=8.6.11
- - tzdata=2021e
- - vc=14.2
- - vs2015_runtime=14.27.29016
- - wheel=0.37.0
- - wincertstore=0.2
- - xlrd=2.0.1
- - xz=5.2.5
- - zlib=1.2.11
- - zstd=1.5.0
- - pip:
- - colour-demosaicing==0.1.6
- - colour-science==0.3.16
- - cycler==0.11.0
- - fonttools==4.28.2
- - future==0.18.2
- - h5py==3.6.0
- - imageio==2.11.1
- - kiwisolver==1.3.2
- - matplotlib==3.5.0
- - opencv-contrib-python-headless==4.5.4.60
- - packaging==21.3
- - pymatreader==0.0.26
- - pyparsing==3.0.6
- - pyside2==5.15.2
- - python-dateutil==2.8.2
- - qtpy==1.11.2
- - qimage2ndarray==1.9.0
- - setuptools-scm==6.3.2
- - shiboken2==5.15.2
- - shiboken6==6.2.1
- - six==1.16.0
- - tomli==1.2.2
- - xmltodict==0.12.0
diff --git a/documentation/_gifs/adding_trials.gif b/documentation/_gifs/adding_trials.gif
index 953f210..0eb9388 100644
Binary files a/documentation/_gifs/adding_trials.gif and b/documentation/_gifs/adding_trials.gif differ
diff --git a/documentation/installation.md b/documentation/installation.md
index 8727420..2b9b811 100644
--- a/documentation/installation.md
+++ b/documentation/installation.md
@@ -21,10 +21,9 @@ cd path_to_bento_folder
- Ex : cd /Users/KennedyLab/all_codes/bento
5. Execute the following command to install all the dependencies/packages required for Bento.
- - Note : filename in the command should be replaced by bento_windows.yml or bento_mac.yml or bento_ubuntu.yml (based on your OS)
```
-conda env create -f filename
+conda env create -f bento.yml
```
6. Execute the following two commands to open the Bento User Interface:
diff --git a/documentation/tutorial.md b/documentation/tutorial.md
index deb2033..17ad0d2 100644
--- a/documentation/tutorial.md
+++ b/documentation/tutorial.md
@@ -89,10 +89,16 @@ Now that you have experimenters, animals, and cameras on record, you can start u
3. Select a session in the **Select Session** table by clicking the corresponding row.
4. Click on **Add New Trial...** button. **Add or Edit Trial dialog** will pop up.
5. Add Stimulus in the **Stimulus** field.
-6. You can add **Video files**, **Annotation files** and **Neural Files** in the window. **Pose files** and **Audio files** are not yet supported.
+6. You can add **Video files**, **Pose files**, **Annotation files** and **Neural files** in the window. **Audio files** are not yet supported.
+- Note : **Pose files** can be added only if there is a video file present.
7. Click **OK** button. You will see a trial added in the **Trial** table along with files you selected for the trial.
8. You can add multiple trials under the same session. Repeat steps 1-7, every time you need to add a trial under a particular session.
+#### Adding a Pose file
+1. Click on a **Video file** and click on **Add Pose...** button
+2. Select the appropriate format for the file in the file selection window. Bento supports MARS, SLEAP and DeepLabCut files.
+3. You can add one pose file to each video file present in a trial.
+
![alt-text](_gifs/adding_trials.gif)
@@ -115,7 +121,7 @@ Now that you have experimenters, animals, and cameras on record, you can start u
> Note : Assuming you are a loading a trial with no annotation file in it. You can also load a trial with an exisiing annotation file and do all the steps.
-1. Load a trial as mentioned in [Loading data from a trial](https://github.com/neuroethology/bento/blob/feature/documentation/documentation/tutorial.md#Loading-data-from-a-trial).
+1. Load a trial as mentioned in [Loading data from a trial](https://github.com/neuroethology/bento/blob/main/documentation/tutorial.md#Loading-data-from-a-trial).
2. In the **Main Window**, click on **New Channel** button. **New Channel Dialog** will pop up.
3. Add any relevant channel name in the **Channel Name** text box and click **OK** button.
4. You will see a channel added in the drop down box left to the **New Channel** button.
@@ -150,8 +156,8 @@ Now that you have experimenters, animals, and cameras on record, you can start u
> Note : Assuming you are a loading a trial with no annotation file in it. You can also load a trial with an exisiing annotation file and do all the steps.
-1. Add a channel as mentioned in [Creating a new annotation channel](https://github.com/neuroethology/bento/blob/feature/documentation/documentation/tutorial.md#Creating-a-new-annotation-channel)
-2. [Create new behaviors](https://github.com/neuroethology/bento/blob/feature/documentation/documentation/tutorial.md#Creating-a-new-behavior) or [edit behavior properties](https://github.com/neuroethology/bento/blob/feature/documentation/documentation/tutorial.md#Editing-behavior-properties) based on your requirement for adding annotations.
+1. Add a channel as mentioned in [Creating a new annotation channel](https://github.com/neuroethology/bento/blob/main/documentation/tutorial.md#Creating-a-new-annotation-channel)
+2. [Create new behaviors](https://github.com/neuroethology/bento/blob/main/documentation/tutorial.md#Creating-a-new-behavior) or [edit behavior properties](https://github.com/neuroethology/bento/blob/main/documentation/tutorial.md#Editing-behavior-properties) based on your requirement for adding annotations.
3. For adding annotations, select the channel for which you need to add annotations.
4. Click on the rectangular annotations view, press the **hot_key** corresponding to the behavior.
5. After pressing the **hot_key**, drag the cursor in the view for the length of time you need to add an annotation.
diff --git a/src/annot/annot.py b/src/annot/annot.py
index c5ee4d8..c8d18f7 100644
--- a/src/annot/annot.py
+++ b/src/annot/annot.py
@@ -1,11 +1,20 @@
# annot.py
+from __future__ import annotations
from random import sample
import timecode as tc
from annot.behavior import Behavior
from sortedcontainers import SortedKeyList
from qtpy.QtCore import QObject, QRectF, Signal, Slot
from qtpy.QtGui import QColor
-from qtpy.QtWidgets import QGraphicsItem
+from qtpy.QtWidgets import QGraphicsItem, QMessageBox
+from dataExporter import DataExporter
+from datetime import datetime
+from pynwb import NWBFile
+from pynwb.epoch import TimeIntervals
+import csv
+import numpy as np
+import pandas as pd
+import math
class Bout(object):
"""
@@ -76,6 +85,7 @@ class Channel(QGraphicsItem):
def __init__(self, chan = None):
super().__init__()
+
if not chan is None:
self._bouts_by_start = chan._bouts_by_start
self._bouts_by_end = chan._bouts_by_end
@@ -286,45 +296,97 @@ def coalesce_bouts(self, start, end):
for item in to_delete:
self.remove(item)
-class Annotations(QObject):
+ def exportToNWBFile(self, chanName: str, nwbFile: NWBFile):
+ if f"annotation_{chanName}" in nwbFile.intervals:
+ nwbFile.intervals.pop(f"annotation_{chanName}")
+
+ annotationData = TimeIntervals(name=f"annotation_{chanName}",
+ description="animal behavior annotations")
+ if len(self._bouts_by_start)>0:
+ annotationData.add_column(name="behaviorName",
+ description="type of behavior")
+
+ for bout in self._bouts_by_start:
+ annotationData.add_row(start_time=bout.start().float, stop_time=bout.end().float, behaviorName=bout.name())
+
+ nwbFile.add_time_intervals(annotationData)
+
+ return nwbFile
+
+
+class Annotations(QObject, DataExporter):
"""
"""
+ # backward-compatible change, so only update minor version number
+ _current_annotation_version_ = 1.1
# Signals
annotations_changed = Signal()
active_annotations_changed = Signal()
def __init__(self, behaviors):
- super().__init__()
+ QObject.__init__(self)
+ DataExporter.__init__(self)
+ self.dataExportType = "annotations"
self._channels = {}
self._behaviors = behaviors
- self._movies = []
+ self._version = self._current_annotation_version_
self._start_frame = None
self._end_frame = None
self._sample_rate = None
- self._stimulus = None
self._format = None
+ self._start_date_time = None
+ self._offset_time = tc.Timecode('30.0', '0:0:0:0')
self.annotation_names = []
behaviors.behaviors_changed.connect(self.note_annotations_changed)
+
def read(self, fn):
+ if self._validate_bento(fn):
+ self._format = "Bento"
+ self._read_bento(fn)
+ elif self._validate_caltech(fn):
+ self._format = 'Caltech'
+ self._read_caltech(fn)
+ elif self._validate_boris(fn):
+ self._format = 'Boris'
+ self._read_boris(fn)
+ elif self._validate_simba(fn):
+ self._format = 'SimBA'
+ self._read_simba(fn)
+ else:
+ QMessageBox.about(self, "File Read Error", f"Unsupported annotation file format")
+ return
+
+ def _validate_bento(self, fn):
with open(fn, "r") as f:
line = f.readline()
line = line.strip().lower()
- if line.endswith("annotation file"):
- self._format = 'Caltech'
- self._read_caltech(f)
- elif line.startswith("scorevideo log"):
- self._format = 'Ethovision'
- self._read_ethovision(f)
- else:
- print("Unsupported annotation file format")
-
- def _read_caltech(self, f):
- found_movies = False
+ return line.startswith('bento annotation file')
+
+ def _validate_boris(self, fn):
+ with open(fn) as csv_file:
+ reader = csv.reader(csv_file)
+ row = next(reader)
+ return row[0].lower() == 'observation id'
+
+ def _validate_simba(self, fn):
+ with open(fn) as csv_file:
+ reader = csv.reader(csv_file)
+ row = next(reader)
+ return '_x' in row[1] and '_y' in row[2] and '_p' in row[3]
+
+ def _validate_caltech(self, fn):
+ header = 'Caltech Behavior Annotator - Annotation File'
+ with open(fn, "r") as f:
+ lines = f.read().splitlines()
+ return lines[0].rstrip().lower() == header.lower()
+
+ def _read_bento(self, fn):
found_timecode = False
- found_stimulus = False
found_channel_names = False
+ found_version = False
+ found_start_date_time = False
found_annotation_names = False
found_all_channels = False
found_all_annotations = False
@@ -335,121 +397,289 @@ def _read_caltech(self, f):
current_channel = None
current_bout = None
- self._format = 'Caltech'
-
- line = f.readline()
- while line:
- if found_annotation_names and not new_behaviors_activated:
- self.ensure_and_activate_behaviors(to_activate)
- new_behaviors_activated = True
-
- line.strip()
-
- if not line:
- if reading_channel:
- reading_channel = False
- current_channel = None
- current_bout = None
- elif line.lower().startswith("movie file"):
- items = line.split()
- for item in items:
- if item.lower().startswith("movie"):
- continue
- if item.lower().startswith("file"):
- continue
- self._movies.append(item)
- found_movies = True
- elif line.lower().startswith("stimulus name"):
- # TODO: do something when we know what one of these looks like
- found_stimulus = True
- elif line.lower().startswith("annotation start frame"):
- items = line.split()
- if len(items) > 3:
- self._start_frame = int(items[3])
- if self._end_frame and self._sample_rate:
- found_timecode = True
- elif line.lower().startswith("annotation stop frame"):
- items = line.split()
- if len(items) > 3:
- self._end_frame = int(items[3])
- if self._start_frame and self._sample_rate:
- found_timecode = True
- elif line.lower().startswith("annotation framerate"):
- items = line.split()
- if len(items) > 2:
- self._sample_rate = float(items[2])
- if self._start_frame and self._end_frame:
- found_timecode = True
- elif line.lower().startswith("list of channels"):
- line = f.readline()
- while line:
- line = line.strip()
- if not line:
- break # blank line -- end of section
- channel_names.append(line)
+ with open(fn, "r") as f:
+ line = f.readline()
+ while line:
+ if found_annotation_names and not new_behaviors_activated:
+ self.ensure_and_activate_behaviors(to_activate)
+ new_behaviors_activated = True
+
+ line.strip()
+
+ if not line:
+ if reading_channel:
+ reading_channel = False
+ current_channel = None
+ current_bout = None
+ elif line.lower().startswith("bento annotation file v"):
+ items = line.split()
+ if len(items)>3:
+ version = str(items[3])
+ if version[0] == "v":
+ version = version[1:]
+ self._version = version
+ found_version = True
+ elif line.lower().startswith("start date time"):
+ items = line.split()
+ if len(items)>3:
+ self._start_date_time = datetime.fromisoformat(items[-2]+' '+items[-1])
+ found_start_date_time = True
+ elif line.lower().startswith("annotation start frame"):
+ items = line.split()
+ if len(items) > 3:
+ self._start_frame = int(items[3])
+ if self._end_frame and self._sample_rate:
+ found_timecode = True
+ elif line.lower().startswith("annotation stop frame"):
+ items = line.split()
+ if len(items) > 3:
+ self._end_frame = int(items[3])
+ if self._start_frame and self._sample_rate:
+ found_timecode = True
+ elif line.lower().startswith("annotation framerate"):
+ items = line.split()
+ if len(items) > 2:
+ self._sample_rate = float(items[2])
+ if self._start_frame and self._end_frame:
+ found_timecode = True
+ elif line.lower().startswith("list of channels"):
line = f.readline()
- found_channel_names = True
- elif line.lower().startswith("list of annotations"):
- line = f.readline()
- while line:
- line = line.strip()
- if not line:
- break # blank line -- end of section
- to_activate.append(line)
- line = f.readline().strip()
- found_annotation_names = True
- elif line.strip().lower().endswith("---"):
- for ch_name in channel_names:
- if line.startswith(ch_name):
- self._channels[ch_name] = Channel()
- current_channel = ch_name
- reading_channel = True
- break
- if reading_channel:
- reading_annot = False
+ while line:
+ line = line.strip()
+ if not line:
+ break # blank line -- end of section
+ channel_names.append(line)
+ line = f.readline()
+ found_channel_names = True
+ elif line.lower().startswith("list of annotations"):
line = f.readline()
while line:
line = line.strip()
- if not line: # blank line
- if reading_annot:
- reading_annot = False
- current_bout = None
+ if not line:
+ break # blank line -- end of section
+ to_activate.append(line)
+ line = f.readline().strip()
+ found_annotation_names = True
+ elif line.strip().lower().endswith("---"):
+ for ch_name in channel_names:
+ if line.startswith(ch_name):
+ ix = len(self._channels)
+ self._channels[ch_name] = Channel()
+ current_channel = ch_name
+ reading_channel = True
+ break
+ if reading_channel:
+ reading_annot = False
+ line = f.readline()
+ while line:
+ line = line.strip()
+ if not line: # blank line
+ if reading_annot:
+ reading_annot = False
+ current_bout = None
+ else:
+ # second blank line, so done with channel
+ reading_channel = False
+ current_channel = None
+ break
+ elif line.startswith(">"):
+ current_bout = line[1:]
+ reading_annot = True
+ elif line.lower().startswith("start"):
+ pass # skip a header line
else:
- # second blank line, so done with channel
- reading_channel = False
- current_channel = None
- break
- elif line.startswith(">"):
- current_bout = line[1:]
- reading_annot = True
- elif line.lower().startswith("start"):
- pass # skip a header line
- else:
- items = line.split()
- is_float = '.' in items[0] or '.' in items[1] or '.' in items[2]
- self.add_bout(
- Bout(
- tc.Timecode(self._sample_rate, start_seconds=float(items[0])) if is_float
- else tc.Timecode(self._sample_rate, frames=int(items[0])),
- tc.Timecode(self._sample_rate, start_seconds=float(items[1])) if is_float
- else tc.Timecode(self._sample_rate, frames=int(items[1])),
- self._behaviors.get(current_bout)),
+ items = line.split()
+ is_float = '.' in items[0] or '.' in items[1] or '.' in items[2]
+ self.add_bout(
+ Bout(
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(items[0])) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(items[0])),
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(items[1])) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(items[1])),
+ self._behaviors.get(current_bout)),
+ current_channel)
+ line = f.readline()
+ line = f.readline()
+ print(f"Done reading Bento annotation file {f.name}")
+ self.note_annotations_changed()
+
+ def _read_caltech(self, f):
+ header = 'Caltech Behavior Annotator - Annotation File'
+ conf = 'Configuration file:'
+ fid = open(f)
+ lines = fid.read().splitlines()
+ fid.close()
+ NFrames = []
+ # check the header
+ assert lines[0].rstrip() == header
+ assert lines[1].rstrip() == ''
+ assert lines[2].rstrip() == conf
+ # parse action list
+ line = 3
+ behavior_names = [None] * 1000
+ behaviors = []
+ k = -1
+
+ #get config keys and names
+ while True:
+ lines[line] = lines[line].rstrip()
+ if not isinstance(lines[line], str) or not lines[line]:
+ line+=1
+ break
+ values = lines[line].split()
+ k += 1
+ behavior_names[k] = values[0].lower()
+ line += 1
+ behavior_names = behavior_names[:k+1]
+
+ #read in each stream in turn until end of file
+ bnds0 = [None]*10000
+ actions0 = [None]*10000
+ nStrm1 = 0
+ while True:
+ lines[line] = lines[line].rstrip()
+ nStrm1 += 1
+ t = lines[line].split(":")
+ line += 1
+ lines[line] = lines[line].rstrip()
+ assert int(t[0][1]) == nStrm1
+ assert lines[line] == '-----------------------------'
+ current_channel = f'Ch{int(nStrm1)}'
+ self._channels[current_channel] = Channel()
+ line += 1
+ bnds1 = np.ones((10000, 2),dtype=int)
+ k = 0
+ # start the annotations
+ while True:
+ lines[line] = lines[line].rstrip()
+ t = lines[line]
+ if not isinstance(t, str) or not t:
+ line += 1
+ break
+ t = lines[line].split()
+ bhvr_ind = [i for i in range(len(behavior_names)) if t[2].lower() == behavior_names[i]]
+ bhvr_ind = bhvr_ind[0]
+ if bhvr_ind == None:
+ print('undefined behavior' + t[2])
+ if bnds1[k-1,1]+1 != int(t[0])-1 and k>0:
+ print('%d ~= %d' % (bnds1[k,1], int(t[0]) - 1))
+ bnds1[k,:] = [int(t[0]) - 1, int(t[1]) - 1] # added -1 so that we're 0-indexing!
+ #actions1[k] = behavior_names[bhvr_ind]
+ start, end, bev = int(t[0])-1, int(t[1])-1, str(t[2].lower())
+ if bev != 'other' and self.sample_rate():
+ is_float = isinstance(start, float) or isinstance(end, float)
+ if bev not in behaviors:
+ behaviors.append(bev)
+ self.ensure_and_activate_behaviors(behaviors)
+ self.add_bout(Bout(
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(start)) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(start)),
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(end)) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(end)),
+ self._behaviors.get(bev)),
current_channel)
- line = f.readline()
- line = f.readline()
- print(f"Done reading Caltech annotation file {f.name}")
+ k += 1
+ line += 1
+ if line == len(lines):
+ break
+
+ if nStrm1 == 1:
+ nFrames = bnds1[k-1, 1] + 1
+ assert nFrames == bnds1[k-1, 1] + 1
+ if line == len(lines):
+ break
+ while not lines[line]:
+ line += 1
+ self._start_frame, self._end_frame = 1, nFrames
+ self.note_annotations_changed()
+
+ def _read_boris(self, f):
+ self._channels['Ch1'] = Channel()
+ current_channel = 'Ch1'
+ to_activate = []
+ with open(f) as csv_file:
+ reader = csv.reader(csv_file)
+ header_flag = 1
+ nFrames = 0
+ for row in reader:
+ if header_flag and not (row[0] == 'Time' and row[1] == 'Media file path' and row[2] == 'Total length'):
+ continue
+ elif header_flag:
+ header_flag = 0
+ continue
+
+ fps = float(row[3])
+ self._sample_rate = fps
+ if not nFrames:
+ nFrames = math.ceil(float(row[2]) * fps) + 1
+ self._start_frame, self._end_frame = 1, nFrames
+ bev = row[-4]
+ if bev not in to_activate:
+ to_activate.append(bev)
+ if row[-1] == 'POINT':
+ start_end = [round(float(row[0]) * fps), round(float(row[0]) * fps)+1]
+ elif row[-1] == 'START':
+ start_end = [round(float(row[0]) * fps), 0]
+ elif row[-1] == 'STOP':
+ start_end[1] = round(float(row[0]) * fps)
+
+ if start_end[1] > 0:
+ is_float = isinstance(start_end[0], float) or isinstance(start_end[1], float)
+ self.add_bout(Bout(
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(start_end[0])) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(start_end[0])),
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(start_end[1])) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(start_end[1])),
+ self._behaviors.get(bev)),
+ current_channel)
+ else:
+ continue
+ self.ensure_and_activate_behaviors(to_activate)
self.note_annotations_changed()
- def write_caltech(self, f, video_files, stimulus):
+
+ def _read_simba(self, f):
+ self._channels['Ch1'] = Channel()
+ current_channel = 'Ch1'
+ with open(f) as csv_file:
+ reader = csv.reader(csv_file)
+ ordered_headers = next(reader)
+ behaviors = []
+ df = pd.read_csv(f)
+
+ # figure out the behaviors scored
+ while True:
+ if 'Low_prob' in ordered_headers[-1]: # this looks like it could sometimes be binary
+ break
+ if any([x not in [0, 1] for x in df[ordered_headers[-1]] if x == x]):
+ break
+ behaviors.append(ordered_headers[-1])
+ del ordered_headers[-1]
+
+ self.ensure_and_activate_behaviors(behaviors)
+ self._start_frame, self._end_frame = 1, df.index[-1] + 1
+
+ for b in behaviors:
+ start = ([0] if df[b][0] else []) + df.index[df[b].diff() == 1].tolist()
+ end = [i for i in df.index[df[b].diff()==-1].tolist()] + ([df.index[-1]] if df[b][df.index[-1]]==1 else [])
+ start_end = np.column_stack((start, end))
+ if self.sample_rate():
+ for start, end in start_end:
+ is_float = isinstance(start, float) or isinstance(end, float)
+ self.add_bout(Bout(
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(start)) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(start)),
+ tc.Timecode(self._sample_rate, start_seconds=self._offset_time.float+float(end)) if is_float
+ else tc.Timecode(self._sample_rate, frames=self._offset_time.frames+int(end)),
+ self._behaviors.get(b)),
+ current_channel)
+
+ def write_caltech(self, f, video_files):
if not f.writable():
raise RuntimeError("File not writable")
- f.write("Bento annotation file\n")
- f.write("Movie file(s):")
- for file in video_files:
- f.write(' ' + file)
- f.write('\n\n')
-
- f.write(f"Stimulus name: {stimulus}\n")
+ f.write(f"Bento annotation file v{str(self._current_annotation_version_)}\n")
+ f.write('\n')
+ f.write(f"Start date time: {str(self._start_date_time)}\n")
f.write(f"Annotation start frame: {self._start_frame}\n")
f.write(f"Annotation stop frame: {self._end_frame}\n")
f.write(f"Annotation framerate: {self._sample_rate}\n")
@@ -478,15 +708,15 @@ def write_caltech(self, f, video_files, stimulus):
f.write(f">{annot}\n")
f.write("Start\tStop\tDuration\n")
for bout in by_name[annot]:
- start = bout.start().frames
- end = bout.end().frames
+ start = bout.start().float
+ end = bout.end().float
f.write(f"{start}\t{end}\t{end - start}\n")
f.write("\n")
f.write("\n")
f.close()
- print(f"Done writing Caltech annotation file {f.name}")
+ print(f"Done writing Bento annotation file {f.name}")
def _read_ethovision(self, f):
print("Ethovision annotations not yet supported")
@@ -511,6 +741,18 @@ def add_bout(self, bout, channel):
if bout.end() > self.end_time():
self.set_end_frame(bout.end())
+ def set_start_date_time(self, dt):
+ if isinstance(dt, float):
+ self._start_date_time = datetime.fromtimestamp(dt-self._offset_time.float).isoformat(sep=' ', timespec='milliseconds')
+ else:
+ raise TypeError("Expected start date time in seconds.")
+
+ def start_date_time(self):
+ return self._start_date_time
+
+ def set_offset_time(self, t):
+ self._offset_time = t
+
def start_time(self):
"""
At some point we will need to support a start time distinct from
@@ -609,4 +851,11 @@ def coalesce_bouts(self, start, end, chan):
@Slot()
def note_annotations_changed(self):
- self.annotations_changed.emit()
\ No newline at end of file
+ self.annotations_changed.emit()
+
+ def exportToNWBFile(self, nwbFile: NWBFile):
+ print(f"Export data from {self.dataExportType} to NWBFile")
+ for chanName in self._channels:
+ nwbFile = self._channels[chanName].exportToNWBFile(chanName, nwbFile)
+
+ return nwbFile
diff --git a/src/annot/behavior.py b/src/annot/behavior.py
index ee6dc98..b86593a 100644
--- a/src/annot/behavior.py
+++ b/src/annot/behavior.py
@@ -168,7 +168,7 @@ def save(self, f):
if h == '':
h = '_'
color = beh.get_color()
- f.write(f"{h} {beh.get_name()} {color.redF()} {color.greenF()} {color.blueF()}" + os.linesep)
+ f.write(f"{h} {beh.get_name()} {color.redF()} {color.greenF()} {color.blueF()}" + "\n")
def get(self, name):
if name not in self._by_name.keys():
diff --git a/src/bento.py b/src/bento.py
index 95d6432..9f1d7ee 100644
--- a/src/bento.py
+++ b/src/bento.py
@@ -1,16 +1,18 @@
# bento.py
-# import faulthandler
-# faulthandler.enable()
+import faulthandler
+
+faulthandler.enable()
from timecode import Timecode
-from qtpy.QtCore import QMarginsF, QObject, QRectF, QTimer, Qt, Signal, Slot
+from qtpy.QtCore import QMarginsF, QObject, QRectF, Qt, Signal, Slot
from qtpy.QtGui import QColor
from qtpy.QtWidgets import QApplication, QFileDialog, QMessageBox, QProgressDialog
from annot.annot import Annotations, Bout
from annot.behavior import Behaviors
from mainWindow import MainWindow
+from timeSource import TimeSourceAbstractBase, TimeSourceQMediaPlayer, TimeSourceQTimer
from video.videoWindow import VideoFrame
from widgets.annotationsWidget import AnnotationsScene
-from db.schema_sqlalchemy import (AnnotationsData, Investigator, Session, Trial,
+from db.schema_sqlalchemy import (Animal, AnnotationsData, Investigator, Session, Trial,
VideoData, new_session, create_tables)
from db.investigatorDialog import InvestigatorDialog
from db.animalDialog import AnimalDialog
@@ -25,64 +27,86 @@
from pose.pose import PoseRegistry
from channelDialog import ChannelDialog
from os.path import expanduser, isabs, sep, relpath, splitext
+from dataExporter import DataExporter
+from pynwb import NWBFile, NWBHDF5IO
+from pynwb.file import Subject
from utils import fix_path, padded_rectf
-import sys, traceback, time
+import sys, traceback, time, itertools
+from datetime import datetime, timezone
+from dateutil.tz import tzlocal
+import numpy as np
class Player(QObject):
def __init__(self, bento):
super().__init__()
- self.playing = False
- self.timer = QTimer()
- self.frame_interval = self.default_frame_interval = 1000./30.
- self.timer.setInterval(round(self.frame_interval))
- self.timer.timeout.connect(bento.incrementTime)
+ self._playing = False
+ self._timeSource = None
@Slot()
def togglePlayer(self):
- # print(f"Setting playing to {not self.playing}")
- self.playing = not self.playing
- if self.playing:
- self.timer.start()
+ if not self._timeSource:
+ return
+ self._playing = not self._playing
+ if self._playing:
+ self._timeSource.start()
else:
- self.timer.stop()
+ self._timeSource.stop()
@Slot()
def doubleFrameRate(self):
- if self.frame_interval > self.default_frame_interval / 8.:
- self.frame_interval /= 2.
- print(f"setting frame interval to {round(self.frame_interval)}")
- self.timer.setInterval(round(self.frame_interval))
+ if self._timeSource:
+ self._timeSource.doubleFrameRate()
@Slot()
def halveFrameRate(self):
- if self.frame_interval < self.default_frame_interval * 8.:
- self.frame_interval *= 2.
- print(f"setting frame interval to {round(self.frame_interval)}")
- self.timer.setInterval(round(self.frame_interval))
+ if self._timeSource:
+ self._timeSource.halveFrameRate()
@Slot()
def resetFrameRate(self):
- self.frame_interval = self.default_frame_interval
- print(f"resetting frame interval to {round(self.frame_interval)}")
- self.timer.setInterval(round(self.frame_interval))
+ if self._timeSource:
+ self._timeSource.resetFrameRate()
@Slot()
def quit(self):
- if self.timer.isActive():
- self.timer.stop()
+ if self._timeSource:
+ self._timeSource.quit()
+
+ def setTimeSource(self, timeSource: TimeSourceAbstractBase):
+ self._timeSource = timeSource
+
+ def timeSource(self) -> TimeSourceAbstractBase:
+ return self._timeSource
+
+ def currentTime(self) -> Timecode:
+ if self._timeSource:
+ return self._timeSource.currentTime()
+ return Timecode("30.0", "00.00.00.01")
+
+ def setCurrentTime(self, t: Timecode):
+ if self._timeSource:
+ self._timeSource.setCurrentTime(t)
-class Bento(QObject):
+class Bento(QObject, DataExporter):
"""
Bento - class representing core machinery (no UI)
"""
def __init__(self):
- super().__init__()
+ QObject.__init__(self)
+ DataExporter.__init__(self)
+ self.dataExportType = "bento"
self.config = BentoConfig()
goodConfig = self.config.read()
self.time_start = Timecode('30.0', '0:0:0:1')
self.time_end = Timecode('30.0', '23:59:59:29')
- self.current_time = self.time_start
+ self.time_start_end = {
+ 'video':[],
+ 'annotations':[],
+ 'neural':[]
+ } # 'data_type' : [start_time, end_time]
+ self.time_start_end_timecode = dict()
+ self.min_max_times = list()
self.investigator_id = None
self.current_annotations = [] # tuples ('ch_key', bout)
self.behaviors = Behaviors()
@@ -91,22 +115,25 @@ def __init__(self):
self.loadBehaviors()
self.behaviorsDialog = BehaviorsDialog(self)
self.behaviorsDialog.show()
-
+ self.nwbFile = None
self.session_id = None
self.trial_id = None
+ self.trial_start_time = None
self.player = Player(self)
self.annotationsScene = AnnotationsScene()
self.newAnnotations = False
- self.video_widgets = []
- self.neural_widgets = []
+ self.widgets = []
+ self.annotations_format = ['Bento', 'SimBA', 'Boris', 'Caltech']
self.annotations = Annotations(self.behaviors)
self.annotations.annotations_changed.connect(self.noteAnnotationsChanged)
+ self.annotations.annotations_changed.connect(self.updateNWBFile)
+ self.newChannelAdded.connect(self.updateNWBFile)
self.pose_registry = PoseRegistry()
self.pose_registry.load_plugins()
self.mainWindow = MainWindow(self)
- self.current_time.set_fractional(False)
self.active_channels = []
self.quitting.connect(self.player.quit)
+ self.timeChanged.connect(self.noteTimeChanged)
self.timeChanged.connect(self.mainWindow.updateTime)
self.currentAnnotsChanged.connect(self.mainWindow.updateAnnotLabel)
self.active_channel_changed.connect(self.mainWindow.selectChannelByName)
@@ -131,7 +158,7 @@ def __init__(self):
if not investigators:
self.edit_investigator()
self.set_investigator()
-
+
self.investigator_id = self.config.investigator_id()
self.mainWindow.show()
@@ -140,7 +167,7 @@ def setInvestigatorId(self, investigator_id):
self.config.set_investigator_id(investigator_id)
self.config.write()
- def load_or_init_annotations(self, fn, sample_rate = 30., running_time = None):
+ def load_or_init_annotations(self, fn, sample_rate = 30., start_time = None):
self.annotationsScene.setSampleRate(sample_rate)
self.annotations.clear_channels()
self.active_channels.clear()
@@ -150,17 +177,17 @@ def load_or_init_annotations(self, fn, sample_rate = 30., running_time = None):
print(f"Try loading annotations from {fn}")
try:
self.annotationsScene.clear()
+ self.annotations.set_sample_rate(sample_rate)
+ self.annotations.set_offset_time(self.time_start_end_timecode['annotations'][0][0]-self.time_start)
+ self.annotations.set_start_date_time(start_time)
self.annotations.read(fn)
if self.annotations.channel_names():
self.mainWindow.addChannelToCombo(self.annotations.channel_names())
- print(f"channel_names: {self.annotations.channel_names()}")
self.setActiveChannel(self.annotations.channel_names()[0])
self.annotationsScene.loadAnnotations(self.annotations, self.annotations.channel_names(), sample_rate)
height = len(self.annotations.channel_names()) - self.annotationsScene.sceneRect().height()
self.annotationsScene.setSceneRect(padded_rectf(self.annotationsScene.sceneRect()) + QMarginsF(0., 0., 0., float(height)))
self.annotationsSceneHeightChanged.emit(float(self.annotationsScene.sceneRect().height()))
- self.time_start = self.annotations.start_time()
- self.time_end = self.annotations.end_time()
loaded = True
except Exception as e:
pass
@@ -170,8 +197,9 @@ def load_or_init_annotations(self, fn, sample_rate = 30., running_time = None):
print("Initializing new annotations")
self.active_channels.clear()
self.annotationsScene.clear()
- self.annotationsScene.setSceneRect(padded_rectf(QRectF(0., 0., running_time, 1.)))
- self.time_end = Timecode(self.time_start.framerate, start_seconds=self.time_start.float + running_time)
+ self.annotationsScene.setSceneRect(padded_rectf(QRectF(0., 0., self.time_end.float, 1.)))
+ self.time_start_end_timecode['annotations'] = [[self.time_start, self.time_end]]
+ self.annotations.set_start_date_time(min(self.min_max_times))
self.annotations.set_sample_rate(sample_rate)
self.annotations.set_start_frame(self.time_start)
self.annotations.set_end_frame(self.time_end)
@@ -197,6 +225,7 @@ def addChannel(self, chanName):
self.annotationsScene.height = self.annotationsScene.sceneRect().height()
self.annotationsSceneHeightChanged.emit(float(self.annotationsScene.sceneRect().height()))
self.setActiveChannel(chanName)
+ self.newChannelAdded.emit()
@Slot()
def setActiveChannel(self, chanName):
@@ -225,6 +254,9 @@ def saveBehaviors(self):
except Exception as e:
print(f"Caught Exception {e}")
QMessageBox.about(self.mainWindow, "Error", f"Saving behaviors to {fn} failed. {e}")
+ self.behaviorsChanged.emit()
+
+ # File menu actions
@Slot()
def save_annotations(self):
@@ -260,7 +292,6 @@ def save_annotations(self):
self.annotations.write_caltech(
file,
[video_data.file_path for video_data in trial.video_data],
- trial.stimulus
)
# Is the annotation filename a new one, or does it exist already?
existingAnnot = None
@@ -275,7 +306,8 @@ def save_annotations(self):
newAnnot.file_path = relpath(fileName, base_directory)
newAnnot.sample_rate = self.annotations.sample_rate()
newAnnot.format = self.annotations.format()
- newAnnot.start_time = self.time_start.float
+ newAnnot.offset_time = datetime.timestamp(datetime.fromisoformat(self.annotations.start_date_time())) - \
+ self.trial_start_time
newAnnot.start_frame = self.time_start.frame_number
newAnnot.stop_frame = self.time_end.frame_number
newAnnot.annotator_name = investigator.user_name
@@ -287,7 +319,26 @@ def save_annotations(self):
db_sess.commit()
self.newAnnotations = False
- # File menu actions
+ @Slot()
+ def export_data(self):
+ if self.session_id == None or self.trial_id == None:
+ msgBox = QMessageBox(QMessageBox.Warning, "Please select session and trial before trying to export data")
+ return
+ with self.db_sessionMaker() as db_sess:
+ base_directory = db_sess.query(Session).filter(Session.id == self.session_id).one().base_directory
+ fileName, selectedFilter = QFileDialog.getSaveFileName(
+ self.mainWindow,
+ caption="Data Export File Name",
+ # filter="HDF5 file (*.h5);;Neurodata Without Borders file (*.nwb)",
+ filter="NWB file (*.nwb)",
+ selectedFilter="NWB file (*.nwb)",
+ dir=base_directory)
+ if selectedFilter == "NWB file (*.nwb)":
+ #write to nwb file
+ with NWBHDF5IO(fileName, 'w') as io:
+ io.write(self.nwbFile)
+ else:
+ raise NotImplementedError(f"Data export format {selectedFilter} not supported")
@Slot()
def set_investigator(self):
@@ -375,10 +426,10 @@ def create_db(self):
# State-related methods
- def update_current_annotations(self):
+ def update_current_annotations(self, t):
self.current_annotations.clear()
for ch in self.active_channels:
- bouts = self.annotations.channel(ch).get_at(self.current_time)
+ bouts = self.annotations.channel(ch).get_at(t)
for bout in bouts:
if bout.is_visible():
self.current_annotations.append((ch, bout))
@@ -388,20 +439,30 @@ def update_current_annotations(self):
bout.color())
for (c, bout) in self.current_annotations])
+ @Slot()
+ def updateNWBFile(self):
+ if isinstance(self.nwbFile, NWBFile):
+ self.nwbFile = self.annotations.exportToNWBFile(self.nwbFile)
+ self.nwbFileUpdated.emit()
+
+ def current_time(self) -> Timecode:
+ return self.player.currentTime()
+
+ @Slot(Timecode)
+ def noteTimeChanged(self, t: Timecode):
+ self.update_current_annotations(t)
+
def set_time(self, new_tc: Timecode):
if not isinstance(new_tc, Timecode):
new_tc = Timecode('30.0', new_tc)
new_tc = max(self.time_start, min(self.time_end, new_tc))
- if self.current_time != new_tc:
- self.current_time = new_tc
- self.update_current_annotations()
- self.timeChanged.emit(self.current_time)
+ self.player.setCurrentTime(new_tc)
def change_time(self, increment: Timecode):
- self.set_time(self.current_time + increment)
+ self.set_time(self.player.currentTime() + increment)
def get_time(self):
- return self.current_time
+ return self.player.currentTime()
@Slot()
def incrementTime(self):
@@ -433,7 +494,7 @@ def toNextEvent(self):
for (ch, bout) in self.current_annotations:
next_event = min(next_event, bout.end() + 1)
for ch in self.active_channels:
- next_bout = self.annotations.channel(ch).get_next_start(self.current_time)
+ next_bout = self.annotations.channel(ch).get_next_start(self.player.currentTime())
next_event = min(next_event, next_bout.start())
self.set_time(next_event)
@@ -443,10 +504,21 @@ def toPrevEvent(self):
for (ch, bout) in self.current_annotations:
prev_event = max(prev_event, bout.start() - 1)
for ch in self.active_channels:
- prev_bout = self.annotations.channel(ch).get_prev_end(self.current_time - 1)
+ prev_bout = self.annotations.channel(ch).get_prev_end(self.player.currentTime() - 1)
prev_event = max(prev_event, prev_bout.end())
self.set_time(prev_event)
+ def jumpToTime(self):
+ time = self.mainWindow.ui.currentTimeEdit.time().toPython()
+ dt = datetime.combine(datetime.fromtimestamp(0., tz=timezone.utc).date(),
+ time, tzinfo=timezone.utc)
+ self.set_time(Timecode(str(self.time_start.framerate), start_seconds=dt.timestamp()))
+
+ @Slot()
+ def jumpToFrame(self):
+ frame = self.mainWindow.ui.currentFrameBox.value()
+ self.set_time(Timecode(str(self.time_start.framerate), frames=frame))
+
# @Slot(QObject.event)
def processHotKey(self, event: QObject.event):
"""
@@ -473,13 +545,21 @@ def processHotKey(self, event: QObject.event):
# Is there a pending bout? If so, complete the annotation activity
if self.pending_bout:
- chan = self.active_channels[0]
- if self.pending_bout.start() > self.current_time:
+ if self.active_channels:
+ chan = self.active_channels[0]
+ else:
+ msgBox = QMessageBox(QMessageBox.Warning,
+ "No annotation channel found",
+ "No annotation channel found. Please add new annotation channel before \
+ doing annotations.")
+ msgBox.exec()
+ raise RuntimeError("Annotation channel should be added before doing annotations.")
+ if self.pending_bout.start() > self.player.currentTime():
# swap start and end before completing
self.pending_bout.set_end(self.pending_bout.start())
- self.pending_bout.set_start(self.current_time)
+ self.pending_bout.set_start(self.player.currentTime())
else:
- self.pending_bout.set_end(self.current_time)
+ self.pending_bout.set_end(self.player.currentTime())
if do_delete:
# truncate or remove any bouts of the same behavior as pending_bout
@@ -502,7 +582,7 @@ def processHotKey(self, event: QObject.event):
self.noteAnnotationsChanged(start, end)
else:
# Start a new annotation activity by saving a pending_bout
- self.pending_bout = Bout(self.current_time, self.current_time, beh)
+ self.pending_bout = Bout(self.player.currentTime(), self.player.currentTime(), beh)
@Slot()
def quit(self, event):
@@ -521,10 +601,9 @@ def quit(self, event):
time.sleep(3./30.) # wait for threads to shut down
QApplication.instance().quit()
- def newVideoWidget(self, video_path: str) -> VideoFrame:
+ def newVideoWidget(self, video_path: str, start_time: Timecode, forcePixmapMode: bool) -> VideoFrame:
video = VideoFrame(self)
- video.load_video(video_path)
- self.timeChanged.connect(video.updateFrame)
+ video.load_video(video_path, start_time, forcePixmapMode)
self.currentAnnotsChanged.connect(video.updateAnnots)
return video
@@ -535,15 +614,114 @@ def newNeuralWidget(self, neuralData, base_dir: str) -> NeuralFrame:
self.active_channel_changed.connect(neuralWidget.setActiveChannel)
return neuralWidget
+ def timeToTimecode(self, time_start_end, video_info, sample_rate=30.):
+ times = list()
+ for ix, item in enumerate(video_info):
+ if VideoFrame(self).supported_by_native_player(item[1].file_path):
+ times.extend(time_start_end['video'][ix])
+ else:
+ continue
+ if len(times)==0:
+ for key in time_start_end:
+ times.extend(list(itertools.chain(*time_start_end[key])))
+ min_time, max_time = min(times), max(times)
+ self.min_max_times = [min_time, max_time]
+
+ for key in time_start_end:
+ # We need to subtract min_time from a list of lists, so convert the inner list
+ # into a numpy array, so that the subtraction is dispatched across all the elements.
+ time_start_end[key] = [list(np.array(t) - min_time) for t in time_start_end[key]]
+
+ for key in time_start_end:
+ if time_start_end[key]:
+ self.time_start_end_timecode[key] = [[Timecode(sample_rate, frames=1)+Timecode(sample_rate, start_seconds=start),
+ Timecode(sample_rate, frames=1)+Timecode(sample_rate, start_seconds=end)]
+ for start, end in time_start_end[key]]
+ else:
+ self.time_start_end_timecode[key] = list()
+
+ timecodes = []
+ for key in self.time_start_end_timecode:
+ timecodes.extend(itertools.chain(*self.time_start_end_timecode[key]))
+
+ if min(timecodes).float<0:
+ self.time_start = self.time_start
+ else:
+ self.time_start = min(timecodes)
+ self.time_end = max(timecodes)
+
+ # set the start time for each video widget, pulling it out of the list pair
+ ix = 0
+ start_end_list = self.time_start_end_timecode['video']
+ for widget in self.widgets:
+ if widget.dataExportType.lower() != 'video':
+ continue
+ assert(ix < len(start_end_list))
+ widget.set_start_time(start_end_list[ix][0])
+ ix += 1
+
+ self.set_time(self.time_start)
+
+ def getTimes(self, videos, annotation, loadNeural, loadAudio):
+
+ # just clearing each list was not working for some reason
+ # so, re-initializing the dict
+ self.time_start_end = {
+ 'video':[],
+ 'annotations':[],
+ 'neural':[]
+ } # 'data_type' : [start_time, end_time]
+ sample_rate = 30.0
+ sample_rate_set = False
+
+ for video in videos:
+ video_start_time = self.trial_start_time + video.offset_time
+ self.time_start_end['video'].append([video_start_time,
+ video_start_time + float(24*60*60)])
+ if not sample_rate_set:
+ sample_rate = video.sample_rate
+ sample_rate_set = True
+ if annotation:
+ sample_rate = annotation.sample_rate
+ start_time = self.trial_start_time + annotation.offset_time
+ end_time = start_time + Timecode(sample_rate, frames=annotation.stop_frame).float
+ self.time_start_end['annotations'].append([start_time,
+ end_time])
+ if loadNeural:
+ with self.db_sessionMaker() as db_sess:
+ trial = db_sess.query(Trial).filter(Trial.id == self.trial_id).one()
+ if trial.neural_data:
+ running_time = Timecode(str(trial.neural_data[0].sample_rate),
+ frames=trial.neural_data[0].stop_frame-trial.neural_data[0].start_frame).float
+ neural_start_time = self.trial_start_time + trial.neural_data[0].offset_time
+ self.time_start_end['neural'].append([neural_start_time,
+ neural_start_time + running_time])
+ if loadAudio:
+ # not supported yet
+ pass
+
+ return sample_rate
+
@Slot()
def loadTrial(self, videos, annotation, loadPose, loadNeural, loadAudio):
- self.video_widgets.clear()
+ if self.player.timeSource():
+ self.player.timeSource().disconnectSignals(self)
+ self.player.setTimeSource(None)
+ for widget in self.widgets:
+ if widget.dataExportType.lower() != 'video':
+ continue
+ widget.reset()
+ widget.hide()
+ widget.deleteLater()
+ self.widgets.clear()
+ sample_rate = self.getTimes(videos, annotation, loadNeural, loadAudio)
progressTotal = (
len(videos) +
(1 if annotation else 0) +
(len(videos) if loadPose else 0) + # potentially one pose per video
(1 if loadNeural else 0) +
(1 if loadAudio else 0))
+ timeSource = None
progressCompleted = 0
progress = QProgressDialog("Loading Trial ...", "Cancel", 0, progressTotal, None)
progress.setWindowModality(Qt.WindowModal)
@@ -553,19 +731,30 @@ def loadTrial(self, videos, annotation, loadPose, loadNeural, loadAudio):
base_directory = session.base_directory
# Qt converts paths to platform-specific separators under the hood,
# so it's correct to use forward-slash ("/") here across all platforms
+ # rather than os.sep
base_dir = base_directory + "/"
runningTime = 0.
- sample_rate = 30.0
- sample_rate_set = False
- for ix, video_data in enumerate(videos):
+ # Find the native video that starts earliest, if there are native videos
+ # Start a native player for that, and then force pixmap players for any
+ # other videos.
+ # We give native videos priority by artificially pushing their start times 6 hours earlier
+ pairs = [[self.trial_start_time + video_data.offset_time, video_data] for video_data in videos]
+ for item in pairs:
+ if VideoFrame(self).supported_by_native_player(item[1].file_path):
+ item[0] -= 6 * 60 * 60
+ ordered_pairs = sorted(pairs, key=lambda tuple: tuple[0])
+ for ix, item in enumerate(ordered_pairs):
+ video_data = item[1]
progress.setLabelText(f"Loading video #{ix}...")
progress.setValue(progressCompleted)
if not isabs(video_data.file_path):
path = base_dir + video_data.file_path
else:
path = video_data.file_path
- widget = self.newVideoWidget(fix_path(path))
- self.video_widgets.append(widget)
+ # force pixmap mode if we already are using a native player as a time source
+ widget = self.newVideoWidget(fix_path(path),
+ self.trial_start_time + video_data.offset_time,
+ bool(timeSource))
if loadPose:
video = db_sess.query(VideoData).filter(VideoData.id == video_data.id).one()
if len(video.pose_data) > 0:
@@ -575,12 +764,18 @@ def loadTrial(self, videos, annotation, loadPose, loadNeural, loadAudio):
pose_path = video.pose_data[0].file_path
if not isabs(pose_path):
pose_path = base_dir + pose_path
- pose_class = self.pose_registry(video.pose_data[0].format)
+ pose_registry = PoseRegistry()
+ pose_registry.load_plugins()
+ pose_class = pose_registry(video.pose_data[0].format)
+ pose_class.loadPoses(self.mainWindow, pose_path, path)
widget.set_pose_class(pose_class)
- pose_class.loadPoses(self.mainWindow, pose_path)
else:
print("No pose data in trial to load.")
progressCompleted += 1
+ self.widgets.append(widget)
+ # if this widget can be a time source, use it as such
+ if not timeSource and widget.getPlayer():
+ timeSource = TimeSourceQMediaPlayer(self.timeChanged, widget.scene)
qr = widget.frameGeometry()
# qr.moveCenter(self.screen_center + spacing)
@@ -588,30 +783,44 @@ def loadTrial(self, videos, annotation, loadPose, loadNeural, loadAudio):
widget.move(qr.topLeft()) #TODO: need to space multiple videos out
widget.show()
runningTime = max(runningTime, widget.running_time())
- if not sample_rate_set:
- sample_rate = widget.sample_rate()
- sample_rate_set = True
+ self.time_start_end['video'][ix][1] = self.time_start_end['video'][ix][0] + runningTime
progressCompleted += 1
+ # instantiate a time source from a QTimer if we don't already have one
+ if not timeSource:
+ timeSource = TimeSourceQTimer(self.timeChanged)
+ # set up the time source parameters and connect it
+ timeSource.setMaxFrameRate(2.)
+ timeSource.setMinFrameRate(0.125)
+ self.timeChanged.connect(timeSource.setCurrentTime)
+ self.player.setTimeSource(timeSource)
+ # compute the trial start and end times from the min and max of all the media
+ # This also (finally) sets the time in the timeSource via self.set_time(), which calls
+ # self.player.setCurrentTime(), which sets the timeSource time.
+ self.timeToTimecode(self.time_start_end, ordered_pairs, sample_rate)
+
+ # Load Annotations
if annotation:
if not isabs(annotation.file_path):
annot_path = base_dir + annotation.file_path
else:
annot_path = annotation.file_path
annot_path = fix_path(annot_path)
- sample_rate = annotation.sample_rate
+ annot_start_time = self.trial_start_time + annotation.offset_time
else:
annot_path = None
+ annot_start_time = None
progress.setLabelText("Loading annotations...")
progress.setValue(progressCompleted)
- self.load_or_init_annotations(annot_path, sample_rate, runningTime)
+ self.load_or_init_annotations(annot_path, sample_rate, annot_start_time)
# try:
# self.load_or_init_annotations(annot_path, sample_rate, runningTime)
# except Exception as e:
# QMessageBox.about(self.selectTrialWindow, "Error", f"Attempt to load annotations from {annot_path} "
# f"failed with error {str(e)}")
- # for widget in self.video_widgets:
- # widget.close()
- # self.video_widgets.clear()
+ # for widget in self.widgets:
+ # if widget.dataExportType.lower() == 'video':
+ # widget.close()
+ # self.widgets.clear()
# return False
progressCompleted += 1
self.noteAnnotationsChanged(self.time_start, self.time_end)
@@ -624,12 +833,9 @@ def loadTrial(self, videos, annotation, loadPose, loadNeural, loadAudio):
progress.setLabelText("Loading neural data...")
progress.setValue(progressCompleted)
neuralWidget = self.newNeuralWidget(trial.neural_data[0], base_dir)
- self.neural_widgets.append(neuralWidget)
+ self.widgets.append(neuralWidget)
if self.annotationsScene:
neuralWidget.overlayAnnotations(self.annotationsScene)
- if runningTime==0 and self.newAnnotations and len(videos)==0:
- running_time = trial.neural_data[0].sample_rate * (trial.neural_data[0].stop_frame-trial.neural_data[0].start_frame)
- self.time_end = Timecode(self.time_start.framerate, start_seconds=self.time_start.float + running_time)
neuralWidget.show()
progressCompleted += 1
else:
@@ -647,6 +853,7 @@ def loadTrial(self, videos, annotation, loadPose, loadNeural, loadAudio):
progress.setLabelText("Done")
progress.setValue(progressCompleted)
self.set_time(self.time_start)
+ self.exportToNWBFile()
return True
@Slot()
@@ -663,20 +870,85 @@ def noteAnnotationsChanged(self, start=None, end=None):
end = self.time_end.float
elif isinstance(end, Timecode):
end = end.float
- self.newAnnotations = True
+ #self.newAnnotations = True
self.annotationsScene.sceneChanged(start, end)
+ @Slot(int)
+ def noteVideoDurationChanged(self, duration):
+ video_end = duration / 1000.
+ time_end = max(video_end, self.time_end.float)
+ self.time_end = Timecode(self.time_end.framerate, start_seconds=time_end)
+ self.annotations.set_end_frame(self.time_end)
+
def deleteAnnotationsByName(self, behaviorName):
beh = self.behaviors.get(behaviorName)
for chan in self.annotations.channel_names():
self.annotations.truncate_or_remove_bouts(beh, self.time_start, self.time_end, chan)
+ def exportToNWBFile(self):
+ print(f"Export data from {self.dataExportType} to NWB file")
+ # export session and trial metadata here
+ investigatorInfo, animalInfo, trialInfo = dict(), dict(), dict()
+ with self.db_sessionMaker() as db_sess:
+ # get session from DB
+ session = db_sess.query(Session).filter(Session.id == self.session_id).scalar()
+ assert session != None
+
+ # get investigator data from DB
+ investigator = db_sess.query(Investigator).filter(Investigator.id == session.investigator_id).scalar()
+ assert investigator != None
+ investigatorInfo['user_name'] = investigator.user_name
+ investigatorInfo['first_name'] = investigator.first_name
+ investigatorInfo['last_name'] = investigator.last_name
+ investigatorInfo['institution'] = investigator.institution #institution
+ investigatorInfo['e_mail'] = investigator.e_mail
+
+ # get animal data from DB
+ animal = db_sess.query(Animal).filter(Animal.id == session.animal_id).scalar()
+ assert animal != None
+ animalInfo['animal_services_id'] = animal.animal_services_id
+ animalInfo['nickname'] = animal.nickname
+ animalInfo['genotype'] = animal.genotype
+ animalInfo['date_of_birth'] = (datetime.fromisoformat(animal.dob.isoformat())).replace(tzinfo=tzlocal())
+ animalInfo['sex'] = animal.sex.name
+
+ # get trial data from DB
+ trial = db_sess.query(Trial).filter(Trial.session_id == self.session_id, Trial.id == self.trial_id).scalar()
+ assert trial != None
+ trialInfo['trial_num'] = trial.trial_num
+ trialInfo['stimulus'] = trial.stimulus
+ #TODO: add date of trial
+
+ self.nwbFile = NWBFile(session_description=f"investigator : {investigatorInfo['user_name']}, animal : {animalInfo['nickname']}, date : {animalInfo['date_of_birth'].date().isoformat()}",
+ identifier=f"{animalInfo['nickname']}_{datetime.fromtimestamp(self.min_max_times[0], tz=tzlocal()).date().isoformat()}",
+ session_start_time=datetime.fromtimestamp(self.min_max_times[0], tz=tzlocal()),
+ session_id=f"session_{self.session_id}",
+ institution=investigatorInfo['institution'],
+ experimenter=f"{investigatorInfo['first_name']} {investigatorInfo['last_name']} (Email : {investigatorInfo['e_mail']})",
+ notes=f"Stimulus : {trialInfo['stimulus']}",
+ file_create_date=datetime.now(tz=tzlocal()))
+ self.nwbFile.subject = Subject(description=animalInfo['nickname'],
+ genotype=animalInfo['genotype'],
+ sex=animalInfo['sex'],
+ subject_id=str(animalInfo['animal_services_id']),
+ date_of_birth=animalInfo['date_of_birth'])
+
+ #ask each widget to export its data
+ for widget in self.widgets:
+ self.nwbFile = widget.exportToNWBFile(self.nwbFile)
+ # export annotations data
+ self.nwbFile = self.annotations.exportToNWBFile(self.nwbFile)
+
+
# Signals
quitting = Signal()
timeChanged = Signal(Timecode)
currentAnnotsChanged = Signal(list)
active_channel_changed = Signal(str)
annotationsSceneHeightChanged = Signal(float)
+ newChannelAdded = Signal()
+ behaviorsChanged = Signal()
+ nwbFileUpdated = Signal()
if __name__ == "__main__":
# Create the Qt Application
diff --git a/src/buildui b/src/buildui
index 7b2d3e5..3d10afc 100755
--- a/src/buildui
+++ b/src/buildui
@@ -1,2 +1,2 @@
#!/bin/zsh
-pyside6-uic $1 | sed s/PySide6/qtpy/ >${1/.ui/_ui.py}
\ No newline at end of file
+pyside2-uic $1 | sed s/PySide2/qtpy/ >${1/.ui/_ui.py}
\ No newline at end of file
diff --git a/src/dataExporter.py b/src/dataExporter.py
new file mode 100644
index 0000000..c0f9934
--- /dev/null
+++ b/src/dataExporter.py
@@ -0,0 +1,22 @@
+# dataExporter.py
+"""
+This module provides the base class for data export.
+Classes (including widgets, etc.) should derive from it, probably along with
+another class such as a Qt base class.
+"""
+
+from pynwb import NWBFile
+
+class DataExporter:
+ """
+ Data export base class
+ """
+
+ def __init__(self):
+ self.dataExportType = "None" # derived class should override this
+
+ def exportToNWBFile(self, nwbFile: NWBFile) -> NWBFile:
+ raise NotImplementedError("Derived class needs to implement this")
+
+ def exportToDict(self, d: dict):
+ raise NotImplementedError("Derived class needs to implement this")
\ No newline at end of file
diff --git a/src/db/editTrialDialog.py b/src/db/editTrialDialog.py
index 3b3f6aa..6b2a297 100644
--- a/src/db/editTrialDialog.py
+++ b/src/db/editTrialDialog.py
@@ -6,11 +6,12 @@
from db.editTrialDialog_ui import Ui_EditTrialDialog
from annot.annot import Annotations
from qtpy.QtCore import QModelIndex, Qt, Signal, Slot
-from qtpy.QtGui import QBrush, QIntValidator
+from qtpy.QtGui import QBrush, QIntValidator, QStandardItemModel
from qtpy.QtWidgets import (QDialog, QFileDialog, QHeaderView, QMessageBox,
- QTreeWidgetItem, QTreeWidgetItemIterator)
+ QTreeWidgetItem, QTreeWidgetItemIterator, QLineEdit)
from models.tableModel import EditableTableModel
-from widgets.deleteableViews import DeleteableTreeWidget
+from widgets.deleteableViews import (DeleteableTreeWidget,
+ OffsetTimeItemDelegate, CustomComboBoxDelegate)
# from models.videoTreeModel import VideoTreeModel
from timecode import Timecode
from os.path import expanduser, getmtime, basename
@@ -29,6 +30,7 @@ def addPoseHeaderIfNeeded(parent):
flags &= ~Qt.ItemIsEditable
flags &= ~Qt.ItemIsSelectable
poseHeaderItem.setFlags(flags)
+ poseHeaderItem.setToolTip(header.index('Offset Time'), 'ss.ms')
font = poseHeaderItem.font(0)
font.setBold(True)
for column in range(poseHeaderItem.columnCount()):
@@ -55,6 +57,14 @@ def __init__(self, bento, session_id, trial_id=None):
self.ui.annotationsSearchPushButton.clicked.connect(self.addAnnotationFiles)
self.ui.addPosePushButton.clicked.connect(self.addPoseFileToVideo)
self.ui.trialNumLineEdit.setValidator(QIntValidator())
+ self.ui.trialDateTimeEdit.setDisplayFormat("yyyy-MM-dd HH:mm:ss.zzz")
+ self.ui.trialDateTimeEdit.setDateTime(datetime.strptime(
+ str(datetime.now().isoformat(sep=" ", timespec="milliseconds")),
+ "%Y-%m-%d %H:%M:%S.%f"
+ ))
+ self.ui.annotationsDeleteButton.clicked.connect(self.deleteAnnotationFiles)
+ self.ui.neuralsDeleteButton.clicked.connect(self.deleteNeuralFiles)
+ self.ui.videoPoseDeleteButton.clicked.connect(self.deleteVideoPoseFiles)
self.video_data = None
self.trial_id = trial_id
@@ -69,6 +79,8 @@ def updateUIFromCurrentTrial(self):
if trial:
self.ui.trialNumLineEdit.setText(str(trial.trial_num))
self.ui.stimulusLineEdit.setText(trial.stimulus)
+ trial_start_time = str(datetime.fromtimestamp(trial.trial_start_time).isoformat(sep=' ', timespec='milliseconds'))
+ self.ui.trialDateTimeEdit.setDateTime(datetime.strptime(trial_start_time, "%Y-%m-%d %H:%M:%S.%f"))
self.populateVideosTreeWidget(True)
self.populateNeuralsTableView(True)
self.populateAnnotationsTableView(True)
@@ -107,13 +119,16 @@ def populateVideosTreeWidget(self, updateTrialNum):
self.ui.videosTreeWidget.setHeaderLabels(header)
self.ui.videosTreeWidget.hideColumn(header.index('id'))
self.ui.videosTreeWidget.hideColumn(header.index('trial_id'))
+ self.ui.videosTreeWidget.setItemDelegate(OffsetTimeItemDelegate())
headerItem = self.ui.videosTreeWidget.headerItem()
+ headerItem.setToolTip(header.index('Offset Time'), 'ss.ms')
font = headerItem.font(0)
font.setBold(True)
for column in range(headerItem.columnCount()):
headerItem.setFont(column, font)
headerItem.setTextAlignment(column, Qt.AlignCenter)
headerSet = True
+ videoTreeItem.setToolTip(header.index('Offset Time'), 'ss.ms')
for ix, key in enumerate(header):
videoTreeItem.setData(ix, Qt.EditRole, videoDict[key])
if len(elem.pose_data) > 0:
@@ -121,6 +136,7 @@ def populateVideosTreeWidget(self, updateTrialNum):
for poseItem in elem.pose_data:
poseTreeItem = QTreeWidgetItem(videoTreeItem)
poseTreeItem.setFlags(poseTreeItem.flags() | Qt.ItemIsEditable)
+ poseTreeItem.setToolTip(poseItem.header().index('Offset Time'), 'ss.ms')
poseDict = poseItem.toDict()
for iy, poseKey in enumerate(poseItem.header()):
poseTreeItem.setData(iy, Qt.EditRole, poseDict[poseKey])
@@ -130,15 +146,15 @@ def populateVideosTreeWidget(self, updateTrialNum):
for column in range(self.ui.videosTreeWidget.columnCount()):
self.ui.videosTreeWidget.resizeColumnToContents(column)
- if updateTrialNum:
- self.ui.videosTreeWidget.itemSelectionChanged.connect(self.populateTrialNum)
- else:
- try:
- self.ui.videosTreeWidget.itemSelectionChanged.disconnect(self.populateTrialNum)
- except RuntimeError:
- # The above call raises RuntimeError if the signal is not connected,
- # which we can safely ignore.
- pass
+ # if updateTrialNum:
+ # self.ui.videosTreeWidget.itemSelectionChanged.connect(self.populateTrialNum)
+ # else:
+ # try:
+ # self.ui.videosTreeWidget.itemSelectionChanged.disconnect(self.populateTrialNum)
+ # except RuntimeError:
+ # # The above call raises RuntimeError if the signal is not connected,
+ # # which we can safely ignore.
+ # pass
def addVideoFile(self, file_path, baseDir, available_cameras):
"""
@@ -150,8 +166,10 @@ def addVideoFile(self, file_path, baseDir, available_cameras):
try:
if ext=='mp4'or ext=='avi':
reader = mp4Io_reader(file_path)
+ #create_time = getmtime(file_path)
elif ext=='seq':
reader = seqIo_reader(file_path, buildTable=False)
+ #create_time = 0.
else:
raise Exception(f"video format {ext} not supported.")
except Exception:
@@ -161,9 +179,8 @@ def addVideoFile(self, file_path, baseDir, available_cameras):
sample_rate = float(reader.header['fps'])
ts = reader.getTs(1)[0]
reader.close()
- dt = datetime.fromtimestamp(ts)
- # set the video start time
- start_time = Timecode(sample_rate, dt.time().isoformat()).float
+ # set the video offset time
+ offset_time = float(0.000)
if file_path.startswith(baseDir):
file_path = file_path[len(baseDir):]
@@ -177,7 +194,7 @@ def addVideoFile(self, file_path, baseDir, available_cameras):
'id': None,
'Video File Path': file_path,
'Sample Rate': sample_rate,
- 'Start Time': start_time,
+ 'Offset Time': offset_time,
'Camera Position': this_camera_position,
'trial_id': self.trial_id,
'pose_data': [],
@@ -190,9 +207,11 @@ def addVideoFile(self, file_path, baseDir, available_cameras):
self.ui.videosTreeWidget.setColumnCount(len(videoKeys))
self.ui.videosTreeWidget.setHeaderLabels(videoKeys)
self.ui.videosTreeWidget.hideColumn(videoKeys.index('id'))
+ self.ui.videosTreeWidget.setItemDelegate(OffsetTimeItemDelegate())
# Attach the video file to treeWidget as a top-level item
videoItem = QTreeWidgetItem(self.ui.videosTreeWidget)
videoItem.setFlags(videoItem.flags() | Qt.ItemIsEditable)
+ videoItem.setToolTip(videoKeys.index('Offset Time'), 'ss.ms')
# insert the data into item
for key in videoKeys:
videoItem.setData(videoKeys.index(key), Qt.EditRole, item[key])
@@ -237,9 +256,10 @@ def addPoseFileToVideo(self):
videosHeader = [videosHeaderItem.data(ix, Qt.DisplayRole) for ix in range(videosHeaderItem.columnCount())]
# insert the data into the pose child item
poseKeys = PoseData().keys
+ poseItem.setToolTip(poseKeys.index('Offset Time'), 'ss.ms')
poseItem.setData(poseKeys.index('Pose File Path'), Qt.EditRole, poseFilePath)
poseItem.setData(poseKeys.index('Sample Rate'), Qt.EditRole, videoItem.data(videosHeader.index('Sample Rate'), Qt.DisplayRole))
- poseItem.setData(poseKeys.index('Start Time'), Qt.EditRole, videoItem.data(videosHeader.index('Start Time'), Qt.DisplayRole))
+ poseItem.setData(poseKeys.index('Offset Time'), Qt.EditRole, videoItem.data(videosHeader.index('Offset Time'), Qt.DisplayRole))
poseItem.setData(poseKeys.index('Format'), Qt.EditRole, format)
poseItem.setData(poseKeys.index('video_id'), Qt.EditRole, videoItem.data(videosHeader.index('id'), Qt.DisplayRole))
poseItem.setData(poseKeys.index('trial_id'), Qt.EditRole, videoItem.data(videosHeader.index('trial_id'), Qt.DisplayRole))
@@ -264,12 +284,38 @@ def addVideoFiles(self):
self,
"Select Video Files to add to Trial",
baseDir,
- "Seq files (*.seq);;mp4 files (*.mp4);;Generic video files (*.avi)",
- "Seq files (*.seq)")
+ "Seq files (*.seq);;mp4 files (*.mp4);;Generic video files (*.avi);; Supported videos (*.seq *.mp4 *.avi)",
+ "Supported videos (*.seq *.mp4 *.avi)")
if len(videoFiles) > 0:
for file_path in videoFiles:
self.addVideoFile(file_path, baseDir, available_cameras)
+ @Slot()
+ def deleteVideoPoseFiles(self):
+ msgBox = QMessageBox(
+ QMessageBox.Question,
+ "Delete Rows",
+ "This will delete the selected file(s) along with any dependent "
+ "items from the database and cannot be undone. Okay to continue?",
+ buttons=QMessageBox.Yes | QMessageBox.Cancel)
+ result = msgBox.exec()
+ if result == QMessageBox.Yes:
+ for item in self.ui.videosTreeWidget.selectedItems():
+ item.setHidden(True)
+ if bool(item.parent()):
+ # pose item: hide the header if no children remain unhidden
+ nonHeaderChildCount = 0
+ poseHeaderIx = -1
+ parentChildCount = item.parent().childCount()
+ for child_ix in range(parentChildCount):
+ child = item.parent().child(child_ix)
+ if child.type() == DeleteableTreeWidget.PoseHeaderType:
+ poseHeaderIx = child_ix
+ elif not child.isHidden():
+ nonHeaderChildCount += 1
+ if nonHeaderChildCount == 0 and poseHeaderIx >= 0:
+ item.parent().child(poseHeaderIx).setHidden(True)
+
def updateVideoData(self, trial, db_sess):
videoHeader = VideoData().header()
poseHeader = PoseData().header()
@@ -376,15 +422,15 @@ def populateNeuralsTableView(self, updateTrialNum):
self.setNeuralModel(model)
selectionModel = self.ui.neuralsTableView.selectionModel()
- if updateTrialNum:
- selectionModel.selectionChanged.connect(self.populateTrialNum)
- else:
- try:
- selectionModel.selectionChanged.disconnect(self.populateTrialNum)
- except RuntimeError:
- # The above call raises RuntimeError if the signal is not connected,
- # which we can safely ignore.
- pass
+ # if updateTrialNum:
+ # selectionModel.selectionChanged.connect(self.populateTrialNum)
+ # else:
+ # try:
+ # selectionModel.selectionChanged.disconnect(self.populateTrialNum)
+ # except RuntimeError:
+ # # The above call raises RuntimeError if the signal is not connected,
+ # # which we can safely ignore.
+ # pass
def addNeuralFile(self, file_path, baseDir):
"""
@@ -405,13 +451,10 @@ def addNeuralFile(self, file_path, baseDir):
# Otherwise, we can only guess from the video file info.
if isinstance(self.video_data, dict):
sample_rate = self.video_data['sample_rate']
- start_time = self.video_data['start_time']
+ offset_time = float(0.)
else:
sample_rate = 30.0
- # get start time (seconds from midnight) from file create time
- create_time = datetime.fromtimestamp(getmtime(file_path))
- create_day_midnight = datetime.fromordinal(create_time.toordinal())
- start_time = create_time.timestamp() - create_day_midnight.timestamp()
+ offset_time = float(0.)
start_frame = 1
stop_frame = data.shape[1]
@@ -423,7 +466,7 @@ def addNeuralFile(self, file_path, baseDir):
'Neural File Path': file_path,
'Sample Rate': sample_rate,
'Format': 'CNMFE', # by default
- 'Start Time': start_time,
+ 'Offset Time': offset_time,
'Start Frame': start_frame,
'Stop Frame': stop_frame,
'trial_id': self.trial_id,
@@ -463,6 +506,18 @@ def addNeuralFiles(self):
for file_path in neuralFiles:
self.addNeuralFile(file_path, baseDir)
+ @Slot()
+ def deleteNeuralFiles(self):
+ msgBox = QMessageBox(
+ QMessageBox.Question,
+ "Delete Rows",
+ "This will delete the selected file(s) from the database and cannot be undone. Okay to continue?",
+ buttons=QMessageBox.Yes | QMessageBox.Cancel)
+ result = msgBox.exec()
+ if result == QMessageBox.Yes:
+ for ix in self.ui.neuralsTableView.selectedIndexes():
+ self.ui.neuralsTableView.hideRow(ix.row())
+
def updateNeuralData(self, trial, db_sess):
model = self.ui.neuralsTableView.model()
if model:
@@ -499,11 +554,15 @@ def setAnnotationsModel(self, model):
font = self.ui.annotationsTableView.horizontalHeader().font()
font.setBold(True)
self.ui.annotationsTableView.horizontalHeader().setFont(font)
- self.ui.annotationsTableView.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
+ self.ui.annotationsTableView.setItemDelegate(CustomComboBoxDelegate(self.bento.annotations_format))
+
+ #self.ui.annotationsTableView.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
self.ui.annotationsTableView.resizeColumnsToContents()
keys = AnnotationsData().keys
self.ui.annotationsTableView.hideColumn(keys.index('id')) # don't show the ID field, but we need it for reference
self.ui.annotationsTableView.hideColumn(keys.index('trial_id')) # also don't show the internal trial_id field
+
+
self.ui.annotationsTableView.setSortingEnabled(False)
self.ui.annotationsTableView.setAutoScroll(False)
if oldModel:
@@ -519,15 +578,15 @@ def populateAnnotationsTableView(self, updateTrialNum):
self.setAnnotationsModel(model)
selectionModel = self.ui.annotationsTableView.selectionModel()
- if updateTrialNum:
- selectionModel.selectionChanged.connect(self.populateTrialNum)
- else:
- try:
- selectionModel.selectionChanged.disconnect(self.populateTrialNum)
- except RuntimeError:
- # The above call raises RuntimeError if the signal is not connected,
- # which we can safely ignore.
- pass
+ # if updateTrialNum:
+ # selectionModel.selectionChanged.connect(self.populateTrialNum)
+ # else:
+ # try:
+ # selectionModel.selectionChanged.disconnect(self.populateTrialNum)
+ # except RuntimeError:
+ # # The above call raises RuntimeError if the signal is not connected,
+ # # which we can safely ignore.
+ # pass
def addAnnotationFile(self, file_path, baseDir):
"""
@@ -543,10 +602,10 @@ def addAnnotationFile(self, file_path, baseDir):
# if the file is a h5 file, we can read it using caiman.utils, as below
# neural_dict = load_dict_from_hdf5(file_path)
# Otherwise, we can only guess from the video file info.
- if isinstance(annotations, Annotations):
+ if isinstance(annotations, Annotations) and annotations.sample_rate():
sample_rate = annotations.sample_rate()
else:
- sample_rate = 30.0
+ sample_rate = None
if file_path.startswith(baseDir):
file_path = file_path[len(baseDir):]
@@ -558,12 +617,19 @@ def addAnnotationFile(self, file_path, baseDir):
if investigator:
annotator_name = investigator.user_name
+ if annotations.start_date_time():
+ trial_start_time = datetime.timestamp(datetime.fromisoformat(
+ self.ui.trialDateTimeEdit.textFromDateTime(self.ui.trialDateTimeEdit.dateTime())))
+ offset_time = datetime.timestamp(datetime.fromisoformat(str(annotations.start_date_time()))) - trial_start_time
+ else:
+ offset_time = float(0.)
+
item = {
'id': None,
'Annotations File Path': file_path,
'Sample Rate': sample_rate,
'Format': annotations.format(),
- 'Start Time': self.bento.time_start.float,
+ 'Offset Time': offset_time,
'Start Frame': annotations.start_frame(),
'Stop Frame': annotations.end_frame(),
'Annotator Name': annotator_name,
@@ -627,12 +693,25 @@ def addAnnotationFiles(self):
self,
"Select Annotation Files to add to Trial",
baseDir,
- "Caltech Annotation files (*.annot)",
- "Caltech Annotation files (*.annot)")
+ "Bento Annotation files (*.annot);;Boris Annotation files (*.csv);;\
+ Simba Annotation files (*.csv);;Caltech Annotation files (*.txt)",
+ "Bento Annotation files (*.annot)")
if len(annotationFiles) > 0:
for file_path in annotationFiles:
self.addAnnotationFile(file_path, baseDir)
+ @Slot()
+ def deleteAnnotationFiles(self):
+ msgBox = QMessageBox(
+ QMessageBox.Question,
+ "Delete Rows",
+ "This will delete the selected file(s) from the database and cannot be undone. Okay to continue?",
+ buttons=QMessageBox.Yes | QMessageBox.Cancel)
+ result = msgBox.exec()
+ if result == QMessageBox.Yes:
+ for ix in self.ui.annotationsTableView.selectedIndexes():
+ self.ui.annotationsTableView.hideRow(ix.row())
+
@Slot()
def accept(self):
with self.bento.db_sessionMaker() as db_sess:
@@ -646,6 +725,8 @@ def accept(self):
trial.session_id = self.session_id
trial.trial_num = self.ui.trialNumLineEdit.text()
trial.stimulus = self.ui.stimulusLineEdit.text()
+ trial.trial_start_time = datetime.timestamp(datetime.fromisoformat(
+ self.ui.trialDateTimeEdit.textFromDateTime(self.ui.trialDateTimeEdit.dateTime())))
self.updateVideoData(trial, db_sess)
self.updateNeuralData(trial, db_sess)
diff --git a/src/db/editTrialDialog.ui b/src/db/editTrialDialog.ui
index 8a33551..d3916cf 100644
--- a/src/db/editTrialDialog.ui
+++ b/src/db/editTrialDialog.ui
@@ -32,6 +32,23 @@
-
+ -
+
+
+ Trial Start Time:
+
+
+
+ -
+
+
+
+ 1
+ 0
+
+
+
+
-
@@ -104,6 +121,13 @@
+ -
+
+
+ Delete
+
+
+
-
@@ -145,6 +169,13 @@
+ -
+
+
+ Delete
+
+
+
-
@@ -192,6 +223,13 @@
+ -
+
+
+ Delete
+
+
+
-
@@ -249,6 +287,13 @@
+ -
+
+
+ Delete
+
+
+
-
@@ -306,6 +351,13 @@
+ -
+
+
+ Delete
+
+
+
-
diff --git a/src/db/editTrialDialog_ui.py b/src/db/editTrialDialog_ui.py
index 13631bd..823c3ac 100644
--- a/src/db/editTrialDialog_ui.py
+++ b/src/db/editTrialDialog_ui.py
@@ -3,24 +3,18 @@
################################################################################
## Form generated from reading UI file 'editTrialDialog.ui'
##
-## Created by: Qt User Interface Compiler version 6.2.2
+## Created by: Qt User Interface Compiler version 5.15.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
-from qtpy.QtCore import (QCoreApplication, QDate, QDateTime, QLocale,
- QMetaObject, QObject, QPoint, QRect,
- QSize, QTime, QUrl, Qt)
-from qtpy.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
- QFont, QFontDatabase, QGradient, QIcon,
- QImage, QKeySequence, QLinearGradient, QPainter,
- QPalette, QPixmap, QRadialGradient, QTransform)
-from qtpy.QtWidgets import (QAbstractButton, QAbstractItemView, QApplication, QDialog,
- QDialogButtonBox, QHBoxLayout, QHeaderView, QLabel,
- QLineEdit, QPushButton, QSizePolicy, QSpacerItem,
- QTreeView, QTreeWidgetItem, QVBoxLayout, QWidget)
+from qtpy.QtCore import *
+from qtpy.QtGui import *
+from qtpy.QtWidgets import *
+
+from widgets.deleteableViews import DeleteableTableView
+from widgets.deleteableViews import DeleteableTreeWidget
-from widgets.deleteableViews import (DeleteableTableView, DeleteableTreeWidget)
class Ui_EditTrialDialog(object):
def setupUi(self, EditTrialDialog):
@@ -42,6 +36,21 @@ def setupUi(self, EditTrialDialog):
self.generalInfoHorizontalLayout.addWidget(self.trialNumLineEdit)
+ self.trialStartTimeLabel = QLabel(EditTrialDialog)
+ self.trialStartTimeLabel.setObjectName(u"trialStartTimeLabel")
+
+ self.generalInfoHorizontalLayout.addWidget(self.trialStartTimeLabel)
+
+ self.trialDateTimeEdit = QDateTimeEdit(EditTrialDialog)
+ self.trialDateTimeEdit.setObjectName(u"trialDateTimeEdit")
+ sizePolicy = QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
+ sizePolicy.setHorizontalStretch(1)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.trialDateTimeEdit.sizePolicy().hasHeightForWidth())
+ self.trialDateTimeEdit.setSizePolicy(sizePolicy)
+
+ self.generalInfoHorizontalLayout.addWidget(self.trialDateTimeEdit)
+
self.stimulusLabel = QLabel(EditTrialDialog)
self.stimulusLabel.setObjectName(u"stimulusLabel")
@@ -49,11 +58,11 @@ def setupUi(self, EditTrialDialog):
self.stimulusLineEdit = QLineEdit(EditTrialDialog)
self.stimulusLineEdit.setObjectName(u"stimulusLineEdit")
- sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
- sizePolicy.setHorizontalStretch(1)
- sizePolicy.setVerticalStretch(0)
- sizePolicy.setHeightForWidth(self.stimulusLineEdit.sizePolicy().hasHeightForWidth())
- self.stimulusLineEdit.setSizePolicy(sizePolicy)
+ sizePolicy1 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
+ sizePolicy1.setHorizontalStretch(1)
+ sizePolicy1.setVerticalStretch(0)
+ sizePolicy1.setHeightForWidth(self.stimulusLineEdit.sizePolicy().hasHeightForWidth())
+ self.stimulusLineEdit.setSizePolicy(sizePolicy1)
self.generalInfoHorizontalLayout.addWidget(self.stimulusLineEdit)
@@ -91,6 +100,11 @@ def setupUi(self, EditTrialDialog):
self.videosSearchVerticalLayout.addWidget(self.addPosePushButton)
+ self.videoPoseDeleteButton = QPushButton(EditTrialDialog)
+ self.videoPoseDeleteButton.setObjectName(u"videoPoseDeleteButton")
+
+ self.videosSearchVerticalLayout.addWidget(self.videoPoseDeleteButton)
+
self.videosSearchVerticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.videosSearchVerticalLayout.addItem(self.videosSearchVerticalSpacer)
@@ -121,6 +135,11 @@ def setupUi(self, EditTrialDialog):
self.annotationsSearchVerticalLayout.addWidget(self.annotationsSearchPushButton)
+ self.annotationsDeleteButton = QPushButton(EditTrialDialog)
+ self.annotationsDeleteButton.setObjectName(u"annotationsDeleteButton")
+
+ self.annotationsSearchVerticalLayout.addWidget(self.annotationsDeleteButton)
+
self.annotationsSearchVerticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.annotationsSearchVerticalLayout.addItem(self.annotationsSearchVerticalSpacer)
@@ -152,6 +171,11 @@ def setupUi(self, EditTrialDialog):
self.neuralsSearchVerticalLayout.addWidget(self.neuralsSearchPushButton)
+ self.neuralsDeleteButton = QPushButton(EditTrialDialog)
+ self.neuralsDeleteButton.setObjectName(u"neuralsDeleteButton")
+
+ self.neuralsSearchVerticalLayout.addWidget(self.neuralsDeleteButton)
+
self.neuralsSearchVerticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.neuralsSearchVerticalLayout.addItem(self.neuralsSearchVerticalSpacer)
@@ -186,6 +210,11 @@ def setupUi(self, EditTrialDialog):
self.audiosSearchVerticalLayout.addWidget(self.audiosSearchPushButton)
+ self.audiosDeleteButton = QPushButton(EditTrialDialog)
+ self.audiosDeleteButton.setObjectName(u"audiosDeleteButton")
+
+ self.audiosSearchVerticalLayout.addWidget(self.audiosDeleteButton)
+
self.audiosSearchVerticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.audiosSearchVerticalLayout.addItem(self.audiosSearchVerticalSpacer)
@@ -220,6 +249,11 @@ def setupUi(self, EditTrialDialog):
self.othersSearchVerticalLayout.addWidget(self.othersSearchPushButton)
+ self.othersDeleteButton = QPushButton(EditTrialDialog)
+ self.othersDeleteButton.setObjectName(u"othersDeleteButton")
+
+ self.othersSearchVerticalLayout.addWidget(self.othersDeleteButton)
+
self.othersSearchVerticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.othersSearchVerticalLayout.addItem(self.othersSearchVerticalSpacer)
@@ -253,17 +287,23 @@ def setupUi(self, EditTrialDialog):
def retranslateUi(self, EditTrialDialog):
EditTrialDialog.setWindowTitle(QCoreApplication.translate("EditTrialDialog", u"Add or Edit Trial", None))
self.trialNumLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Trial Num: ", None))
+ self.trialStartTimeLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Trial Start Time: ", None))
self.stimulusLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Stimulus: ", None))
self.videosLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Video Files: ", None))
self.videosSearchPushButton.setText(QCoreApplication.translate("EditTrialDialog", u"Search...", None))
self.addPosePushButton.setText(QCoreApplication.translate("EditTrialDialog", u"Add Pose...", None))
+ self.videoPoseDeleteButton.setText(QCoreApplication.translate("EditTrialDialog", u"Delete", None))
self.annotationsLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Annotation Files: ", None))
self.annotationsSearchPushButton.setText(QCoreApplication.translate("EditTrialDialog", u"Search...", None))
+ self.annotationsDeleteButton.setText(QCoreApplication.translate("EditTrialDialog", u"Delete", None))
self.neuralsLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Neural Files: ", None))
self.neuralsSearchPushButton.setText(QCoreApplication.translate("EditTrialDialog", u"Search...", None))
+ self.neuralsDeleteButton.setText(QCoreApplication.translate("EditTrialDialog", u"Delete", None))
self.audiosLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Audio Files: ", None))
self.audiosSearchPushButton.setText(QCoreApplication.translate("EditTrialDialog", u"Search...", None))
+ self.audiosDeleteButton.setText(QCoreApplication.translate("EditTrialDialog", u"Delete", None))
self.othersLabel.setText(QCoreApplication.translate("EditTrialDialog", u"Other Files: ", None))
self.othersSearchPushButton.setText(QCoreApplication.translate("EditTrialDialog", u"Search...", None))
+ self.othersDeleteButton.setText(QCoreApplication.translate("EditTrialDialog", u"Delete", None))
# retranslateUi
diff --git a/src/db/schema_sqlalchemy.py b/src/db/schema_sqlalchemy.py
index 5f79119..4ed8020 100644
--- a/src/db/schema_sqlalchemy.py
+++ b/src/db/schema_sqlalchemy.py
@@ -7,7 +7,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Date, Enum, Float, ForeignKey, Integer, String, Time, create_engine, func
from sqlalchemy.orm import relationship, sessionmaker
-from datetime import date
+from datetime import date, datetime
import enum
from os.path import expanduser, exists, sep
@@ -81,11 +81,11 @@ class NeuralData(Base):
file_path = Column(String(512))
sample_rate = Column(Float)
format = Column(String(128))
- start_time = Column(Float) # needs to be convertible to timecode
+ offset_time = Column(Float) # needs to be convertible to timecode
start_frame = Column(Integer)
stop_frame = Column(Integer)
trial = Column(Integer, ForeignKey('trial.id')) # 'Trial.trial_id' is quoted because it's a forward reference
- keys = ['id', 'Neural File Path', 'Sample Rate', 'Format', 'Start Time', 'Start Frame', 'Stop Frame', 'trial_id']
+ keys = ['id', 'Neural File Path', 'Sample Rate', 'Format', 'Offset Time', 'Start Frame', 'Stop Frame', 'trial_id']
def __init__(self, d=None, db_sess=None):
super().__init__()
@@ -93,8 +93,8 @@ def __init__(self, d=None, db_sess=None):
self.fromDict(d, db_sess)
def __repr__(self):
- return "" % (
- self.file_path, self.sample_rate, self.format, self.start_time, self.start_frame, self.stop_frame
+ return "" % (
+ self.file_path, self.sample_rate, self.format, self.offset_time, self.start_frame, self.stop_frame
)
def header(self):
@@ -106,7 +106,7 @@ def toDict(self):
'Neural File Path': self.file_path,
'Sample Rate': self.sample_rate,
'Format': self.format,
- 'Start Time': self.start_time,
+ 'Offset Time': self.offset_time,
'Start Frame': self.start_frame,
'Stop Frame': self.stop_frame,
'trial_id': self.trial
@@ -124,8 +124,8 @@ def fromDict(self, d, db_sess):
self.sample_rate = d['Sample Rate']
if 'Format' in d.keys() and d['Format'] != self.format:
self.format = d['Format']
- if 'Start Time' in d.keys() and d['Start Time'] != self.start_time:
- self.start_time = d['Start Time']
+ if 'Offset Time' in d.keys() and d['Offset Time'] != self.offset_time:
+ self.offset_time = d['Offset Time']
if 'Start Frame' in d.keys() and d['Start Frame'] != self.start_frame:
self.start_frame = d['Start Frame']
if 'Stop Frame' in d.keys() and d['Stop Frame'] != self.stop_frame:
@@ -142,13 +142,13 @@ class VideoData(Base):
id = Column(Integer, primary_key=True)
file_path = Column(String(512))
sample_rate = Column(Float)
- start_time = Column(Float) # needs to be convertible to timecode
+ offset_time = Column(Float) # needs to be convertible to timecode
camera_id = Column(Integer, ForeignKey('camera.id'))
trial = Column(Integer, ForeignKey('trial.id'))
camera = relationship('Camera')
pose_data = relationship('PoseData')
# don't put id first, because then we can't hide it in a QTreeWidget as column 0
- keys = ['Video File Path', 'Sample Rate', 'Start Time', 'Camera Position', 'trial_id', 'id']
+ keys = ['Video File Path', 'Sample Rate', 'Offset Time', 'Camera Position', 'trial_id', 'id']
def __init__(self, d=None, db_sess=None):
super().__init__()
@@ -156,8 +156,8 @@ def __init__(self, d=None, db_sess=None):
self.fromDict(d, db_sess)
def __repr__(self):
- return "" % (
- self.file_path, self.sample_rate, self.start_time
+ return "" % (
+ self.file_path, self.sample_rate, self.offset_time
)
def header(self):
@@ -169,7 +169,7 @@ def toDict(self):
'id': self.id,
'Video File Path': self.file_path,
'Sample Rate': self.sample_rate,
- 'Start Time': self.start_time,
+ 'Offset Time': self.offset_time,
'Camera Position': self.camera.position,
'trial_id': self.trial,
'pose_data': pose_data
@@ -189,8 +189,8 @@ def fromDict(self, d, db_sess):
self.file_path = d['Video File Path']
if 'Sample Rate' in d.keys() and d['Sample Rate'] != self.sample_rate:
self.sample_rate = d['Sample Rate']
- if 'Start Time' in d.keys() and d['Start Time'] != self.start_time:
- self.start_time = d['Start Time']
+ if 'Offset Time' in d.keys() and d['Offset Time'] != self.offset_time:
+ self.offset_time = d['Offset Time']
if self.camera_id != camera.id:
self.camera_id = camera.id
if self.camera != camera:
@@ -209,13 +209,13 @@ class AnnotationsData(Base):
file_path = Column(String(512))
sample_rate = Column(Float)
format = Column(String(128), nullable=False)
- start_time = Column(Float) # needs to be convertible to timecode
+ offset_time = Column(Float) # needs to be convertible to timecode
start_frame = Column(Integer)
stop_frame = Column(Integer)
annotator_name = Column(String(128))
method = Column(String(128)) # e.g. manual, MARS v1_8
trial = Column(Integer, ForeignKey('trial.id'))
- keys = ['id', 'Annotations File Path', 'Sample Rate', 'Format', 'Start Time', 'Start Frame',
+ keys = ['id', 'Annotations File Path', 'Sample Rate', 'Format', 'Offset Time', 'Start Frame',
'Stop Frame', 'Annotator Name', 'Method', 'trial_id']
def __init__(self, d=None, db_sess=None):
@@ -225,10 +225,10 @@ def __init__(self, d=None, db_sess=None):
def __repr__(self):
return ( "" % (
self.file_path, self.sample_rate, self.format,
- self.start_time, self.start_frame, self.stop_frame,
+ self.offset_time, self.start_frame, self.stop_frame,
self.annotator_name, self.method )
)
@@ -241,7 +241,7 @@ def toDict(self):
'Annotations File Path': self.file_path,
'Sample Rate': self.sample_rate,
'Format': self.format,
- 'Start Time': self.start_time,
+ 'Offset Time': self.offset_time,
'Start Frame': self.start_frame,
'Stop Frame': self.stop_frame,
'Annotator Name': self.annotator_name,
@@ -261,8 +261,8 @@ def fromDict(self, d, db_sess):
self.sample_rate = d['Sample Rate']
if 'Format' in d.keys() and d['Format'] != self.format:
self.format = d['Format']
- if 'Start Time' in d.keys() and d['Start Time'] != self.start_time:
- self.start_time = d['Start Time']
+ if 'Offset Time' in d.keys() and d['Offset Time'] != self.offset_time:
+ self.offset_time = d['Offset Time']
if 'Start Frame' in d.keys() and d['Start Frame'] != self.start_frame:
self.start_frame = d['Start Frame']
if 'Stop Frame' in d.keys() and d['Stop Frame'] != self.stop_frame:
@@ -283,12 +283,12 @@ class AudioData(Base):
id = Column(Integer, primary_key=True)
file_path = Column(String(512))
sample_rate = Column(Float)
- start_time = Column(Float) # needs to be convertible to timecode
+ offset_time = Column(Float) # needs to be convertible to timecode
processed_audio_file_path = Column(String(512))
# annotations_id = Column(Integer, ForeignKey('annotations.id'))
trial = Column(Integer, ForeignKey('trial.id'))
- # keys = ['id', 'file_path', 'sample_rate', 'start_time', 'processed_audio_file_path', 'annotations_id', 'trial_id']
- keys = ['id', 'Audio File Path', 'Sample Rate', 'Start Time', 'Processed Audio File Path', 'trial_id']
+ # keys = ['id', 'file_path', 'sample_rate', 'offset_time', 'processed_audio_file_path', 'annotations_id', 'trial_id']
+ keys = ['id', 'Audio File Path', 'Sample Rate', 'Offset Time', 'Processed Audio File Path', 'trial_id']
def __init__(self, d=None, db_sess=None):
super().__init__()
@@ -296,9 +296,9 @@ def __init__(self, d=None, db_sess=None):
self.fromDict(d, db_sess)
def __repr__(self):
- return ( "" % (
- self.file_path, self.sample_rate, self.start_time,
+ self.file_path, self.sample_rate, self.offset_time,
self.start_frame, self.stop_time, self.processed_audio_file_path )
)
@@ -310,7 +310,7 @@ def toDict(self):
'id': self.id,
'Audio File Path': self.file_path,
'Sample Rate': self.sample_rate,
- 'Start Time': self.start_time,
+ 'Offset Time': self.offset_time,
'Processed Audio File Path': self.processed_audio_file_path,
# 'annotations_id': self.annotations_id,
'trial_id': self.trial
@@ -326,8 +326,8 @@ def fromDict(self, d, db_sess):
self.file_path = d['Audio File Path']
if 'Sample Rate' in d.keys() and d['Sample Rate'] != self.sample_rate:
self.sample_rate = d['Sample Rate']
- if 'Start Time' in d.keys() and d['Start Time'] != self.start_time:
- self.start_time = d['Start Time']
+ if 'Offset Time' in d.keys() and d['Offset Time'] != self.offset_time:
+ self.offset_time = d['Offset Time']
if ('Processed Audio File Path' in d.keys()
and d['Processed Audio File Path'] != self.processed_audio_file_path):
self.processed_audio_file_path = d['Processed Audio File Path']
@@ -345,12 +345,12 @@ class PoseData(Base):
pose_id = Column(Integer, primary_key=True)
file_path = Column(String(512))
sample_rate = Column(Float)
- start_time = Column(Float) # needs to be convertible to timecode
+ offset_time = Column(Float) # needs to be convertible to timecode
format = Column(String(128), nullable=False)
video = Column(Integer, ForeignKey('video_data.id'))
trial = Column(Integer, ForeignKey('trial.id'))
# don't put id first, because then we can't hide it in a QTreeWidget as column 0
- keys = ['Pose File Path', 'Sample Rate', 'Start Time', 'Format', 'video_id', 'trial_id', 'id']
+ keys = ['Pose File Path', 'Sample Rate', 'Offset Time', 'Format', 'video_id', 'trial_id', 'id']
def __init__(self, d=None, db_sess=None):
super().__init__()
@@ -358,9 +358,9 @@ def __init__(self, d=None, db_sess=None):
self.fromDict(d, db_sess)
def __repr__(self):
- return ( "" % (
- self.file_path, self.sample_rate, self.start_time, self.format )
+ self.file_path, self.sample_rate, self.offset_time, self.format )
)
def header(self):
@@ -371,7 +371,7 @@ def toDict(self):
'id': self.pose_id,
'Pose File Path': self.file_path,
'Sample Rate': self.sample_rate,
- 'Start Time': self.start_time,
+ 'Offset Time': self.offset_time,
'Format': self.format,
'video_id': self.video,
'trial_id': self.trial
@@ -387,8 +387,8 @@ def fromDict(self, d, db_sess):
self.file_path = d['Pose File Path']
if 'Sample Rate' in d.keys() and d['Sample Rate'] != self.sample_rate:
self.sample_rate = d['Sample Rate']
- if 'Start Time' in d.keys() and d['Start Time'] != self.start_time:
- self.start_time = d['Start Time']
+ if 'Offset Time' in d.keys() and d['Offset Time'] != self.offset_time:
+ self.offset_time = d['Offset Time']
if 'Format' in d.keys() and d['Format'] != self.format:
self.format = d['Format']
if 'video_id' in d.keys() and d['video_id'] != self.video:
@@ -405,12 +405,12 @@ class OtherData(Base):
id = Column(Integer, primary_key=True)
file_path = Column(String(512))
sample_rate = Column(Float)
- start_time = Column(Float) # needs to be convertible to timecode
+ offset_time = Column(Float) # needs to be convertible to timecode
start_frame = Column(Integer)
stop_frame = Column(Integer)
format = Column(String(128))
trial = Column(Integer, ForeignKey('trial.id'))
- keys = ['id', 'File Path', 'Sample Rate', 'Start Time', 'Start Frame', 'Stop Frame', 'Format', 'trial_id']
+ keys = ['id', 'File Path', 'Sample Rate', 'Offset Time', 'Start Frame', 'Stop Frame', 'Format', 'trial_id']
def __init__(self, d=None, db_sess=None):
super().__init__()
@@ -418,10 +418,10 @@ def __init__(self, d=None, db_sess=None):
self.fromDict(d, db_sess)
def __repr__(self):
- return ( "" % (
- self.file_path, self.sample_rate, self.start_time,
- self.start_time, self.stop_frame, self.format )
+ self.file_path, self.sample_rate, self.offset_time,
+ self.offset_time, self.stop_frame, self.format )
)
def header(self):
@@ -432,7 +432,7 @@ def toDict(self):
'id': self.id,
'File Path': self.file_path,
'Sample Rate': self.sample_rate,
- 'Start Time': self.start_time,
+ 'Offset Time': self.offset_time,
'Start Frame': self.start_frame,
'Stop Frame': self.stop_frame,
'Format': self.format,
@@ -449,8 +449,8 @@ def fromDict(self, d, db_sess):
self.file_path = d['File Path']
if 'Sample Rate' in d.keys() and d['Sample Rate'] != self.sample_rate:
self.sample_rate = d['Sample Rate']
- if 'Start Time' in d.keys() and d['Start Time'] != self.start_time:
- self.start_time = d['Start Time']
+ if 'Offset Time' in d.keys() and d['Offset Time'] != self.offset_time:
+ self.offset_time = d['Offset Time']
if 'Start Frame' in d.keys() and d['Start Frame'] != self.start_frame:
self.start_frame = d['Start Frame']
if 'Stop Frame' in d.keys() and d['Stop Frame'] != self.stop_frame:
@@ -493,6 +493,7 @@ class Trial(Base):
id = Column(Integer, primary_key=True)
session_id = Column(Integer, ForeignKey('session.id'))
trial_num = Column(Integer)
+ trial_start_time = Column(Float)
stimulus = Column(String(128))
video_data = relationship('VideoData',
cascade='all, delete, delete-orphan') # could be more than one, e.g. top view, front view
diff --git a/src/db/trialWindow.py b/src/db/trialWindow.py
index 5448809..59db9dd 100644
--- a/src/db/trialWindow.py
+++ b/src/db/trialWindow.py
@@ -231,6 +231,7 @@ def loadTrial(self):
trial = db_session.query(Trial).where(Trial.id == trial_id).one()
self.bento.session_id = trial.session_id
self.bento.trial_id = trial.id
+ self.bento.trial_start_time = trial.trial_start_time
videos = []
videoSelectionModel = self.ui.videoTableView.selectionModel()
annotateSelectionModel = self.ui.annotationTableView.selectionModel()
diff --git a/src/mainWindow.py b/src/mainWindow.py
index 877480f..1941b52 100644
--- a/src/mainWindow.py
+++ b/src/mainWindow.py
@@ -6,6 +6,7 @@
from qtpy.QtCore import Qt, QEvent, Signal, Slot
from qtpy.QtWidgets import QMainWindow, QMenuBar
from db.trialWindow import TrialDockWidget
+from datetime import datetime, timezone
class MainWindow(QMainWindow):
@@ -27,6 +28,8 @@ def __init__(self, bento):
self.ui.nextButton.clicked.connect(bento.toNextEvent)
self.ui.previousButton.clicked.connect(bento.toPrevEvent)
self.ui.quitButton.clicked.connect(bento.quit)
+ self.ui.currentTimeEdit.set_bento(bento)
+ self.ui.currentFrameBox.textChanged.connect(bento.jumpToFrame)
bento.quitting.connect(self.close)
self.quitting.connect(bento.quit)
@@ -36,6 +39,7 @@ def __init__(self, bento):
self.ui.doubleFrameRateButton.clicked.connect(bento.player.doubleFrameRate)
self.ui.oneXFrameRateButton.clicked.connect(bento.player.resetFrameRate)
self.ui.annotationsView.set_bento(bento)
+ self.ui.annotationsView.set_showTickLabels(False)
self.ui.annotationsView.setScene(bento.annotationsScene)
bento.annotationsScene.sceneRectChanged.connect(self.ui.annotationsView.update)
self.ui.annotationsView.scale(10., self.ui.annotationsView.height())
@@ -51,6 +55,8 @@ def __init__(self, bento):
self.fileMenu = self.menuBar.addMenu("File")
self.saveAnnotationsAction = self.fileMenu.addAction("Save Annotations...")
self.saveAnnotationsAction.triggered.connect(bento.save_annotations)
+ self.exportDataAction = self.fileMenu.addAction("Export Data...")
+ self.exportDataAction.triggered.connect(bento.export_data)
self.fileMenuSeparatorAction = self.fileMenu.addSeparator()
self.setInvestigatorAction = self.fileMenu.addAction("Set Investigator...")
self.setInvestigatorAction.triggered.connect(bento.set_investigator)
@@ -115,7 +121,12 @@ def keyPressEvent(self, event):
@Slot(tc.Timecode)
def updateTime(self, t):
- self.ui.timeLabel.setText(f"{t} ({t.frame_number})")
+ time = datetime.fromtimestamp(t.float, tz=timezone.utc).time()
+ maxTime = datetime.fromtimestamp(self.bento.time_end.float, tz=timezone.utc).time()
+ self.ui.currentTimeEdit.setMaximumTime(maxTime)
+ self.ui.currentTimeEdit.setTime(time)
+ self.ui.currentFrameBox.setValue(int(t.frames))
+ self.ui.currentFrameBox.setMaximum(int(self.bento.time_end.frames))
self.ui.annotationsView.updatePosition(t)
self.show()
diff --git a/src/mainWindow.ui b/src/mainWindow.ui
index 1cde407..646619e 100644
--- a/src/mainWindow.ui
+++ b/src/mainWindow.ui
@@ -7,7 +7,7 @@
0
0
460
- 328
+ 350
@@ -19,13 +19,13 @@
460
- 328
+ 350
- 457
- 328
+ 493
+ 496
@@ -40,11 +40,52 @@
-
-
-
- Current Time
-
-
+
+
-
+
+
+ Current Time
+
+
+
+ -
+
+
+ HH:mm:ss.zzz
+
+
+
+ -
+
+
+ Current Frame
+
+
+
+ -
+
+
+ 1
+
+
+ 10000000
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+
-
@@ -133,14 +174,14 @@
-
- Previous
+ Previous Annotation
-
- Next
+ Next Annotation
@@ -151,7 +192,7 @@
-
- /2
+ 0.5x
@@ -165,7 +206,7 @@
-
- * 2
+ 2x
@@ -232,6 +273,11 @@
QGraphicsView
widgets.annotationsWidget.h
+
+ CustomTimeEdit
+ QTimeEdit
+
+
diff --git a/src/mainWindow_ui.py b/src/mainWindow_ui.py
index 0a05075..e04d48f 100644
--- a/src/mainWindow_ui.py
+++ b/src/mainWindow_ui.py
@@ -3,30 +3,31 @@
################################################################################
## Form generated from reading UI file 'mainWindow.ui'
##
-## Created by: Qt User Interface Compiler version 6.1.2
+## Created by: Qt User Interface Compiler version 5.15.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
-from qtpy.QtCore import * # type: ignore
-from qtpy.QtGui import * # type: ignore
-from qtpy.QtWidgets import * # type: ignore
+from qtpy.QtCore import *
+from qtpy.QtGui import *
+from qtpy.QtWidgets import *
from widgets.annotationsWidget import AnnotationsView
+from timeEdit import CustomTimeEdit
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
if not MainWindow.objectName():
MainWindow.setObjectName(u"MainWindow")
- MainWindow.resize(460, 328)
+ MainWindow.resize(460, 350)
sizePolicy = QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth())
MainWindow.setSizePolicy(sizePolicy)
- MainWindow.setMinimumSize(QSize(460, 328))
- MainWindow.setMaximumSize(QSize(457, 328))
+ MainWindow.setMinimumSize(QSize(460, 350))
+ MainWindow.setMaximumSize(QSize(493, 496))
self.centralwidget = QWidget(MainWindow)
self.centralwidget.setObjectName(u"centralwidget")
sizePolicy1 = QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
@@ -36,10 +37,36 @@ def setupUi(self, MainWindow):
self.centralwidget.setSizePolicy(sizePolicy1)
self.verticalLayout = QVBoxLayout(self.centralwidget)
self.verticalLayout.setObjectName(u"verticalLayout")
- self.timeLabel = QLabel(self.centralwidget)
- self.timeLabel.setObjectName(u"timeLabel")
+ self.currentTimeFrameLayout = QHBoxLayout()
+ self.currentTimeFrameLayout.setObjectName(u"currentTimeFrameLayout")
+ self.currentTimeLabel = QLabel(self.centralwidget)
+ self.currentTimeLabel.setObjectName(u"currentTimeLabel")
- self.verticalLayout.addWidget(self.timeLabel)
+ self.currentTimeFrameLayout.addWidget(self.currentTimeLabel)
+
+ self.currentTimeEdit = CustomTimeEdit(self.centralwidget)
+ self.currentTimeEdit.setObjectName(u"currentTimeEdit")
+
+ self.currentTimeFrameLayout.addWidget(self.currentTimeEdit)
+
+ self.currentFrameLabel = QLabel(self.centralwidget)
+ self.currentFrameLabel.setObjectName(u"currentFrameLabel")
+
+ self.currentTimeFrameLayout.addWidget(self.currentFrameLabel)
+
+ self.currentFrameBox = QSpinBox(self.centralwidget)
+ self.currentFrameBox.setObjectName(u"currentFrameBox")
+ self.currentFrameBox.setMinimum(1)
+ self.currentFrameBox.setMaximum(10000000)
+
+ self.currentTimeFrameLayout.addWidget(self.currentFrameBox)
+
+ self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
+
+ self.currentTimeFrameLayout.addItem(self.horizontalSpacer)
+
+
+ self.verticalLayout.addLayout(self.currentTimeFrameLayout)
self.annotLabel = QLabel(self.centralwidget)
self.annotLabel.setObjectName(u"annotLabel")
@@ -176,7 +203,9 @@ def setupUi(self, MainWindow):
def retranslateUi(self, MainWindow):
MainWindow.setWindowTitle(QCoreApplication.translate("MainWindow", u"MainWindow", None))
- self.timeLabel.setText(QCoreApplication.translate("MainWindow", u"Current Time", None))
+ self.currentTimeLabel.setText(QCoreApplication.translate("MainWindow", u"Current Time", None))
+ self.currentTimeEdit.setDisplayFormat(QCoreApplication.translate("MainWindow", u"HH:mm:ss.zzz", None))
+ self.currentFrameLabel.setText(QCoreApplication.translate("MainWindow", u"Current Frame", None))
self.annotLabel.setText(QCoreApplication.translate("MainWindow", u"annotation label", None))
self.toStartButton.setText(QCoreApplication.translate("MainWindow", u"|<", None))
self.fbButton.setText(QCoreApplication.translate("MainWindow", u"<<", None))
@@ -185,11 +214,11 @@ def retranslateUi(self, MainWindow):
self.stepButton.setText(QCoreApplication.translate("MainWindow", u">", None))
self.ffButton.setText(QCoreApplication.translate("MainWindow", u">>", None))
self.toEndButton.setText(QCoreApplication.translate("MainWindow", u">|", None))
- self.previousButton.setText(QCoreApplication.translate("MainWindow", u"Previous", None))
- self.nextButton.setText(QCoreApplication.translate("MainWindow", u"Next", None))
- self.halveFrameRateButton.setText(QCoreApplication.translate("MainWindow", u"/2", None))
+ self.previousButton.setText(QCoreApplication.translate("MainWindow", u"Previous Annotation", None))
+ self.nextButton.setText(QCoreApplication.translate("MainWindow", u"Next Annotation", None))
+ self.halveFrameRateButton.setText(QCoreApplication.translate("MainWindow", u"0.5x", None))
self.oneXFrameRateButton.setText(QCoreApplication.translate("MainWindow", u"1x", None))
- self.doubleFrameRateButton.setText(QCoreApplication.translate("MainWindow", u"* 2", None))
+ self.doubleFrameRateButton.setText(QCoreApplication.translate("MainWindow", u"2x", None))
self.newChannelPushButton.setText(QCoreApplication.translate("MainWindow", u"New Channel", None))
self.trialPushButton.setText(QCoreApplication.translate("MainWindow", u"Select Trial...", None))
self.quitButton.setText(QCoreApplication.translate("MainWindow", u"Quit", None))
diff --git a/src/models/tableModel.py b/src/models/tableModel.py
index 968c842..80fde6f 100644
--- a/src/models/tableModel.py
+++ b/src/models/tableModel.py
@@ -10,6 +10,8 @@ def __init__(self, parent, mylist, header, *args):
self.mylist = mylist
self.header = header
self.colorRoleColumns = set()
+ self.toolTipColumns = set()
+ self.setToolTipColumn()
def rowCount(self, parent):
return len(self.mylist)
@@ -20,7 +22,7 @@ def columnCount(self, parent):
def data(self, index, role):
if not isinstance(index, QModelIndex) or not index.isValid():
raise RuntimeError("Index is not valid")
- if role not in (Qt.DisplayRole, Qt.BackgroundRole, Qt.EditRole):
+ if role not in (Qt.DisplayRole, Qt.BackgroundRole, Qt.EditRole, Qt.ToolTipRole):
return None
row = self.mylist[index.row()]
if isinstance(row, (tuple, list)):
@@ -29,7 +31,9 @@ def data(self, index, role):
datum = row[self.header[index.column()]]
else:
raise RuntimeError(f"Can't handle indexing with data of type {type(row)}")
- if (index.column() in self.colorRoleColumns) and (role == Qt.BackgroundRole or role == Qt.EditRole):
+ if (index.column() in self.toolTipColumns) and (role == Qt.ToolTipRole):
+ return 'ss.ms'
+ elif (index.column() in self.colorRoleColumns) and (role == Qt.BackgroundRole or role == Qt.EditRole):
return QColor(datum)
elif (index.column() not in self.colorRoleColumns) and (role in (Qt.DisplayRole, Qt.EditRole)):
return str(datum)
@@ -39,6 +43,8 @@ def data(self, index, role):
def headerData(self, col, orientation, role):
if orientation == Qt.Horizontal and role == Qt.DisplayRole:
return self.header[col]
+ if col in self.toolTipColumns and role == Qt.ToolTipRole:
+ return 'ss.ms'
return None
def sort(self, col, order):
@@ -72,6 +78,10 @@ def appendData(self, newData):
def __iter__(self):
return TableModelIterator(self.mylist)
+ def setToolTipColumn(self):
+ if 'Offset Time' in self.header:
+ self.toolTipColumns.add(self.header.index('Offset Time'))
+
def setColorRoleColumn(self, column):
self.colorRoleColumns.add(column)
@@ -89,7 +99,7 @@ def __init__(self, parent, mylist, header, *args):
def flags(self, index):
flags = super().flags(index)
if index.column() not in self.immutableColumns:
- flags |= Qt.ItemIsEditable
+ flags = flags | Qt.ItemIsEditable | Qt.ToolTip
return flags
def setData(self, index, value, role=Qt.EditRole):
diff --git a/src/neural/neuralFrame.py b/src/neural/neuralFrame.py
index a8f6a3e..5c6d153 100644
--- a/src/neural/neuralFrame.py
+++ b/src/neural/neuralFrame.py
@@ -1,17 +1,22 @@
# neuralFrame.py
from neural.neuralFrame_ui import Ui_neuralFrame
+from processing.processing import ProcessingRegistry
+from plugin.plugin_processing_kMeansClustering import kMeansClustering
from qtpy.QtCore import Qt, Signal, Slot
from qtpy.QtGui import QPixmap
-from qtpy.QtWidgets import QFrame
+from qtpy.QtWidgets import QFrame, QMenu
import time
from timecode import Timecode
from widgets.neuralWidget import NeuralScene
+from dataExporter import DataExporter
+from pynwb import NWBFile
import numpy as np
from utils import fix_path
from os.path import isabs
-class NeuralFrame(QFrame):
+
+class NeuralFrame(QFrame, DataExporter):
openReader = Signal(str)
quitting = Signal()
@@ -19,8 +24,9 @@ class NeuralFrame(QFrame):
active_channel_changed = Signal(str)
def __init__(self, bento):
- # super(NeuralDockWidget, self).__init__()
- super(NeuralFrame, self).__init__()
+ QFrame.__init__(self)
+ DataExporter.__init__(self)
+ self.dataExportType = "neural"
self.bento = bento
# self.ui = Ui_NeuralDockWidget()
self.ui = Ui_neuralFrame()
@@ -29,6 +35,15 @@ def __init__(self, bento):
self.quitting.connect(self.bento.quit)
self.neuralScene = NeuralScene()
self.active_channel_changed.connect(self.neuralScene.setActiveChannel)
+ self.neuralPluginsMenu = QMenu("neural plugins")
+ self.ui.launchPlugin.setMenu(self.neuralPluginsMenu)
+ self.ui.launchPlugin.setToolTip("click to see neural plugin options")
+ self.eventTriggeredAvg = self.neuralPluginsMenu.addAction("Event Triggered Average")
+ self.clusteringMenu = QMenu("Clustering")
+ self.kMeansClustering = self.clusteringMenu.addAction("K-Means Clustering")
+ self.neuralPluginsMenu.addMenu(self.clusteringMenu)
+ self.eventTriggeredAvg.triggered.connect(self.launchEventTriggeredAvg)
+ self.kMeansClustering.triggered.connect(self.launchKMeansClustering)
self.ui.neuralView.setScene(self.neuralScene)
self.ui.neuralView.set_bento(bento)
self.ui.neuralView.scale(10., self.ui.neuralView.height())
@@ -58,10 +73,10 @@ def load(self, neuralData, base_dir):
neuralData.sample_rate,
neuralData.start_frame,
neuralData.stop_frame,
- self.bento.time_start,
+ self.bento.time_start_end_timecode['neural'][0][0],
self.ui.showTraceRadioButton.isChecked(),
self.ui.showHeatMapRadioButton.isChecked(),
- self.ui.showAnnotationsCheckBox.checkState()
+ self.ui.showAnnotationsCheckBox.isChecked()
)
self.ui.dataMinLabel.setText(f"{self.neuralScene.data_min:.3f}")
self.ui.dataMaxLabel.setText(f"{self.neuralScene.data_max:.3f}")
@@ -124,6 +139,7 @@ def showNeuralTraces(self, checked):
self.neuralScene.showTraces(checked)
if checked:
self.ui.showAnnotationsCheckBox.setEnabled(True)
+ self.ui.showAnnotationsCheckBox.setChecked(True)
self.neuralScene.showAnnotations(
self.ui.showAnnotationsCheckBox.isChecked()
)
@@ -134,6 +150,7 @@ def showNeuralHeatMap(self, checked):
self.neuralScene.showHeatmap(checked)
if checked:
self.ui.showAnnotationsCheckBox.setEnabled(False)
+ self.ui.showAnnotationsCheckBox.setChecked(False)
self.neuralScene.showAnnotations(False)
@Slot(int)
@@ -141,3 +158,22 @@ def showNeuralAnnotations(self, state):
if isinstance(self.neuralScene, NeuralScene):
self.neuralScene.showAnnotations(state > 0)
+ def launchEventTriggeredAvg(self):
+ self.processing_registry = ProcessingRegistry(self.nwbFile, self.bento)
+ self.processing_registry.load_plugins()
+ self.processing_class = self.processing_registry('BTA')
+ self.processing_class.show()
+
+ def launchKMeansClustering(self):
+ self.processing_registry = ProcessingRegistry(self.nwbFile, self.bento)
+ self.processing_registry.load_plugins()
+ self.processing_class = self.processing_registry('kMeansClustering')
+ self.processing_class.setNeural(self)
+ self.processing_class.show()
+
+ def exportToNWBFile(self, nwbFile: NWBFile):
+ print(f"Export data from {self.dataExportType} to NWB file")
+ if isinstance(self.neuralScene, NeuralScene):
+ self.nwbFile = self.neuralScene.exportToNWBFile(nwbFile)
+
+ return self.nwbFile
\ No newline at end of file
diff --git a/src/neural/neuralFrame.ui b/src/neural/neuralFrame.ui
index 0bc2925..23caa32 100644
--- a/src/neural/neuralFrame.ui
+++ b/src/neural/neuralFrame.ui
@@ -39,13 +39,80 @@
0
-
-
+
QLayout::SetMaximumSize
0
+
-
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 0
+
+
+
+
+ 16777215
+ 16777215
+
+
+
+
+ 13
+
+
+
+ Qt::StrongFocus
+
+
+
+
+
+ false
+
+
+ Launch Neural Plugin
+
+
+
+ 8
+ 8
+
+
+
+ QToolButton::InstantPopup
+
+
+ Qt::ToolButtonTextBesideIcon
+
+
+ Qt::DownArrow
+
+
+
-
diff --git a/src/neural/neuralFrame_ui.py b/src/neural/neuralFrame_ui.py
index f981c69..9b9344a 100644
--- a/src/neural/neuralFrame_ui.py
+++ b/src/neural/neuralFrame_ui.py
@@ -3,26 +3,19 @@
################################################################################
## Form generated from reading UI file 'neuralFrame.ui'
##
-## Created by: Qt User Interface Compiler version 6.2.2
+## Created by: Qt User Interface Compiler version 5.15.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
-from qtpy.QtCore import (QCoreApplication, QDate, QDateTime, QLocale,
- QMetaObject, QObject, QPoint, QRect,
- QSize, QTime, QUrl, Qt)
-from qtpy.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
- QFont, QFontDatabase, QGradient, QIcon,
- QImage, QKeySequence, QLinearGradient, QPainter,
- QPalette, QPixmap, QRadialGradient, QTransform)
-from qtpy.QtWidgets import (QAbstractScrollArea, QApplication, QCheckBox, QFrame,
- QGraphicsView, QHBoxLayout, QLabel, QLayout,
- QRadioButton, QSizePolicy, QSpacerItem, QVBoxLayout,
- QWidget)
+from qtpy.QtCore import *
+from qtpy.QtGui import *
+from qtpy.QtWidgets import *
from widgets.annotationsWidget import AnnotationsView
from widgets.neuralWidget import NeuralView
+
class Ui_neuralFrame(object):
def setupUi(self, neuralFrame):
if not neuralFrame.objectName():
@@ -42,6 +35,31 @@ def setupUi(self, neuralFrame):
self.horizontalLayout.setObjectName(u"horizontalLayout")
self.horizontalLayout.setSizeConstraint(QLayout.SetMaximumSize)
self.horizontalLayout.setContentsMargins(-1, -1, -1, 0)
+ self.leftHorizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
+
+ self.horizontalLayout.addItem(self.leftHorizontalSpacer)
+
+ self.launchPlugin = QToolButton(neuralFrame)
+ self.launchPlugin.setObjectName(u"launchPlugin")
+ sizePolicy1 = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
+ sizePolicy1.setHorizontalStretch(0)
+ sizePolicy1.setVerticalStretch(0)
+ sizePolicy1.setHeightForWidth(self.launchPlugin.sizePolicy().hasHeightForWidth())
+ self.launchPlugin.setSizePolicy(sizePolicy1)
+ self.launchPlugin.setMinimumSize(QSize(0, 0))
+ self.launchPlugin.setMaximumSize(QSize(16777215, 16777215))
+ font = QFont()
+ font.setPointSize(13)
+ self.launchPlugin.setFont(font)
+ self.launchPlugin.setFocusPolicy(Qt.StrongFocus)
+ self.launchPlugin.setAutoFillBackground(False)
+ self.launchPlugin.setIconSize(QSize(8, 8))
+ self.launchPlugin.setPopupMode(QToolButton.InstantPopup)
+ self.launchPlugin.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
+ self.launchPlugin.setArrowType(Qt.DownArrow)
+
+ self.horizontalLayout.addWidget(self.launchPlugin)
+
self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
self.horizontalLayout.addItem(self.horizontalSpacer)
@@ -95,11 +113,11 @@ def setupUi(self, neuralFrame):
self.neuralView = NeuralView(neuralFrame)
self.neuralView.setObjectName(u"neuralView")
- sizePolicy1 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
- sizePolicy1.setHorizontalStretch(1)
- sizePolicy1.setVerticalStretch(1)
- sizePolicy1.setHeightForWidth(self.neuralView.sizePolicy().hasHeightForWidth())
- self.neuralView.setSizePolicy(sizePolicy1)
+ sizePolicy2 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
+ sizePolicy2.setHorizontalStretch(1)
+ sizePolicy2.setVerticalStretch(1)
+ sizePolicy2.setHeightForWidth(self.neuralView.sizePolicy().hasHeightForWidth())
+ self.neuralView.setSizePolicy(sizePolicy2)
self.neuralView.setFrameShape(QFrame.NoFrame)
self.neuralView.setFrameShadow(QFrame.Plain)
self.neuralView.setLineWidth(0)
@@ -113,11 +131,11 @@ def setupUi(self, neuralFrame):
self.annotationsView = AnnotationsView(neuralFrame)
self.annotationsView.setObjectName(u"annotationsView")
- sizePolicy2 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
- sizePolicy2.setHorizontalStretch(4)
- sizePolicy2.setVerticalStretch(0)
- sizePolicy2.setHeightForWidth(self.annotationsView.sizePolicy().hasHeightForWidth())
- self.annotationsView.setSizePolicy(sizePolicy2)
+ sizePolicy3 = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
+ sizePolicy3.setHorizontalStretch(4)
+ sizePolicy3.setVerticalStretch(0)
+ sizePolicy3.setHeightForWidth(self.annotationsView.sizePolicy().hasHeightForWidth())
+ self.annotationsView.setSizePolicy(sizePolicy3)
self.annotationsView.setMinimumSize(QSize(0, 64))
self.annotationsView.setMaximumSize(QSize(16777215, 64))
self.annotationsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
@@ -134,6 +152,10 @@ def setupUi(self, neuralFrame):
def retranslateUi(self, neuralFrame):
neuralFrame.setWindowTitle(QCoreApplication.translate("neuralFrame", u"Neural Data", None))
+#if QT_CONFIG(tooltip)
+ self.launchPlugin.setToolTip("")
+#endif // QT_CONFIG(tooltip)
+ self.launchPlugin.setText(QCoreApplication.translate("neuralFrame", u"Launch Neural Plugin", None))
self.showTraceRadioButton.setText(QCoreApplication.translate("neuralFrame", u"Show Trace", None))
self.showHeatMapRadioButton.setText(QCoreApplication.translate("neuralFrame", u"Show HeatMap", None))
self.showAnnotationsCheckBox.setText(QCoreApplication.translate("neuralFrame", u"Show Annotations", None))
diff --git a/src/plugin/behaviorTriggeredAverage.ui b/src/plugin/behaviorTriggeredAverage.ui
new file mode 100644
index 0000000..a81d7ca
--- /dev/null
+++ b/src/plugin/behaviorTriggeredAverage.ui
@@ -0,0 +1,436 @@
+
+
+ BTAFrame
+
+
+
+ 0
+ 0
+ 945
+ 860
+
+
+
+ Behavior Triggered Average
+
+
+
-
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 302
+ 834
+
+
+
+
-
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 14
+ false
+
+
+
+ Save
+
+
+
+ 16
+ 16
+
+
+
+ QToolButton::InstantPopup
+
+
+ Qt::ToolButtonTextBesideIcon
+
+
+ Qt::DownArrow
+
+
+
+ -
+
+
+ Behavior trigger :
+
+
+
+ -
+
+
+ -
+
+
+
+ 16777215
+ 80
+
+
+
+ in channel :
+
+
+
+ -
+
+
+ -
+
+
+ Align at start
+
+
+ true
+
+
+
+ -
+
+
+ Align at end
+
+
+
+ -
+
+
-
+
+
+ Window :
+
+
+
+ -
+
+
+
+ 60
+ 30
+
+
+
+ sec after
+
+
+
+ -
+
+
+ Bin size :
+
+
+
+ -
+
+
+ sec before
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+
+ 60
+ 30
+
+
+
+ 4
+
+
+ 0.000100000000000
+
+
+ 0.000100000000000
+
+
+ QAbstractSpinBox::DefaultStepType
+
+
+ 0.033300000000000
+
+
+
+ -
+
+
+
+ 60
+ 30
+
+
+
+ 0.010000000000000
+
+
+ QAbstractSpinBox::DefaultStepType
+
+
+ 10.000000000000000
+
+
+
+ -
+
+
+
+ 60
+ 30
+
+
+
+ 2
+
+
+ 0.010000000000000
+
+
+ QAbstractSpinBox::DefaultStepType
+
+
+ 10.000000000000000
+
+
+
+
+
+ -
+
+
-
+
+
+ sec apart
+
+
+
+ -
+
+
+ sec long
+
+
+
+ -
+
+
+ Discard bouts under
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Merge bouts under
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+
+ 60
+ 30
+
+
+
+ 0.010000000000000
+
+
+ QAbstractSpinBox::DefaultStepType
+
+
+
+ -
+
+
+
+ 60
+ 30
+
+
+
+ 0.010000000000000
+
+
+ QAbstractSpinBox::DefaultStepType
+
+
+ 2.000000000000000
+
+
+
+
+
+ -
+
+
+ Element to analyze :
+
+
+
+ -
+
+
+ -
+
+
+ in channel(s) :
+
+
+
+ -
+
+
+
+ 16777215
+ 80
+
+
+
+
+ -
+
+
-
+
+
+ z-score traces :
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+
+
+ -
+
+
+ Behaviors to show :
+
+
+
+ -
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 605
+ 834
+
+
+
+
-
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/plugin/behaviorTriggeredAverage_ui.py b/src/plugin/behaviorTriggeredAverage_ui.py
new file mode 100644
index 0000000..4058666
--- /dev/null
+++ b/src/plugin/behaviorTriggeredAverage_ui.py
@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+
+################################################################################
+## Form generated from reading UI file 'behaviorTriggeredAverage.ui'
+##
+## Created by: Qt User Interface Compiler version 5.15.1
+##
+## WARNING! All changes made in this file will be lost when recompiling UI file!
+################################################################################
+
+from qtpy.QtCore import *
+from qtpy.QtGui import *
+from qtpy.QtWidgets import *
+
+
+class Ui_BTAFrame(object):
+ def setupUi(self, BTAFrame):
+ if not BTAFrame.objectName():
+ BTAFrame.setObjectName(u"BTAFrame")
+ BTAFrame.resize(945, 860)
+ self.horizontalLayout = QHBoxLayout(BTAFrame)
+ self.horizontalLayout.setObjectName(u"horizontalLayout")
+ self.userOptionsScrollArea = QScrollArea(BTAFrame)
+ self.userOptionsScrollArea.setObjectName(u"userOptionsScrollArea")
+ self.userOptionsScrollArea.setWidgetResizable(True)
+ self.userOptionScrollAreaWidget = QWidget()
+ self.userOptionScrollAreaWidget.setObjectName(u"userOptionScrollAreaWidget")
+ self.userOptionScrollAreaWidget.setGeometry(QRect(0, 0, 302, 834))
+ self.scrollAreaLayout = QVBoxLayout(self.userOptionScrollAreaWidget)
+ self.scrollAreaLayout.setObjectName(u"scrollAreaLayout")
+ self.verticalSpacer_2 = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
+
+ self.scrollAreaLayout.addItem(self.verticalSpacer_2)
+
+ self.saveButton = QToolButton(self.userOptionScrollAreaWidget)
+ self.saveButton.setObjectName(u"saveButton")
+ self.saveButton.setBaseSize(QSize(0, 0))
+ font = QFont()
+ font.setPointSize(14)
+ font.setBold(False)
+ self.saveButton.setFont(font)
+ self.saveButton.setIconSize(QSize(16, 16))
+ self.saveButton.setPopupMode(QToolButton.InstantPopup)
+ self.saveButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
+ self.saveButton.setArrowType(Qt.DownArrow)
+
+ self.scrollAreaLayout.addWidget(self.saveButton)
+
+ self.behaviorLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.behaviorLabel.setObjectName(u"behaviorLabel")
+
+ self.scrollAreaLayout.addWidget(self.behaviorLabel)
+
+ self.behaviorComboBox = QComboBox(self.userOptionScrollAreaWidget)
+ self.behaviorComboBox.setObjectName(u"behaviorComboBox")
+
+ self.scrollAreaLayout.addWidget(self.behaviorComboBox)
+
+ self.channelLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.channelLabel.setObjectName(u"channelLabel")
+ self.channelLabel.setMaximumSize(QSize(16777215, 80))
+
+ self.scrollAreaLayout.addWidget(self.channelLabel)
+
+ self.channelComboBox = QComboBox(self.userOptionScrollAreaWidget)
+ self.channelComboBox.setObjectName(u"channelComboBox")
+
+ self.scrollAreaLayout.addWidget(self.channelComboBox)
+
+ self.alignAtStartButton = QRadioButton(self.userOptionScrollAreaWidget)
+ self.alignAtStartButton.setObjectName(u"alignAtStartButton")
+ self.alignAtStartButton.setChecked(True)
+
+ self.scrollAreaLayout.addWidget(self.alignAtStartButton)
+
+ self.alignAtEndButton = QRadioButton(self.userOptionScrollAreaWidget)
+ self.alignAtEndButton.setObjectName(u"alignAtEndButton")
+
+ self.scrollAreaLayout.addWidget(self.alignAtEndButton)
+
+ self.windowBinLayout = QGridLayout()
+ self.windowBinLayout.setObjectName(u"windowBinLayout")
+ self.windowLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.windowLabel.setObjectName(u"windowLabel")
+
+ self.windowBinLayout.addWidget(self.windowLabel, 0, 0, 1, 1)
+
+ self.secAfterLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.secAfterLabel.setObjectName(u"secAfterLabel")
+ self.secAfterLabel.setMinimumSize(QSize(60, 30))
+
+ self.windowBinLayout.addWidget(self.secAfterLabel, 1, 2, 1, 1)
+
+ self.binSizeLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.binSizeLabel.setObjectName(u"binSizeLabel")
+
+ self.windowBinLayout.addWidget(self.binSizeLabel, 2, 0, 1, 1)
+
+ self.secBeforeLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.secBeforeLabel.setObjectName(u"secBeforeLabel")
+
+ self.windowBinLayout.addWidget(self.secBeforeLabel, 0, 2, 1, 1)
+
+ self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
+
+ self.windowBinLayout.addItem(self.horizontalSpacer, 1, 3, 1, 1)
+
+ self.binSizeBox = QDoubleSpinBox(self.userOptionScrollAreaWidget)
+ self.binSizeBox.setObjectName(u"binSizeBox")
+ self.binSizeBox.setMinimumSize(QSize(60, 30))
+ self.binSizeBox.setDecimals(4)
+ self.binSizeBox.setMinimum(0.000100000000000)
+ self.binSizeBox.setSingleStep(0.000100000000000)
+ self.binSizeBox.setStepType(QAbstractSpinBox.DefaultStepType)
+ self.binSizeBox.setValue(0.033300000000000)
+
+ self.windowBinLayout.addWidget(self.binSizeBox, 2, 1, 1, 1)
+
+ self.windowBox_2 = QDoubleSpinBox(self.userOptionScrollAreaWidget)
+ self.windowBox_2.setObjectName(u"windowBox_2")
+ self.windowBox_2.setMinimumSize(QSize(60, 30))
+ self.windowBox_2.setSingleStep(0.010000000000000)
+ self.windowBox_2.setStepType(QAbstractSpinBox.DefaultStepType)
+ self.windowBox_2.setValue(10.000000000000000)
+
+ self.windowBinLayout.addWidget(self.windowBox_2, 1, 1, 1, 1)
+
+ self.windowBox_1 = QDoubleSpinBox(self.userOptionScrollAreaWidget)
+ self.windowBox_1.setObjectName(u"windowBox_1")
+ self.windowBox_1.setMinimumSize(QSize(60, 30))
+ self.windowBox_1.setDecimals(2)
+ self.windowBox_1.setSingleStep(0.010000000000000)
+ self.windowBox_1.setStepType(QAbstractSpinBox.DefaultStepType)
+ self.windowBox_1.setValue(10.000000000000000)
+
+ self.windowBinLayout.addWidget(self.windowBox_1, 0, 1, 1, 1)
+
+
+ self.scrollAreaLayout.addLayout(self.windowBinLayout)
+
+ self.boutsLayout = QGridLayout()
+ self.boutsLayout.setObjectName(u"boutsLayout")
+ self.secApartLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.secApartLabel.setObjectName(u"secApartLabel")
+
+ self.boutsLayout.addWidget(self.secApartLabel, 1, 2, 1, 1)
+
+ self.secLongLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.secLongLabel.setObjectName(u"secLongLabel")
+
+ self.boutsLayout.addWidget(self.secLongLabel, 0, 2, 1, 1)
+
+ self.discardBoutsLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.discardBoutsLabel.setObjectName(u"discardBoutsLabel")
+
+ self.boutsLayout.addWidget(self.discardBoutsLabel, 0, 0, 1, 1)
+
+ self.mergeBoutsLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.mergeBoutsLabel.setObjectName(u"mergeBoutsLabel")
+ self.mergeBoutsLabel.setMinimumSize(QSize(0, 0))
+
+ self.boutsLayout.addWidget(self.mergeBoutsLabel, 1, 0, 1, 1)
+
+ self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
+
+ self.boutsLayout.addItem(self.horizontalSpacer_2, 0, 3, 1, 1)
+
+ self.discardBoutsBox = QDoubleSpinBox(self.userOptionScrollAreaWidget)
+ self.discardBoutsBox.setObjectName(u"discardBoutsBox")
+ self.discardBoutsBox.setMinimumSize(QSize(60, 30))
+ self.discardBoutsBox.setSingleStep(0.010000000000000)
+ self.discardBoutsBox.setStepType(QAbstractSpinBox.DefaultStepType)
+
+ self.boutsLayout.addWidget(self.discardBoutsBox, 0, 1, 1, 1)
+
+ self.mergeBoutsBox = QDoubleSpinBox(self.userOptionScrollAreaWidget)
+ self.mergeBoutsBox.setObjectName(u"mergeBoutsBox")
+ self.mergeBoutsBox.setMinimumSize(QSize(60, 30))
+ self.mergeBoutsBox.setSingleStep(0.010000000000000)
+ self.mergeBoutsBox.setStepType(QAbstractSpinBox.DefaultStepType)
+ self.mergeBoutsBox.setValue(2.000000000000000)
+
+ self.boutsLayout.addWidget(self.mergeBoutsBox, 1, 1, 1, 1)
+
+
+ self.scrollAreaLayout.addLayout(self.boutsLayout)
+
+ self.analyzeLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.analyzeLabel.setObjectName(u"analyzeLabel")
+
+ self.scrollAreaLayout.addWidget(self.analyzeLabel)
+
+ self.analyzeComboBox = QComboBox(self.userOptionScrollAreaWidget)
+ self.analyzeComboBox.setObjectName(u"analyzeComboBox")
+
+ self.scrollAreaLayout.addWidget(self.analyzeComboBox)
+
+ self.channelLabel_2 = QLabel(self.userOptionScrollAreaWidget)
+ self.channelLabel_2.setObjectName(u"channelLabel_2")
+
+ self.scrollAreaLayout.addWidget(self.channelLabel_2)
+
+ self.channelsList = QListView(self.userOptionScrollAreaWidget)
+ self.channelsList.setObjectName(u"channelsList")
+ self.channelsList.setMaximumSize(QSize(16777215, 80))
+
+ self.scrollAreaLayout.addWidget(self.channelsList)
+
+ self.zscoreLayout = QHBoxLayout()
+ self.zscoreLayout.setObjectName(u"zscoreLayout")
+ self.zscoreLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.zscoreLabel.setObjectName(u"zscoreLabel")
+
+ self.zscoreLayout.addWidget(self.zscoreLabel)
+
+ self.zscoreCheckBox = QCheckBox(self.userOptionScrollAreaWidget)
+ self.zscoreCheckBox.setObjectName(u"zscoreCheckBox")
+
+ self.zscoreLayout.addWidget(self.zscoreCheckBox)
+
+ self.horizontalSpacer_3 = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
+
+ self.zscoreLayout.addItem(self.horizontalSpacer_3)
+
+
+ self.scrollAreaLayout.addLayout(self.zscoreLayout)
+
+ self.behaviorSelectionLabel = QLabel(self.userOptionScrollAreaWidget)
+ self.behaviorSelectionLabel.setObjectName(u"behaviorSelectionLabel")
+
+ self.scrollAreaLayout.addWidget(self.behaviorSelectionLabel)
+
+ self.behaviorSelectLayout = QGridLayout()
+ self.behaviorSelectLayout.setObjectName(u"behaviorSelectLayout")
+
+ self.scrollAreaLayout.addLayout(self.behaviorSelectLayout)
+
+ self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
+
+ self.scrollAreaLayout.addItem(self.verticalSpacer)
+
+ self.userOptionsScrollArea.setWidget(self.userOptionScrollAreaWidget)
+
+ self.horizontalLayout.addWidget(self.userOptionsScrollArea)
+
+ self.plotScrollArea = QScrollArea(BTAFrame)
+ self.plotScrollArea.setObjectName(u"plotScrollArea")
+ self.plotScrollArea.setWidgetResizable(True)
+ self.plotScrollAreaWidget = QWidget()
+ self.plotScrollAreaWidget.setObjectName(u"plotScrollAreaWidget")
+ self.plotScrollAreaWidget.setGeometry(QRect(0, 0, 605, 834))
+ self.plotScrollAreaLayout = QHBoxLayout(self.plotScrollAreaWidget)
+ self.plotScrollAreaLayout.setObjectName(u"plotScrollAreaLayout")
+ self.plotLayout = QVBoxLayout()
+ self.plotLayout.setObjectName(u"plotLayout")
+
+ self.plotScrollAreaLayout.addLayout(self.plotLayout)
+
+ self.plotScrollArea.setWidget(self.plotScrollAreaWidget)
+
+ self.horizontalLayout.addWidget(self.plotScrollArea)
+
+ self.horizontalLayout.setStretch(0, 1)
+ self.horizontalLayout.setStretch(1, 2)
+
+ self.retranslateUi(BTAFrame)
+
+ QMetaObject.connectSlotsByName(BTAFrame)
+ # setupUi
+
+ def retranslateUi(self, BTAFrame):
+ BTAFrame.setWindowTitle(QCoreApplication.translate("BTAFrame", u"Behavior Triggered Average", None))
+ self.saveButton.setText(QCoreApplication.translate("BTAFrame", u"Save", None))
+ self.behaviorLabel.setText(QCoreApplication.translate("BTAFrame", u"Behavior trigger :", None))
+ self.channelLabel.setText(QCoreApplication.translate("BTAFrame", u"in channel :", None))
+ self.alignAtStartButton.setText(QCoreApplication.translate("BTAFrame", u"Align at start", None))
+ self.alignAtEndButton.setText(QCoreApplication.translate("BTAFrame", u"Align at end", None))
+ self.windowLabel.setText(QCoreApplication.translate("BTAFrame", u"Window :", None))
+ self.secAfterLabel.setText(QCoreApplication.translate("BTAFrame", u"sec after", None))
+ self.binSizeLabel.setText(QCoreApplication.translate("BTAFrame", u"Bin size :", None))
+ self.secBeforeLabel.setText(QCoreApplication.translate("BTAFrame", u"sec before", None))
+ self.secApartLabel.setText(QCoreApplication.translate("BTAFrame", u"sec apart", None))
+ self.secLongLabel.setText(QCoreApplication.translate("BTAFrame", u"sec long", None))
+ self.discardBoutsLabel.setText(QCoreApplication.translate("BTAFrame", u"Discard bouts under", None))
+ self.mergeBoutsLabel.setText(QCoreApplication.translate("BTAFrame", u"Merge bouts under", None))
+ self.analyzeLabel.setText(QCoreApplication.translate("BTAFrame", u"Element to analyze :", None))
+ self.channelLabel_2.setText(QCoreApplication.translate("BTAFrame", u"in channel(s) :", None))
+ self.zscoreLabel.setText(QCoreApplication.translate("BTAFrame", u"z-score traces :", None))
+ self.zscoreCheckBox.setText("")
+ self.behaviorSelectionLabel.setText(QCoreApplication.translate("BTAFrame", u"Behaviors to show :", None))
+ # retranslateUi
+
diff --git a/src/plugin/kMeansClustersDialog.ui b/src/plugin/kMeansClustersDialog.ui
new file mode 100644
index 0000000..1796a33
--- /dev/null
+++ b/src/plugin/kMeansClustersDialog.ui
@@ -0,0 +1,202 @@
+
+
+ kMeansClustersDialog
+
+
+
+ 0
+ 0
+ 619
+ 200
+
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 0
+
+
+
+
+ 16777215
+ 16777215
+
+
+
+ Clusters Input For K-Means Clustering
+
+
+ -
+
+
-
+
+
-
+
+
+
+
+
+
+ -
+
+
+
+ 0
+ 25
+
+
+
+ Number of Clusters :
+
+
+
+ -
+
+
+ 1
+
+
+ 3
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+
+
+ -
+
+
-
+
+
+
+
+
+ true
+
+
+
+ -
+
+
+ Evaluate goodness of fit from
+
+
+
+ -
+
+
+ 1
+
+
+
+ -
+
+
+ to
+
+
+
+ -
+
+
+ 1
+
+
+ 11
+
+
+
+ -
+
+
+ clusters using
+
+
+
+ -
+
+
+ 2
+
+
+ 5
+
+
+
+ -
+
+
+ folds cross validation
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+ QDialogButtonBox::Cancel|QDialogButtonBox::Ok
+
+
+
+
+
+
+
+
+ buttonBox
+ accepted()
+ kMeansClustersDialog
+ accept()
+
+
+ 248
+ 254
+
+
+ 157
+ 274
+
+
+
+
+ buttonBox
+ rejected()
+ kMeansClustersDialog
+ reject()
+
+
+ 316
+ 260
+
+
+ 286
+ 274
+
+
+
+
+
diff --git a/src/plugin/kMeansClustersDialog_ui.py b/src/plugin/kMeansClustersDialog_ui.py
new file mode 100644
index 0000000..9822f80
--- /dev/null
+++ b/src/plugin/kMeansClustersDialog_ui.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+
+################################################################################
+## Form generated from reading UI file 'kMeansClustersDialog.ui'
+##
+## Created by: Qt User Interface Compiler version 5.15.1
+##
+## WARNING! All changes made in this file will be lost when recompiling UI file!
+################################################################################
+
+from qtpy.QtCore import *
+from qtpy.QtGui import *
+from qtpy.QtWidgets import *
+
+
+class Ui_kMeansClustersDialog(object):
+ def setupUi(self, kMeansClustersDialog):
+ if not kMeansClustersDialog.objectName():
+ kMeansClustersDialog.setObjectName(u"kMeansClustersDialog")
+ kMeansClustersDialog.resize(619, 200)
+ sizePolicy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(kMeansClustersDialog.sizePolicy().hasHeightForWidth())
+ kMeansClustersDialog.setSizePolicy(sizePolicy)
+ kMeansClustersDialog.setMinimumSize(QSize(0, 0))
+ kMeansClustersDialog.setMaximumSize(QSize(16777215, 16777215))
+ self.verticalLayout_2 = QVBoxLayout(kMeansClustersDialog)
+ self.verticalLayout_2.setObjectName(u"verticalLayout_2")
+ self.verticalLayout = QVBoxLayout()
+ self.verticalLayout.setObjectName(u"verticalLayout")
+ self.numOfClustersLayout = QHBoxLayout()
+ self.numOfClustersLayout.setObjectName(u"numOfClustersLayout")
+ self.numOfClusterRadioButton = QRadioButton(kMeansClustersDialog)
+ self.numOfClusterRadioButton.setObjectName(u"numOfClusterRadioButton")
+
+ self.numOfClustersLayout.addWidget(self.numOfClusterRadioButton)
+
+ self.numOfClustersLabel = QLabel(kMeansClustersDialog)
+ self.numOfClustersLabel.setObjectName(u"numOfClustersLabel")
+ self.numOfClustersLabel.setMinimumSize(QSize(0, 25))
+
+ self.numOfClustersLayout.addWidget(self.numOfClustersLabel)
+
+ self.numOfClustersBox = QSpinBox(kMeansClustersDialog)
+ self.numOfClustersBox.setObjectName(u"numOfClustersBox")
+ self.numOfClustersBox.setMinimum(1)
+ self.numOfClustersBox.setValue(3)
+
+ self.numOfClustersLayout.addWidget(self.numOfClustersBox)
+
+ self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
+
+ self.numOfClustersLayout.addItem(self.horizontalSpacer)
+
+
+ self.verticalLayout.addLayout(self.numOfClustersLayout)
+
+ self.gofLayout = QHBoxLayout()
+ self.gofLayout.setObjectName(u"gofLayout")
+ self.gofRadioButton = QRadioButton(kMeansClustersDialog)
+ self.gofRadioButton.setObjectName(u"gofRadioButton")
+ self.gofRadioButton.setChecked(True)
+
+ self.gofLayout.addWidget(self.gofRadioButton)
+
+ self.gofLabel1 = QLabel(kMeansClustersDialog)
+ self.gofLabel1.setObjectName(u"gofLabel1")
+
+ self.gofLayout.addWidget(self.gofLabel1)
+
+ self.clusterRangeBox1 = QSpinBox(kMeansClustersDialog)
+ self.clusterRangeBox1.setObjectName(u"clusterRangeBox1")
+ self.clusterRangeBox1.setMinimum(1)
+
+ self.gofLayout.addWidget(self.clusterRangeBox1)
+
+ self.gofLabel2 = QLabel(kMeansClustersDialog)
+ self.gofLabel2.setObjectName(u"gofLabel2")
+
+ self.gofLayout.addWidget(self.gofLabel2)
+
+ self.clusterRangeBox2 = QSpinBox(kMeansClustersDialog)
+ self.clusterRangeBox2.setObjectName(u"clusterRangeBox2")
+ self.clusterRangeBox2.setMinimum(1)
+ self.clusterRangeBox2.setValue(11)
+
+ self.gofLayout.addWidget(self.clusterRangeBox2)
+
+ self.gofLabel3 = QLabel(kMeansClustersDialog)
+ self.gofLabel3.setObjectName(u"gofLabel3")
+
+ self.gofLayout.addWidget(self.gofLabel3)
+
+ self.crossValidationFoldsBox = QSpinBox(kMeansClustersDialog)
+ self.crossValidationFoldsBox.setObjectName(u"crossValidationFoldsBox")
+ self.crossValidationFoldsBox.setMinimum(2)
+ self.crossValidationFoldsBox.setValue(5)
+
+ self.gofLayout.addWidget(self.crossValidationFoldsBox)
+
+ self.gofLabel4 = QLabel(kMeansClustersDialog)
+ self.gofLabel4.setObjectName(u"gofLabel4")
+
+ self.gofLayout.addWidget(self.gofLabel4)
+
+
+ self.verticalLayout.addLayout(self.gofLayout)
+
+
+ self.verticalLayout_2.addLayout(self.verticalLayout)
+
+ self.buttonBox = QDialogButtonBox(kMeansClustersDialog)
+ self.buttonBox.setObjectName(u"buttonBox")
+ self.buttonBox.setOrientation(Qt.Horizontal)
+ self.buttonBox.setStandardButtons(QDialogButtonBox.Cancel|QDialogButtonBox.Ok)
+
+ self.verticalLayout_2.addWidget(self.buttonBox)
+
+
+ self.retranslateUi(kMeansClustersDialog)
+ self.buttonBox.accepted.connect(kMeansClustersDialog.accept)
+ self.buttonBox.rejected.connect(kMeansClustersDialog.reject)
+
+ QMetaObject.connectSlotsByName(kMeansClustersDialog)
+ # setupUi
+
+ def retranslateUi(self, kMeansClustersDialog):
+ kMeansClustersDialog.setWindowTitle(QCoreApplication.translate("kMeansClustersDialog", u"Clusters Input For K-Means Clustering", None))
+ self.numOfClusterRadioButton.setText("")
+ self.numOfClustersLabel.setText(QCoreApplication.translate("kMeansClustersDialog", u"Number of Clusters : ", None))
+ self.gofRadioButton.setText("")
+ self.gofLabel1.setText(QCoreApplication.translate("kMeansClustersDialog", u"Evaluate goodness of fit from", None))
+ self.gofLabel2.setText(QCoreApplication.translate("kMeansClustersDialog", u"to", None))
+ self.gofLabel3.setText(QCoreApplication.translate("kMeansClustersDialog", u"clusters using ", None))
+ self.gofLabel4.setText(QCoreApplication.translate("kMeansClustersDialog", u"folds cross validation", None))
+ # retranslateUi
+
diff --git a/src/plugin/plugin_pose_DeepLabCut.py b/src/plugin/plugin_pose_DeepLabCut.py
index 42b0e8d..732084c 100644
--- a/src/plugin/plugin_pose_DeepLabCut.py
+++ b/src/plugin/plugin_pose_DeepLabCut.py
@@ -4,10 +4,15 @@
from qtpy.QtGui import QColor, QPainter, QPen, QPolygonF
from qtpy.QtWidgets import QMessageBox, QWidget
from pose.pose import PoseBase
+from collections import OrderedDict
from utils import get_colormap
import csv
+import os
import numpy as np
+import pandas as pd
import h5py
+from pynwb import NWBFile
+from ndx_pose import PoseEstimationSeries, PoseEstimation
from os.path import splitext
@@ -32,24 +37,26 @@ def getFileSearchPattern(self) -> str:
return "*.h5 *.csv"
def _validateFileH5(self, parent_widget, file_path: str) -> bool:
- h5 = h5py.File(file_path, 'r')
- default_group = h5[list(h5.keys())[0]]
- if 'table' not in default_group.keys():
+ df = pd.read_hdf(file_path)
+ inds = list(df.columns.names)
+ if ('scorer' in inds and
+ 'bodyparts' in inds and
+ 'coords' in inds):
+ return True
+ else:
QMessageBox.warning(parent_widget, "Add Pose ...", "No pose data found in pose file")
return False
- return True
def _validateFileCSV(self, parent_widget, file_path: str) -> bool:
- result = True
- with open(file_path, 'r') as csvFile:
- reader = csv.reader(csvFile)
- header1 = next(reader)
- if (not isinstance(header1[1], str) or
- header1[1] != 'bodyparts' or
- (len(header1) - 2) % 3 != 0): # "", "bodyparts", 3 x , ...
- result = False
- csvFile.seek(0)
- return result
+ df = pd.read_csv(file_path, header=[0,1,2], index_col=0)
+ inds = list(df.columns.names)
+ if ('scorer' in inds and
+ 'bodyparts' in inds and
+ 'coords' in inds):
+ return True
+ else:
+ QMessageBox.warning(parent_widget, "Add Pose ...", "No pose data found in pose file")
+ return False
def validateFile(self, parent_widget: QWidget, file_path: str) -> bool:
"""
@@ -66,19 +73,19 @@ def validateFile(self, parent_widget: QWidget, file_path: str) -> bool:
QMessageBox.warning(parent_widget, "Extension not supported",
f"The file extension {ext} is not supported.")
- def _loadPoses_h5(self, parent_widget: QWidget, path: str):
+ def _loadPoses_h5(self, parent_widget: QWidget, path: str, video_path: str):
raise NotImplementedError("Please implement this in your derived class")
- def _loadPoses_csv(self, parent_widget: QWidget, path: str):
+ def _loadPoses_csv(self, parent_widget: QWidget, path: str, video_path: str):
raise NotImplementedError("Please implement this in your derived class")
- def loadPoses(self, parent_widget: QWidget, path: str):
+ def loadPoses(self, parent_widget: QWidget, path: str, video_path: str):
_, self.file_extension = splitext(path)
self.file_extension = self.file_extension.lower()
if self.file_extension == '.h5':
- self._loadPoses_h5(parent_widget, path)
+ self._loadPoses_h5(parent_widget, path, video_path)
elif self.file_extension == '.csv':
- self._loadPoses_csv(parent_widget, path)
+ self._loadPoses_csv(parent_widget, path, video_path)
else:
QMessageBox.warning(parent_widget, "Extension not supported",
f"The file extension {self.file_extension} is not supported.")
@@ -89,6 +96,9 @@ def __init__(self):
super().__init__()
self.file_extension = None
self.frame_points = []
+ self.pose_data = np.array([])
+ self.body_parts = np.array([])
+ self.video_path = None
self.num_frames = 0
# Construct colors based on the number of body parts and the color map
self.colormap_data = get_colormap('turbo')
@@ -118,18 +128,18 @@ def getFileSearchDescription(self) -> str:
def getFileFormat(self) -> str:
return "DeepLabCut_generic"
- def _loadPoses_h5(self, parent_widget, path: str):
+ def _loadPoses_h5(self, parent_widget, path: str, video_path: str):
# TODO: need to change this from the mouse form to a more generic
# version that just stores the points.
- h5 = h5py.File(path, 'r')
- default_group = h5[list(h5.keys())[0]]
- table = default_group['table']
- self.num_frames = len(table)
+ df = pd.read_hdf(path)
+ self.body_parts = np.array(df.columns.get_level_values(1))
+ self.pose_data = np.array(df)
+ self.video_path = video_path
+ self.num_frames = self.pose_data.shape[0]
isColorsGenerated = False
for frame_ix in range(self.num_frames):
- # DLC stores each frame as (frame_idx, (pose_array))
- # we don't care about the frame_idx
- this_frame_data = table[frame_ix][1]
+ # DLC stores each frame as row of a table
+ this_frame_data = self.pose_data[frame_ix]
this_frame_points = []
vals_per_pt = 3
num_points = int(len(this_frame_data) / vals_per_pt)
@@ -143,33 +153,77 @@ def _loadPoses_h5(self, parent_widget, path: str):
this_frame_points.append(QPointF(this_frame_data[pt_x_ix], this_frame_data[pt_y_ix]))
self.frame_points.append(this_frame_points)
- def _loadPoses_csv(self, parent_widget, path: str):
- with open(path, 'r') as csvFile:
- vals_per_pt = 3 # x, y, confidence. We don't care about confidence.
- reader = csv.reader(csvFile)
- header1 = next(reader)
- num_pts = int((len(header1) - 2) / vals_per_pt)
- self.generatePoseColors(num_pts)
- _ = next(reader) # skip over second row of header
- while True:
- try:
- row = next(reader)
- except StopIteration:
- break
- this_frame_points = []
- for point_ix in range(num_pts):
- pt_x_ix = (vals_per_pt * point_ix) + 2
- pt_y_ix = (vals_per_pt * point_ix) + 3
- this_frame_points.append(QPointF(float(row[pt_x_ix]), float(row[pt_y_ix])))
- self.frame_points.append(this_frame_points)
- self.num_frames = len(self.frame_points)
-
+ def _loadPoses_csv(self, parent_widget, path: str, video_path: str):
+ df = pd.read_csv(path, header=[0,1,2], index_col=0)
+ self.body_parts = np.array(df.columns.get_level_values(1))
+ self.pose_data = np.array(df)
+ self.video_path = video_path
+ self.num_frames = self.pose_data.shape[0]
+ for frame_ix in range(self.num_frames):
+ # DLC stores each frame as row of a table
+ this_frame_data = self.pose_data[frame_ix]
+ this_frame_points = []
+ vals_per_pt = 3
+ num_points = int(len(this_frame_data) / vals_per_pt)
+ self.generatePoseColors(num_points)
+ for point_ix in range(num_points):
+ # each pose point has x, y and confidence
+ # we don't care about confidence
+ pt_x_ix = (vals_per_pt * point_ix) + 0
+ pt_y_ix = (vals_per_pt * point_ix) + 1
+ this_frame_points.append(QPointF(this_frame_data[pt_x_ix], this_frame_data[pt_y_ix]))
+ self.frame_points.append(this_frame_points)
+
+ def exportPosesToNWBFile(self, nwbFile: NWBFile):
+ processing_module_name = f"Pose data for video {os.path.basename(self.video_path)}"
+ #reshape pose data to [frames, body_parts, points]
+ pose_data = self.pose_data.reshape(self.num_frames, -1, 3) # 3 because x, y, confidence
+ # maintain bodyparts order
+ order_dict = OrderedDict()
+ for node in self.body_parts:
+ order_dict[node] = 1
+ body_parts = np.array(list(order_dict.keys()))
+
+
+ pose_estimation_series = []
+ for nodes_ix in range(body_parts.shape[0]):
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = f"{body_parts[nodes_ix]}",
+ description = f"Pose keypoint placed around {body_parts[nodes_ix]}",
+ data = pose_data[:,nodes_ix,:2],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = pose_data[:,nodes_ix,2]
+ )
+ )
+ pose_estimation = PoseEstimation(
+ pose_estimation_series = pose_estimation_series,
+ name = f"animal_0",
+ description = f"Estimated position for animal_0 in video {os.path.basename(self.video_path)}",
+ nodes = body_parts,
+ )
+ if processing_module_name in nwbFile.processing:
+ nwbFile.processing[processing_module_name].add(pose_estimation)
+ else:
+ pose_pm = nwbFile.create_processing_module(
+ name = processing_module_name,
+ description = f"Pose Data from {self.getFileFormat().split('_')[0]}"
+ )
+ pose_pm.add(pose_estimation)
+
+ return nwbFile
+
class PoseDLC_mouse(PoseDLCBase):
def __init__(self):
super().__init__()
self.pose_polys = []
+ self.body_parts = np.array([])
+ self.pose_data = np.array([])
self.num_frames = 0
+ self.num_mice = 0
+ self.video_path = None
self.pose_colors = [Qt.blue, Qt.green]
def drawPoses(self, painter: QPainter, frame_ix: int):
@@ -191,22 +245,27 @@ def getFileSearchDescription(self) -> str:
def getFileFormat(self) -> str:
return "DeepLabCut_mouse"
- def _loadPoses_h5(self, parent_widget: QWidget, path: str):
- h5 = h5py.File(path, 'r')
- default_group = h5[list(h5.keys())[0]]
- table = default_group['table']
+ def _loadPoses_h5(self, parent_widget: QWidget, path: str, video_path: str):
+ df = pd.read_hdf(path)
+ self.body_parts = np.array(df.columns.get_level_values(1))
+ self.pose_data = np.array(df)
self.pose_polys = []
- self.num_frames = len(table)
+ self.video_path = video_path
+ self.num_frames = self.pose_data.shape[0]
for frame_ix in range(self.num_frames):
- # DLC stores each frame as (frame_idx, (pose_array))
- # we don't care about the frame_idx
- this_frame_data = table[frame_ix][1]
+ # DLC stores each frame as row of a table
+ this_frame_data = self.pose_data[frame_ix]
frame_polys = []
vals_per_pt = 3
pts_per_mouse = 7
- vals_per_mouse = vals_per_pt * pts_per_mouse
- num_mice = int(len(this_frame_data) / vals_per_mouse)
- for mouse_ix in range(num_mice):
+ vals_per_mouse = vals_per_pt * pts_per_mouse #21
+ self.num_mice = int(len(this_frame_data) / vals_per_mouse)
+ num_pts = int(self.pose_data.shape[1]/vals_per_pt)
+ if num_pts % pts_per_mouse != 0:
+ QMessageBox.warning(parent_widget, "Improper number of points for mice",
+ f"Expected a multiple of 7 points, got {num_pts}")
+ break
+ for mouse_ix in range(self.num_mice):
# each pose point has x, y and confidence
# we don't care about confidence
poly = QPolygonF()
@@ -239,64 +298,107 @@ def _loadPoses_h5(self, parent_widget: QWidget, path: str):
frame_polys.append(poly)
self.pose_polys.append(frame_polys)
- def _loadPoses_csv(self, parent_widget: QWidget, path: str):
- with open(path, 'r') as csvFile:
- vals_per_pt = 3 # x, y, confidence. We don't care about confidence.
- vals_per_mouse = 7
- reader = csv.reader(csvFile)
- header1 = next(reader)
- num_pts = int((len(header1) - 2) / vals_per_pt)
- if num_pts % vals_per_mouse != 0:
+ def _loadPoses_csv(self, parent_widget: QWidget, path: str, video_path: str):
+ # reading multi-index csv file with first three rows
+ # as header and first column as row index
+ df = pd.read_csv(path, header=[0,1,2], index_col=0)
+ self.body_parts = np.array(df.columns.get_level_values(1))
+ self.pose_data = np.array(df)
+ self.pose_polys = []
+ self.video_path = video_path
+ self.num_frames = self.pose_data.shape[0]
+ for frame_ix in range(self.num_frames):
+ # DLC stores each frame as row of a table
+ this_frame_data = self.pose_data[frame_ix]
+ frame_polys = []
+ vals_per_pt = 3
+ pts_per_mouse = 7
+ vals_per_mouse = vals_per_pt * pts_per_mouse
+ self.num_mice = int(len(this_frame_data) / vals_per_mouse)
+ num_pts = int(self.pose_data.shape[1]/vals_per_pt)
+ if num_pts % pts_per_mouse != 0:
QMessageBox.warning(parent_widget, "Improper number of points for mice",
f"Expected a multiple of 7 points, got {num_pts}")
- return
- _ = next(reader) # skip over second row of header
- while True:
- try:
- row = next(reader)
- except StopIteration:
- break
- num_mice = int(num_pts / vals_per_mouse)
- frame_polys = []
- for mouse_ix in range(num_mice):
- # each pose point has x, y and confidence
- # we don't care about confidence
- poly = QPolygonF()
- pt_x_ix = (vals_per_mouse * mouse_ix) + 2
- pt_y_ix = (vals_per_mouse * mouse_ix) + 3
- nose = QPointF(row[pt_x_ix], row[pt_y_ix])
- poly.append(nose)
- poly.append(QPointF(
- row[pt_x_ix + (1 * vals_per_pt)],
- row[pt_y_ix + (1 * vals_per_pt)]))
- poly.append(QPointF(
- row[pt_x_ix + (3 * vals_per_pt)],
- row[pt_y_ix + (3 * vals_per_pt)]))
- poly.append(QPointF(
- row[pt_x_ix + (4 * vals_per_pt)],
- row[pt_y_ix + (4 * vals_per_pt)]))
- poly.append(QPointF(
- row[pt_x_ix + (6 * vals_per_pt)],
- row[pt_y_ix + (6 * vals_per_pt)]))
- poly.append(QPointF(
- row[pt_x_ix + (5 * vals_per_pt)],
- row[pt_y_ix + (5 * vals_per_pt)]))
- poly.append(QPointF(
- row[pt_x_ix + (3 * vals_per_pt)],
- row[pt_y_ix + (3 * vals_per_pt)]))
- poly.append(QPointF(
- row[pt_x_ix + (2 * vals_per_pt)],
- row[pt_y_ix + (2 * vals_per_pt)]))
- poly.append(nose)
- frame_polys.append(poly)
- self.pose_polys.append(frame_polys)
- self.num_frames = len(self.frame_points)
+ break
+ for mouse_ix in range(self.num_mice):
+ # each pose point has x, y and confidence
+ # we don't care about confidence
+ poly = QPolygonF()
+ pt_x_ix = (vals_per_mouse * mouse_ix) + 0
+ pt_y_ix = (vals_per_mouse * mouse_ix) + 1
+ nose = QPointF(this_frame_data[pt_x_ix], this_frame_data[pt_y_ix])
+ poly.append(nose)
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (1 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (1 * vals_per_pt)]))
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (3 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (3 * vals_per_pt)]))
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (4 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (4 * vals_per_pt)]))
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (6 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (6 * vals_per_pt)]))
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (5 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (5 * vals_per_pt)]))
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (3 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (3 * vals_per_pt)]))
+ poly.append(QPointF(
+ this_frame_data[pt_x_ix + (2 * vals_per_pt)],
+ this_frame_data[pt_y_ix + (2 * vals_per_pt)]))
+ poly.append(nose)
+ frame_polys.append(poly)
+ self.pose_polys.append(frame_polys)
+
+ def exportPosesToNWBFile(self, nwbFile: NWBFile):
+ processing_module_name = f"Pose data for video {os.path.basename(self.video_path)}"
+ #reshape pose data to [frames, numOfMice, body_parts, points]
+ pose_data = self.pose_data.reshape(self.num_frames, self.num_mice, -1, 3) # 3 because x, y, confidence
+ # maintain bodyparts order
+ order_dict = OrderedDict()
+ for node in self.body_parts:
+ order_dict[node] = 1
+ body_parts = np.array(list(order_dict.keys())).reshape(-1, pose_data.shape[2])
+
+ for mouse_ix in range(self.num_mice):
+ pose_estimation_series = []
+ for nodes_ix in range(body_parts.shape[1]):
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = f"{body_parts[mouse_ix, nodes_ix]}",
+ description = f"Pose keypoint placed around {body_parts[mouse_ix, nodes_ix]}",
+ data = pose_data[:,mouse_ix,nodes_ix,:2],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = pose_data[:,mouse_ix,nodes_ix,2]
+ )
+ )
+ pose_estimation = PoseEstimation(
+ pose_estimation_series = pose_estimation_series,
+ name = f"animal_{mouse_ix}",
+ description = f"Estimated position for animal_{mouse_ix} in video {os.path.basename(self.video_path)}",
+ nodes = body_parts[mouse_ix,:],
+ edges = np.array([[0,1], [1,3], [3,4], [4,6], [6,5], [5,3], [3,2], [2,0]], dtype='uint8')
+ )
+ if processing_module_name in nwbFile.processing:
+ nwbFile.processing[processing_module_name].add(pose_estimation)
+ else:
+ pose_pm = nwbFile.create_processing_module(
+ name = processing_module_name,
+ description = f"Pose Data from {self.getFileFormat().split('_')[0]}"
+ )
+ pose_pm.add(pose_estimation)
+
+ return nwbFile
def register(registry):
# construct and register the generic plugin
- pose_plugin = PoseDLC_generic()
- registry.register(pose_plugin.getFileFormat(), pose_plugin)
+ pose_plugin_generic = PoseDLC_generic()
+ registry.register(pose_plugin_generic.getFileFormat(), pose_plugin_generic)
# construct and register the MARS-style mouse-specific plugin
- pose_plugin = PoseDLC_mouse()
- registry.register(pose_plugin.getFileFormat(), pose_plugin)
\ No newline at end of file
+ pose_plugin_mouse = PoseDLC_mouse()
+ registry.register(pose_plugin_mouse.getFileFormat(), pose_plugin_mouse)
\ No newline at end of file
diff --git a/src/plugin/plugin_pose_MARS.py b/src/plugin/plugin_pose_MARS.py
index 8bf57a3..b8240a2 100644
--- a/src/plugin/plugin_pose_MARS.py
+++ b/src/plugin/plugin_pose_MARS.py
@@ -4,6 +4,13 @@
from qtpy.QtGui import QPainter, QPen, QPolygonF
from qtpy.QtWidgets import QMessageBox, QWidget
from pose.pose import PoseBase
+import h5py as h5
+import os
+import json
+from os.path import splitext
+import numpy as np
+from pynwb import NWBFile
+from ndx_pose import PoseEstimationSeries, PoseEstimation
import pymatreader as pmr
import warnings
@@ -20,7 +27,11 @@ def __init__(self):
super().__init__()
self.pose_polys = []
self.num_frames = 0
+ self.num_mice = 0
self.pose_colors = [Qt.blue, Qt.green]
+ self.keypoints = None
+ self.confidence = None
+ self.video_path = None
def drawPoses(self, painter: QPainter, frame_ix: int):
frame_ix = min(frame_ix, self.num_frames)
@@ -39,42 +50,73 @@ def getFileSearchDescription(self) -> str:
return "MARS pose files"
def getFileSearchPattern(self) -> str:
- return "*.mat"
+ return "*.mat *.json"
def getFileFormat(self) -> str:
return "MARS"
- def validateFile(self, parent_widget: QWidget, file_path: str) -> bool:
- """
- Default implementation does no checking,
- but we can do better than that
- """
+ def _validateFileJSON(self, parent_widget: QWidget, file_path: str) -> bool:
+ with open(file_path, 'r') as f:
+ keys = list(json.load(f).keys())
+ if 'keypoints' not in keys:
+ QMessageBox.warning(parent_widget, "Add Pose ...", "No keypoints found in pose file")
+ return False
+ return True
+
+ def _validateFileMAT(self, parent_widget: QWidget, file_path: str) -> bool:
with warnings.catch_warnings():
# suppress warning coming from checking the mat file contents
warnings.simplefilter('ignore', category=UserWarning)
- poseMat = pmr.read_mat(file_path)
- if 'keypoints' not in poseMat.keys():
+ keys = list(pmr.read_mat(file_path).keys())
+ if 'keypoints' not in keys:
QMessageBox.warning(parent_widget, "Add Pose ...", "No keypoints found in pose file")
return False
return True
+
+ def validateFile(self, parent_widget: QWidget, file_path: str) -> bool:
+ """
+ Default implementation does no checking,
+ but we can do better than that
+ """
+ _, ext = splitext(file_path)
+ ext = ext.lower()
+ if ext == '.mat':
+ return self._validateFileMAT(parent_widget, file_path)
+ elif ext == '.json':
+ return self._validateFileJSON(parent_widget, file_path)
+ else:
+ QMessageBox.warning(parent_widget, "Extension not supported",
+ f"The file extension {ext} is not supported.")
- def loadPoses(self, parent_widget: QWidget, path: str):
- mat = None
- with warnings.catch_warnings():
- # suppress warning coming from checking the mat file contents
- warnings.simplefilter('ignore', category=UserWarning)
- mat = pmr.read_mat(path)
+ def loadPoses(self, parent_widget: QWidget, path: str, video_path: str):
+ data = None
+ _, ext = splitext(path)
+ ext = ext.lower()
+ if ext == '.mat':
+ with warnings.catch_warnings():
+ # suppress warning coming from checking the mat file contents
+ warnings.simplefilter('ignore', category=UserWarning)
+ data = pmr.read_mat(path)
+ elif ext == '.json':
+ with open(path, 'r') as f:
+ data = json.load(f)
+ else:
+ QMessageBox.warning(parent_widget, "Extension not supported",
+ f"The file extension {self.file_extension} is not supported.")
try:
- keypoints = mat['keypoints']
+ self.keypoints = np.array(data['keypoints'])
+ self.confidence = np.array(data['scores'])
except Exception as e:
QMessageBox.about(parent_widget, "Load Error", f"Error loading pose data from file {path}: {e}")
return None
self.pose_polys = []
- self.num_frames = len(keypoints)
+ self.video_path = video_path
+ self.num_frames = len(self.keypoints)
for frame_ix in range(self.num_frames):
- frame_keypoints = keypoints[frame_ix]
+ frame_keypoints = self.keypoints[frame_ix]
frame_polys = []
- for mouse_ix in range(len(frame_keypoints)):
+ self.num_mice = len(frame_keypoints)
+ for mouse_ix in range(self.num_mice):
mouse_keypoints = frame_keypoints[mouse_ix]
pose_x = mouse_keypoints[0]
pose_y = mouse_keypoints[1]
@@ -92,6 +134,101 @@ def loadPoses(self, parent_widget: QWidget, path: str):
frame_polys.append(poly)
self.pose_polys.append(frame_polys)
+ def exportPosesToNWBFile(self, nwbFile: NWBFile):
+
+ processing_module_name = f"Pose data for video {os.path.basename(self.video_path)}"
+ for mouse_ix in range(self.num_mice):
+ pose_estimation_series = []
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'nose',
+ description = 'Pose keypoint placed around nose',
+ data = self.keypoints[:,mouse_ix,:,0],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,0]
+ )
+ )
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'left ear',
+ description = 'Pose keypoint placed around left ear',
+ data = self.keypoints[:,mouse_ix,:,1],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,1]
+ )
+ )
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'right ear',
+ description = 'Pose keypoint placed around right ear',
+ data = self.keypoints[:,mouse_ix,:,2],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,2]
+ )
+ )
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'neck',
+ description = 'Pose keypoint placed around neck',
+ data = self.keypoints[:,mouse_ix,:,3],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,3]
+ )
+ )
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'left hip',
+ description = 'Pose keypoint placed around left hip',
+ data = self.keypoints[:,mouse_ix,:,4],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,4]
+ )
+ )
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'right hip',
+ description = 'Pose keypoint placed around right hip',
+ data = self.keypoints[:,mouse_ix,:,5],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,5]
+ )
+ )
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = 'tail',
+ description = 'Pose keypoint placed around tail',
+ data = self.keypoints[:,mouse_ix,:,6],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,mouse_ix,6]
+ )
+ )
+ pose_estimation = PoseEstimation(
+ pose_estimation_series = pose_estimation_series,
+ name = f"animal_{mouse_ix}",
+ description = f"Estimated position for animal_{mouse_ix} in video {os.path.basename(self.video_path)}",
+ nodes = ['nose', 'left ear', 'right ear', 'neck', 'left hip', 'right hip', 'tail'],
+ edges = np.array([[0,1], [1,3], [3,4], [4,6], [6,5], [5,3], [3,2], [2,0]], dtype='uint8')
+ )
+
+ if processing_module_name in nwbFile.processing:
+ nwbFile.processing[processing_module_name].add(pose_estimation)
+ else:
+ pose_pm = nwbFile.create_processing_module(
+ name = processing_module_name,
+ description = f"Pose Data from {self.getFileFormat().split('_')[0]}"
+ )
+ pose_pm.add(pose_estimation)
+
+ return nwbFile
+
+
def register(registry):
mars_pose_plugin = PoseMARS()
registry.register(mars_pose_plugin.getFileFormat(), mars_pose_plugin)
\ No newline at end of file
diff --git a/src/plugin/plugin_pose_SLEAP.py b/src/plugin/plugin_pose_SLEAP.py
new file mode 100644
index 0000000..5056498
--- /dev/null
+++ b/src/plugin/plugin_pose_SLEAP.py
@@ -0,0 +1,301 @@
+from typing import List, Tuple
+from qtpy.QtCore import QPointF, QLineF
+from qtpy.QtGui import QColor, QPolygonF, QPainter, QPen
+from qtpy.QtWidgets import QMessageBox, QWidget
+from pose.pose import PoseBase, PoseRegistry
+from os.path import splitext
+from pynwb import NWBFile
+from ndx_pose import PoseEstimationSeries, PoseEstimation
+import os
+import numpy as np
+import h5py
+
+
+class PoseSLEAP(PoseBase):
+ """
+ SLEAP plugin that provides support for SLEAP-style pose files
+ represented as HDF5 .h5 files.
+
+ Implements a class derived from PoseBase, which is an abstract class.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.pose_colors: List[QColor] = [
+ QColor(0, 144, 189), # Blue
+ QColor(217, 83, 25), # Orange
+ QColor(237, 177, 32), # Green
+ QColor(126, 47, 142), # Purple
+ QColor(119, 172, 48), # Light green
+ QColor(77, 190, 238), # Light blue
+ QColor(162, 20, 47), # Red
+ ] # Standard SLEAP colors
+ self.pose_polys: List = []
+ self.num_frames: int = 0
+ self.num_nodes: int = 0
+ self.num_instances: int = 0
+ self.has_edges: bool = False
+ self.has_scores: bool = False
+ self.pose_data: np.ndarray = np.array([])
+ self.confidence: np.ndarray = np.array([])
+ self.edge_inds: np.ndarray = np.array([])
+ self.node_names: List = []
+
+ def _drawPoses_noEdges(self, painter: QPainter, frame_ix: int):
+ """
+ Draw poses for SLEAP analysis files with no edge data in analysis file.
+ """
+ for instance_ix in range(len(self.pose_polys[frame_ix])):
+ instance_color = self.pose_colors[instance_ix % len(self.pose_colors)]
+ instance_poly = self.pose_polys[frame_ix][instance_ix]
+ if not instance_poly.isEmpty():
+ painter.setPen(QPen(instance_color, 1.5))
+ painter.drawPolyline(instance_poly)
+ painter.setBrush(instance_color)
+ for node in instance_poly:
+ painter.drawEllipse(node, 5.0, 5.0)
+
+ def _drawPoses_hasEdges(self, painter: QPainter, frame_ix: int):
+ """
+ Draw poses for SLEAP analysis files with edge data in analysis file.
+ """
+ frame_polys, frame_edges = self.pose_polys[frame_ix]
+ for instance_ix in range(len(frame_polys)):
+ instance_color = self.pose_colors[instance_ix % len(self.pose_colors)]
+ instance_poly = frame_polys[instance_ix]
+ instance_edges = frame_edges[instance_ix]
+ if not instance_poly.isEmpty():
+ # Draw nodes
+ painter.setPen(QPen(instance_color, 1.5))
+ painter.setBrush(instance_color)
+ for node in instance_poly:
+ painter.drawEllipse(node, 5.0, 5.0)
+ # Draw edges
+ painter.drawLines(instance_edges)
+
+ def drawPoses(self, painter: QPainter, frame_ix: int):
+ """
+ Determines if edge data is available and calls respective pose drawing methods.
+ """
+ if self.has_edges:
+ self._drawPoses_hasEdges(painter, frame_ix)
+ else:
+ self._drawPoses_noEdges(painter, frame_ix)
+
+ def getFileSearchDescription(self) -> str:
+ """
+ Defines the name of this file class, supported by this derived class. This
+ description will be desplayed as part of the "OpenFile" dialog.
+ """
+ return "SLEAP analysis HDF5 file"
+
+ def getFileSearchPattern(self) -> str:
+ """
+ Defines the file search pattern that should be used to filter files that
+ contain pose data supported by the PoseSLEAP class.
+ """
+ return "*.analysis.h5"
+
+ def getFileFormat(self) -> str:
+ """
+ Defines the file format as a string.
+ """
+ return "SLEAP_analysis_HDF5"
+
+ def _validateFileH5(self, parent_widget: QWidget, file_path: str) -> bool:
+ """
+ Validates that h5 file loaded has necessary pose data. For SLEAP Anlysis HDF5
+ files, the pose data information is stored in the "tracks" dataset.
+ """
+ with h5py.File(file_path, "r") as f:
+ dset_names = list(f.keys())
+
+ try:
+ assert "tracks" and "node_names" in dset_names
+ except AssertionError:
+ QMessageBox.warning(
+ parent_widget,
+ "Add Pose ...",
+ "No SLEAP pose data found in pose file. Ensure loaded as correct file format.",
+ )
+ return False
+
+ return True
+
+ def validateFile(self, parent_widget: QWidget, file_path: str) -> bool:
+ """
+ Check which extension is used, then perform validation specific to file type.
+ """
+ _, ext = splitext(file_path.lower())
+ if ext == ".h5":
+ return self._validateFileH5(parent_widget, file_path)
+ else:
+ QMessageBox.warning(
+ parent_widget,
+ "Extension not supported",
+ f"The fileextension {ext} is not supported for SLEAP pose data.",
+ )
+ return False
+
+ def _loadPoses_noEdges_h5(self, parent_widget: QWidget, file_path: str, video_path: str):
+ """
+ Method for parsing and importing pose data when no edge data is available.
+ Save all pose information in self.pose_polys, self.num_frames, self.num_nodes,
+ and self.num_instances.
+ """
+
+ def append_keypoints(body: QPolygonF, appendage: np.ndarray):
+ """
+ Append nodes to instance poly.
+ """
+ if not np.any(np.isnan(appendage)):
+ body.append(QPointF(appendage[0], appendage[1]))
+
+ with h5py.File(file_path, "r") as f:
+ self.pose_data = f["tracks"][:].T
+ self.node_names = [n.decode() for n in f["node_names"][:]]
+ if self.has_scores:
+ self.confidence = f["point_scores"][:].T
+
+ self.pose_polys: List[List[QPolygonF]] = []
+ self.num_frames, self.num_nodes, _, self.num_instances = self.pose_data.shape
+ self.video_path = video_path
+
+ for frame_ix in range(self.num_frames):
+ # Get all keypoints in current frame using self.pose_data[frame, node, coor, tracks]
+ frame_keypoints = self.pose_data[frame_ix, :, :, :]
+ frame_polys: List[QPolygonF] = []
+ for instance_ix in range(self.num_instances):
+ # Get all points for selected instance
+ instance_keypoints = frame_keypoints[:, :, instance_ix]
+ first_node = instance_keypoints[0, :]
+ poly = QPolygonF()
+ for node_ix in range(self.num_nodes):
+ # Get points for specific node
+ node_keypoints = instance_keypoints[node_ix, :]
+ append_keypoints(poly, node_keypoints)
+ append_keypoints(poly, first_node) # Complete the poly
+ frame_polys.append(poly) # Add instance poly to current frame polys
+ self.pose_polys.append(frame_polys) # Add frame polys to all pose polys
+
+ def _loadPoses_hasEdges_h5(self, parent_widget: QWidget, file_path: str, video_path: str):
+ """
+ Method for parsing and importing pose data when edge data is available.
+ Save all pose information in self.pose_polys, self.num_frames, self.num_nodes,
+ and self.num_instances.
+ """
+
+ def append_keypoints(
+ body: QPolygonF, appendage: np.ndarray, node_lookup: dict, node_name: str
+ ):
+ """
+ Append nodes to instance poly. Also add node to dictionary for creating edges
+ """
+ if not np.any(np.isnan(appendage)):
+ point = QPointF(appendage[0], appendage[1])
+ body.append(point)
+ node_lookup[node_name] = point
+
+ with h5py.File(file_path, "r") as f:
+ self.pose_data = f["tracks"][:].T
+ if self.has_scores:
+ self.confidence = f["point_scores"][:].T
+ if self.has_edges:
+ self.node_names = [n.decode() for n in f["node_names"][:]]
+ edge_names = [(s.decode(), d.decode()) for (s, d) in f["edge_names"][:]]
+ self.edge_inds = f["edge_inds"][:].T
+
+ self.pose_polys: List[Tuple[List[QPolygonF], List[QLineF]]] = []
+ self.num_frames, self.num_nodes, _, self.num_instances = self.pose_data.shape
+ self.video_path = video_path
+
+ for frame_ix in range(self.num_frames):
+ # Get all keypoints in current frame using self.pose_data[frame,node,coor,tracks]
+ frame_keypoints = self.pose_data[frame_ix, :, :, :]
+ frame_polys: List[QPolygonF] = []
+ frame_edges: List[QLineF] = []
+ for instance_ix in range(self.num_instances):
+ # Get all nodes for selected instance
+ instance_keypoints = frame_keypoints[:, :, instance_ix]
+ node_dict = {} # Temporary dict to help build edges
+ instance_edges = [] # All edges in an instance
+ poly = QPolygonF() # Polygon for nodes in instance
+ for node_ix, node_nombre in enumerate(self.node_names):
+ # Get points for specific node
+ node_keypoints = instance_keypoints[node_ix, :]
+ append_keypoints(poly, node_keypoints, node_dict, node_nombre)
+ for (src_node, des_node) in edge_names:
+ if (src_node in node_dict) and (des_node in node_dict):
+ # Add edges using node points
+ edge = QLineF()
+ edge.setPoints(node_dict[src_node], node_dict[des_node])
+ instance_edges.append(edge)
+ frame_edges.append(instance_edges)
+ frame_polys.append(poly)
+ self.pose_polys.append(
+ (frame_polys, frame_edges)
+ ) # Add frame polys to all pose polys
+
+ def loadPoses(self, parent_widget: QWidget, file_path: str, video_path: str):
+ """
+ Method for parsing and importing all the pose data. Determines whether edge
+ data is present and calls respective method for loading poses.
+ """
+ # Determine if edge data is available in analysis file.
+ self.has_edges = False
+ with h5py.File(file_path, "r") as f:
+ dset_names = list(f.keys())
+ if "edge_names" in dset_names:
+ self.has_edges = True
+ if "point_scores" in dset_names:
+ self.has_scores = True
+
+ if self.has_edges:
+ # Load poses with edge data available.
+ self._loadPoses_hasEdges_h5(parent_widget, file_path, video_path)
+ else:
+ # Load poses without edge data available.
+ self._loadPoses_noEdges_h5(parent_widget, file_path, video_path)
+
+ def exportPosesToNWBFile(self, nwbFile: NWBFile):
+ processing_module_name = f"Pose data for video {os.path.basename(self.video_path)}"
+
+ for instance_ix in range(self.num_instances):
+ pose_estimation_series = []
+ for nodes_ix in range(self.num_nodes):
+ pose_estimation_series.append(
+ PoseEstimationSeries(
+ name = f"{self.node_names[nodes_ix]}",
+ description = f"Pose keypoint placed around {self.node_names[nodes_ix]}",
+ data =self.pose_data[:,nodes_ix,:,instance_ix],
+ reference_frame = "The coordinates are in (x, y) relative to the top-left of the image",
+ timestamps = np.arange(self.num_frames, dtype=float),
+ confidence = self.confidence[:,nodes_ix,instance_ix] if self.has_scores
+ else np.full(self.pose_data.shape[0], -1, dtype='float64')
+ )
+ )
+ pose_estimation = PoseEstimation(
+ pose_estimation_series = pose_estimation_series,
+ name = f"animal_{instance_ix}",
+ description = f"Estimated position for animal_{instance_ix} in video {os.path.basename(self.video_path)}",
+ nodes = self.node_names,
+ edges = self.edge_inds.astype('uint64') if self.has_edges else None,
+ )
+ if processing_module_name in nwbFile.processing:
+ nwbFile.processing[processing_module_name].add(pose_estimation)
+ else:
+ pose_pm = nwbFile.create_processing_module(
+ name = processing_module_name,
+ description = f"Pose Data from {self.getFileFormat().split('_')[0]}"
+ )
+ pose_pm.add(pose_estimation)
+
+ return nwbFile
+
+
+def register(registry: PoseRegistry):
+ """
+ Method to register pose plugin in pose registry.
+ """
+ pose_plugin = PoseSLEAP()
+ registry.register(pose_plugin.getFileFormat(), pose_plugin)
diff --git a/src/plugin/plugin_processing_behaviorTriggeredAverage.py b/src/plugin/plugin_processing_behaviorTriggeredAverage.py
new file mode 100644
index 0000000..d5f97ca
--- /dev/null
+++ b/src/plugin/plugin_processing_behaviorTriggeredAverage.py
@@ -0,0 +1,527 @@
+from qtpy.QtWidgets import QFrame, QMessageBox, QCheckBox, QMenu, QFileDialog, QLabel
+from plugin.behaviorTriggeredAverage_ui import Ui_BTAFrame
+from qtpy.QtCore import Slot
+from os.path import expanduser, sep
+from qtpy.QtGui import QColor
+import numpy as np
+import pandas as pd
+from scipy import signal
+import math
+import os
+import timecode as tc
+from processing.processing import ProcessingBase
+from vispy.scene import SceneCanvas, visuals, AxisWidget
+from vispy.visuals.transforms import STTransform
+from PIL import Image
+
+
+class behaviorTriggeredAverage(QFrame, ProcessingBase):
+ def __init__(self, nwbFile, bento):
+ QFrame.__init__(self)
+ ProcessingBase.__init__(self, nwbFile, bento)
+ self.bento = bento
+ self.checkData()
+ if not self.annotationsExists or not self.neuralExists:
+ msgBox = QMessageBox(QMessageBox.Warning,
+ "Required Data not found",
+ "Both neural data and annotation data \
+ must exist for this plugin to work")
+ msgBox.exec()
+ raise RuntimeError("Either Neural data or Annotation data does not exist for this plugin to work.")
+ self.bento.nwbFileUpdated.connect(self.getAnnotationsData)
+ self.bento.behaviorsChanged.connect(self.getBehaviors)
+ self.bento.behaviorsChanged.connect(self.populateBehaviorSelection)
+ self.bento.behaviorsChanged.connect(self.getBehaviorTriggeredTrials)
+
+ self.getAnnotationsData()
+ self.getNeuralData()
+ self.getBehaviors()
+ self.invokeUI()
+
+ def invokeUI(self):
+ #setting up UI
+ self.ui = Ui_BTAFrame()
+ self.ui.setupUi(self)
+ # initialize few variables
+ self.checkboxState = {}
+ self.bev, self.ch = None, None
+ self.combineBehaviorNames, self.combineChannels = [], []
+
+ # populating combo boxes
+ self.populateBehaviorCombo()
+ self.populateChannelsCombo()
+ self.populateAnalyzeCombo()
+
+ # connect nwb file update signal to populate functions
+ self.bento.nwbFileUpdated.connect(self.populateBehaviorCombo)
+ self.bento.nwbFileUpdated.connect(self.populateChannelsCombo)
+
+ # creating save menu and connect them to saving functions
+ self.saveMenu = QMenu("Save Options")
+ self.ui.saveButton.setMenu(self.saveMenu)
+ self.ui.saveButton.setToolTip("click to see save options")
+ self.saveh5 = self.saveMenu.addAction("Save BTA to h5 file")
+ self.saveFigure = self.saveMenu.addAction("Save Figure")
+ self.saveh5.triggered.connect(self.saveBTAtoh5)
+ self.saveFigure.triggered.connect(self.savePlots)
+
+ # setting default value for bin size
+ self.ui.binSizeBox.setValue(float(1/self.neuralSampleRate))
+
+ # connecting different user options getBehaviorTriggeredTrials function
+ self.ui.binSizeBox.valueChanged.connect(self.setBinSizeBoxValueToMin)
+ self.ui.mergeBoutsBox.textChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.discardBoutsBox.textChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.channelComboBox.currentTextChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.behaviorComboBox.currentTextChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.analyzeComboBox.currentTextChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.alignAtStartButton.toggled.connect(self.getBehaviorTriggeredTrials)
+ self.ui.alignAtEndButton.toggled.connect(self.getBehaviorTriggeredTrials)
+ self.ui.windowBox_1.textChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.windowBox_2.textChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.binSizeBox.valueChanged.connect(self.getBehaviorTriggeredTrials)
+ self.ui.zscoreCheckBox.stateChanged.connect(self.getBehaviorTriggeredTrials)
+
+ # populating behavior selection checkboxes
+ self.populateBehaviorSelection()
+
+ self.getBehaviorTriggeredTrials()
+
+ def getType(self):
+ return "BTA"
+
+ @Slot()
+ def setBinSizeBoxValueToMin(self):
+ if self.ui.binSizeBox.value()(index+window[1]):
+ trial = neuralData[index-window[0]-1:index+window[1]]
+ elif window[0]>=index and neuralData.shape[0]>(index+window[1]):
+ mismatch = window[0]-index+1
+ trial = np.zeros(window[0]+window[1]+1)
+ trial[:mismatch] = np.nan
+ trial[mismatch:] = neuralData[:index+window[1]]
+ else:
+ mismatch = (index+window[1])-neuralData.shape[0]
+ trial1 = np.zeros(mismatch)
+ trial1[:] = np.nan
+ trial2 = neuralData[index-window[0]-1:neuralData.shape[0]]
+ trial = np.concatenate((trial2, trial1))
+
+ return trial
+
+ def mergeAndDiscardBouts(self, startTime, stopTime):
+ self.mergeBoutsTime = float(self.ui.mergeBoutsBox.value())
+ self.discardBoutsTime = float(self.ui.discardBoutsBox.value())
+ startTime, stopTime = np.array(startTime), np.array(stopTime)
+ keptIndicesStart = [0]
+ flag = None
+ for i in range(startTime.shape[0]-1):
+ if startTime[i+1]-stopTime[i] 0)[0]]
+ self.trials = np.full((self.alignTime.shape[0], self.trialsTs.shape[0]), np.nan)
+ self.backgroundAnnotations = dict()
+ resampledData = signal.resample(self.checkAnalyzeComboAndGetData(),
+ num=num, t=self.neuralDataTs)[0]
+
+ # getting all the trials along with background annotations for each trial
+ for ix in range(self.alignTime.shape[0]):
+ idx = int(round((self.alignTime[ix])*self.sampleRate))
+ if self.ui.zscoreCheckBox.isChecked():
+ self.trials[ix, :] = self.zscore(self.createTrial(resampledData,
+ idx,
+ self.windowNumTs))
+ else:
+ self.trials[ix, :] = self.createTrial(resampledData,
+ idx,
+ self.windowNumTs)
+ start_time = np.array(self.annotationsData[self.ch]['start_time']) - self.alignTime[ix]
+ stop_time = np.array(self.annotationsData[self.ch]['stop_time']) - self.alignTime[ix]
+ temp_df = self.annotationsData[self.ch].iloc[np.where((stop_time>=-self.window[0])
+ & (start_time<=self.window[1]))[0],:]
+ copy_df = temp_df.copy()
+ copy_df.loc[:,'start_time'] = temp_df['start_time'] - self.alignTime[ix]
+ copy_df.loc[:,'stop_time'] = temp_df['stop_time'] - self.alignTime[ix]
+ for bev in list(self.checkboxState.keys()):
+ if not self.checkboxState[bev].isChecked():
+ copy_df.drop(copy_df[copy_df['behaviorName']==bev].index, inplace=True)
+ self.backgroundAnnotations[str(ix)] = copy_df
+
+ # handling a case when there is only one trial or no trials at all
+ if self.trials.shape[0]==0:
+ self.clearPlotLayout()
+ self.addWarningForNoData()
+ elif self.trials.shape[0]==1:
+ self.avgTrials = self.trials[0]
+ self.errTrials = np.zeros(self.trialsTs.shape[0])
+ # plotting trials along with annotations in the background
+ self.plotBTA()
+ else:
+ self.avgTrials = np.nanmean(self.trials, axis=0)
+ self.errTrials = np.nanstd(self.trials, axis=0)/math.sqrt(self.trials.shape[0])
+ # plotting trials along with annotations in the background
+ self.plotBTA()
+ else:
+ self.clearPlotLayout()
+ self.addWarningForNoData()
+
+ def addWarningForNoData(self):
+ self.label_top = QLabel("Data corresponding to this behavior/channel does not exist")
+ self.label_bot = QLabel("Data corresponding to this behavior/channel does not exist")
+ self.ui.plotLayout.addWidget(self.label_top, stretch=1)
+ self.ui.plotLayout.addWidget(self.label_bot, stretch=2)
+
+ def clearPlotLayout(self):
+ for i in reversed(range(self.ui.plotLayout.count())):
+ self.ui.plotLayout.itemAt(i).widget().deleteLater()
+
+ def plotBTA(self):
+ self.clearPlotLayout()
+ self.createAnnotationsImgArray()
+ self.canvas_top = SceneCanvas(size=(self.width,200))
+ self.grid_top = self.canvas_top.central_widget.add_grid()
+ self.xaxis_top = AxisWidget(orientation='bottom',
+ axis_label='Time (s)',
+ axis_color='black',
+ tick_color='black',
+ text_color='black',
+ axis_font_size=12,
+ font_size= 8,
+ axis_label_margin=16,
+ tick_label_margin=14)
+ self.xaxis_top.height_max = 35
+ self.xaxis_top.bgcolor = 'white'
+ self.grid_top.add_widget(self.xaxis_top, row=1, col=1)
+
+ self.yaxis_top = AxisWidget(orientation='left',
+ axis_color='black',
+ tick_color='black',
+ text_color='black',
+ axis_font_size=12,
+ font_size= 8,
+ axis_label_margin=16,
+ tick_label_margin=8)
+ self.yaxis_top.width_max = 35
+ self.yaxis_top.bgcolor = 'white'
+ self.grid_top.add_widget(self.yaxis_top, row=0, col=0)
+
+ self.view_top = self.grid_top.add_view(0, 1, bgcolor='white')
+ self.lineAvgTrials = visuals.Line(
+ np.column_stack((self.trialsTs, self.avgTrials)),
+ parent=self.view_top.scene,
+ color='black'
+ )
+ errPosValues = self.avgTrials + self.errTrials
+ errNegValues = self.avgTrials - self.errTrials
+ if np.count_nonzero(self.errTrials)!=0:
+ notZeroIndices = np.where(~(self.errTrials==0))[0]
+ # handing a case when there are two trials and one has few NaN values
+ errPosValues = self.avgTrials[notZeroIndices] + self.errTrials[notZeroIndices]
+ errNegValues = self.avgTrials[notZeroIndices] - self.errTrials[notZeroIndices]
+ self.errPosLine = visuals.Line(
+ np.column_stack((self.trialsTs[notZeroIndices], errPosValues)),
+ parent=self.view_top.scene,
+ color='white'
+ )
+ self.errNegLine = visuals.Line(
+ np.column_stack((self.trialsTs[notZeroIndices], errNegValues)),
+ parent=self.view_top.scene,
+ color='white'
+ )
+ self.fillBetween = visuals.Polygon(
+ pos=np.vstack((self.errPosLine.pos, self.errNegLine.pos[::-1])),
+ color=(0.85, 0.85, 0.85, 1),
+ parent=self.view_top.scene
+ )
+ self.centerLineTop = visuals.InfiniteLine(0,
+ color=[0,0,0,1],
+ vertical=True,
+ line_width=2,
+ parent=self.view_top.scene)
+ self.view_top.camera = "panzoom"
+ self.view_top.interactive = False
+ self.view_top.camera.set_range(x=(self.trialsTs[0], self.trialsTs[-1]),
+ y=(np.nanmin(errNegValues), np.nanmax(errPosValues)))
+ self.xaxis_top.link_view(self.view_top)
+ self.yaxis_top.link_view(self.view_top)
+
+ self.rememberPos = []
+ height = 0
+ for t in range(self.trials.shape[0]):
+ pos = np.empty((self.trials.shape[1], 2), dtype=np.float32)
+ pos[:, 0] = self.trialsTs
+ pos[:, 1] = np.interp(self.trials[t,:],
+ (np.nanmin(self.trials[t,:]), np.nanmax(self.trials[t,:])),
+ (height+2, height+self.trialHeight-2))
+ self.rememberPos.append(pos)
+ height += self.trialHeight
+
+ self.canvas_bot = SceneCanvas(size=(self.width,self.height))
+ self.grid_bot = self.canvas_bot.central_widget.add_grid()
+ self.xaxis_bot = AxisWidget(orientation='bottom',
+ axis_label='Time (s)',
+ axis_color='black',
+ tick_color='black',
+ text_color='black',
+ axis_font_size=12,
+ font_size= 8,
+ axis_label_margin=16,
+ tick_label_margin=14)
+ self.xaxis_bot.height_max = 35
+ self.xaxis_bot.bgcolor = 'white'
+ self.grid_bot.add_widget(self.xaxis_bot, row=1, col=0)
+ self.view_bot = self.grid_bot.add_view(0, 0, bgcolor='white')
+ for p in range(len(self.rememberPos)):
+ nanIndices = np.where(~np.isnan(self.rememberPos[p]))[0]
+ self.lineTrials = visuals.Line(self.rememberPos[p][nanIndices,:],
+ parent=self.view_bot.scene,
+ color='black',
+ width=1)
+ self.annotationsImage = visuals.Image(self.imgArray,
+ texture_format="auto",
+ parent=self.view_bot.scene)
+ self.annotationsImage.transform = STTransform(scale=((self.window[0]+self.window[1])/self.width),
+ translate=(-(self.window[0]), 0))
+ self.centerLineBot = visuals.InfiniteLine(0,
+ color=[0,0,0,1],
+ vertical=True,
+ line_width=2,
+ parent=self.view_bot.scene)
+ self.view_bot.camera = "panzoom"
+ self.view_bot.camera.interactive = False
+ self.view_bot.camera.set_range(x=(-self.window[0], self.window[1]),
+ y=(0, self.height))
+ self.xaxis_bot.link_view(self.view_bot)
+
+ self.ui.plotLayout.addWidget(self.canvas_top.native, stretch=1)
+ self.ui.plotLayout.addWidget(self.canvas_bot.native, stretch=2)
+
+ def createAnnotationsImgArray(self):
+ self.trialHeight = 10
+ self.height, self.width, channels = self.trials.shape[0]*self.trialHeight, int((self.trials.shape[1])*(self.window[0]+self.window[1])), 3
+ secondLength = int(self.width/(self.window[0]+self.window[1]))
+ self.imgArray = np.ones((self.height, self.width, channels), dtype=np.float32)
+ height = 0
+ for t in range(self.trials.shape[0]):
+ annotations = np.array(self.backgroundAnnotations[str(t)])
+ annotations[:,0], annotations[:,1] = annotations[:,0] + self.window[0], annotations[:,1] + self.window[0]
+ for j in range(annotations.shape[0]):
+ start = int(round((annotations[j,0]) * secondLength))
+ end = int(round((annotations[j,1] +
+ tc.Timecode(self.bento.annotationsScene.sample_rate, frames=1).float) * secondLength))
+ bev = self.behaviors[annotations[j,2]]
+ self.imgArray[height:int(height+self.trialHeight),start:end,0] = bev[0]
+ self.imgArray[height:int(height+self.trialHeight),start:end,1] = bev[1]
+ self.imgArray[height:int(height+self.trialHeight),start:end,2] = bev[2]
+ height += self.trialHeight
+
+ def saveBTAtoh5(self):
+ fileName, selectedFilter = QFileDialog.getSaveFileName(
+ self,
+ caption="Save Behavior Triggered Average Plots",
+ filter="h5 file (*.h5)",
+ selectedFilter="h5 file (*.h5)",
+ dir=expanduser('~'))
+ if selectedFilter == "h5 file (*.h5)":
+ cols = ['Bout_'+str(i) for i in np.arange(1,self.trials.shape[0]+1)]
+ df = pd.DataFrame(self.trials.T, columns=cols)
+ df.to_hdf(fileName, key='dataframe', mode='w')
+ print("BTA saved to h5 file")
+ else:
+ raise NotImplementedError(f"File format {selectedFilter} not supported")
+
+ def savePlots(self):
+ fileName, selectedFilter = QFileDialog.getSaveFileName(
+ self,
+ caption="Save Behavior Triggered Average Plots",
+ filter="eps file (*.eps)",
+ selectedFilter="eps file (*.eps)",
+ dir=expanduser('~'))
+ if selectedFilter == "eps file (*.eps)":
+ dirname = os.path.dirname(fileName)
+ files = [os.path.join(dirname, os.path.basename(fileName).split('.')[0]+'_top.eps'),
+ os.path.join(dirname, os.path.basename(fileName).split('.')[0]+'_bottom.eps')]
+ self.imageArr = [self.canvas_top.render(alpha=False), self.canvas_bot.render(alpha=False)]
+ for i in range(len(self.imageArr)):
+ img = Image.fromarray(self.imageArr[i])
+ img.save(files[i], format='eps', resolution=100.)
+ print("BTA plots saved.")
+ else:
+ raise NotImplementedError(f"File format {selectedFilter} not supported")
+
+
+
+
+def register(registry, nwbFile=None, bento=None):
+ btaProcessingPlugin = behaviorTriggeredAverage(nwbFile, bento)
+ registry.register(btaProcessingPlugin.getType(), btaProcessingPlugin)
+
+
+
diff --git a/src/plugin/plugin_processing_kMeansClustering.py b/src/plugin/plugin_processing_kMeansClustering.py
new file mode 100644
index 0000000..0c2a53d
--- /dev/null
+++ b/src/plugin/plugin_processing_kMeansClustering.py
@@ -0,0 +1,196 @@
+import os
+# to avoid memory leak in Windows
+if os.name == 'nt':
+ os.environ["OMP_NUM_THREADS"] = '1'
+import math
+import numpy as np
+import matplotlib.pyplot as plt
+from vispy.scene import SceneCanvas, visuals, AxisWidget
+from qtpy.QtWidgets import QDialog, QMessageBox, QWidget, QHBoxLayout
+from qtpy.QtCore import Signal, Slot
+from plugin.kMeansClustersDialog_ui import Ui_kMeansClustersDialog
+from processing.processing import ProcessingBase
+from sklearn.cluster import KMeans
+from sklearn.preprocessing import StandardScaler
+from sklearn.model_selection import KFold
+
+
+class kMeansClustering(QDialog, ProcessingBase):
+ def __init__(self, nwbFile, bento):
+ QDialog.__init__(self)
+ ProcessingBase.__init__(self, nwbFile, bento)
+ self.bento = bento
+ self.checkData()
+ if not self.neuralExists:
+ msgBox = QMessageBox(QMessageBox.Warning,
+ "Required Data not found",
+ "Neural data \
+ must exist for this plugin to work")
+ msgBox.exec()
+ raise RuntimeError("Neural Data does not exist for this plugin to work.")
+
+ self.getNeuralData()
+ self.invokeUI()
+
+ def invokeUI(self):
+ #setting up UI
+ self.ui = Ui_kMeansClustersDialog()
+ self.ui.setupUi(self)
+
+ def getType(self):
+ return "kMeansClustering"
+
+ def setNeural(self, neural):
+ self.neuralFrame = neural
+
+ def preprocessing(self, data):
+ scaler = StandardScaler()
+ self.scaledData = scaler.fit_transform(data)
+
+ def calculateBIC(self, data, labels):
+ nPoints = len(labels)
+ nClusters = len(set(labels))
+ nDimensions = data.shape[1]
+
+ nParameters = (nClusters - 1) + (nDimensions * nClusters) + 1
+
+ loglikelihood = 0
+ for labelName in set(labels):
+ dataCluster = data[labels == labelName]
+ nPointsCluster = len(dataCluster)
+ if nPointsCluster==1:
+ loglikelihood += -(nPointsCluster * np.log(nPoints))
+ else:
+ centroid = np.mean(dataCluster, axis=0)
+ variance = np.mean((dataCluster - centroid) ** 2)
+ loglikelihood += \
+ nPointsCluster * np.log(nPointsCluster) \
+ - nPointsCluster * np.log(nPoints) \
+ - nPointsCluster * nDimensions / 2 * np.log(2 * math.pi * variance) \
+ - (nPointsCluster - 1) / 2
+
+ bic = loglikelihood - (nParameters / 2) * np.log(nPoints)
+
+ return bic
+
+ def computeCrossValidation(self, data):
+ # get values from UI
+ self.clusterStart = int(self.ui.clusterRangeBox1.text())
+ self.clusterStop = int(self.ui.clusterRangeBox2.text())
+ self.numFolds = int(self.ui.crossValidationFoldsBox.text())
+ self.preprocessing(data)
+ kMeansKwargs = {
+ "init": "k-means++",
+ "n_init": 10,
+ "max_iter": 300
+ }
+ # cross validation
+ kf = KFold(n_splits=self.numFolds, shuffle=True)
+ self.avgBicValues = []
+ for k in range(self.clusterStart, self.clusterStop):
+ bic_values = []
+ for trainInd, valInd in kf.split(self.scaledData):
+ scaledDataTrain, scaledDataVal = self.scaledData[trainInd], self.scaledData[valInd]
+ kmeans = KMeans(n_clusters=k, **kMeansKwargs)
+ kmeans.fit(scaledDataTrain)
+ valLabels = kmeans.predict(scaledDataVal)
+ # Calculate bic
+ bic = self.calculateBIC(scaledDataVal, valLabels)
+ bic_values.append(bic)
+ self.avgBicValues.append(np.mean(bic_values))
+
+
+ def plotBicValues(self):
+ self.bicCanvas = SceneCanvas(size=(600,600))
+ self.bicGrid = self.bicCanvas.central_widget.add_grid()
+ self.bicXaxis = AxisWidget(orientation='bottom',
+ axis_label='Number of Clusters',
+ axis_color='black',
+ tick_color='black',
+ text_color='black',
+ axis_font_size=12,
+ font_size= 8,
+ axis_label_margin=20,
+ tick_label_margin=14)
+ self.bicXaxis.height_max = 50
+ self.bicXaxis.bgcolor = 'white'
+ self.bicGrid.add_widget(self.bicXaxis, row=1, col=1)
+
+ self.bicYaxis = AxisWidget(orientation='left',
+ axis_label='BIC Values',
+ axis_color='black',
+ tick_color='black',
+ text_color='black',
+ axis_font_size=12,
+ font_size= 8,
+ axis_label_margin=70,
+ tick_label_margin=8)
+ self.bicYaxis.width_max = 100
+ self.bicYaxis.bgcolor = 'white'
+ self.bicGrid.add_widget(self.bicYaxis, row=0, col=0)
+
+
+ self.bicView = self.bicGrid.add_view(0, 1, bgcolor='white')
+ self.bicPlot = visuals.Line(
+ np.column_stack((np.arange(self.clusterStart, self.clusterStop), self.avgBicValues)),
+ parent=self.bicView.scene,
+ color='black'
+ )
+ self.bicView.camera = "panzoom"
+ self.bicView.interactive = False
+ self.bicView.camera.set_range(x=(self.clusterStart, self.clusterStop),
+ y=(min(self.avgBicValues), max(self.avgBicValues)))
+ self.bicXaxis.link_view(self.bicView)
+ self.bicYaxis.link_view(self.bicView)
+
+ # show plot on a widget
+ self.plotWidget = QWidget()
+ self.plotLayout = QHBoxLayout()
+ self.plotLayout.addWidget(self.bicCanvas.native)
+ self.plotWidget.setWindowTitle("Bayesion Information Criterion Values vs # of Clusters")
+ self.plotWidget.setLayout(self.plotLayout)
+ self.plotWidget.show()
+
+
+ def kMeansClustering(self, data):
+ kMeansKwargs = {
+ "init": "k-means++",
+ "n_init": 10,
+ "max_iter": 300
+ }
+ self.numOfClusters = int(self.ui.numOfClustersBox.text())
+ self.preprocessing(data)
+ self.kmeans = KMeans(n_clusters=self.numOfClusters, **kMeansKwargs)
+ self.kmeans.fit(self.scaledData)
+
+ def getNewOrderAfterClusterting(self):
+ newOrder = np.array([], dtype=int)
+ endOfLabel = np.array([], dtype=int)
+ labels = self.kmeans.labels_
+ for c in range(self.numOfClusters):
+ indices = np.where(labels==c)[0]
+ if c<1:
+ endOfLabel = np.concatenate((endOfLabel, [indices.shape[0]]), axis=0, dtype=int)
+ else:
+ endOfLabel = np.concatenate((endOfLabel, [endOfLabel[-1] + indices.shape[0]]), axis=0, dtype=int)
+ newOrder = np.concatenate((newOrder, indices), axis=0, dtype=int)
+ self.neuralFrame.neuralScene.reorderTracesAndHeatmap(newOrder, endOfLabel)
+
+ @Slot()
+ def accept(self):
+ if self.ui.gofRadioButton.isChecked():
+ self.computeCrossValidation(self.neuralData)
+ self.plotBicValues()
+ else:
+ self.kMeansClustering(self.neuralData)
+ self.getNewOrderAfterClusterting()
+
+ super().accept()
+
+ @Slot()
+ def reject(self):
+ super().reject()
+
+def register(registry, nwbFile=None, bento=None):
+ kMeansClusteringPlugin = kMeansClustering(nwbFile, bento)
+ registry.register(kMeansClusteringPlugin.getType(), kMeansClusteringPlugin)
\ No newline at end of file
diff --git a/src/pose/pose.py b/src/pose/pose.py
index ee73f76..3f8b57b 100644
--- a/src/pose/pose.py
+++ b/src/pose/pose.py
@@ -5,6 +5,7 @@
from os import curdir, listdir
from os.path import abspath, sep, splitext
from importlib import import_module
+from pynwb import NWBFile
import sys
class PoseBase():
@@ -50,7 +51,7 @@ def validateFile(self, parent_widget: QWidget, file_path: str) -> bool:
"""
return True
- def loadPoses(self, parent_widget: QWidget, file_path: str):
+ def loadPoses(self, parent_widget: QWidget, file_path: str, video_path: str):
"""
Base class template for parsing and importing all the
pose data. Real implementations should save away everything
@@ -59,6 +60,13 @@ def loadPoses(self, parent_widget: QWidget, file_path: str):
raise NotImplementedError("PoseBase: abstract base class. ",
"Please implement this in your derived class")
+ def exportPosesToNWBFile(self, nwbFile: NWBFile):
+ """
+ Base class template for exporting pose data to NWB file object
+ """
+ raise NotImplementedError("PoseBase: abstract base class. ",
+ "Please implement this in your derived class")
+
class PoseRegistry():
"""
Class that loads and manages pose plug-ins
diff --git a/src/processing/processing.py b/src/processing/processing.py
new file mode 100644
index 0000000..bf10a93
--- /dev/null
+++ b/src/processing/processing.py
@@ -0,0 +1,154 @@
+from qtpy.QtWidgets import QFrame
+from qtpy.QtCore import Slot
+from pynwb import NWBFile
+from qtpy.QtWidgets import QFileDialog, QMessageBox, QWidget
+from os import curdir, listdir
+from os.path import abspath, sep, splitext, expanduser
+from importlib import import_module
+from pynwb import NWBFile
+import numpy as np
+import sys
+
+
+
+class ProcessingBase():
+ """
+ Abstract class from which to derive post processing plug-ins
+ """
+
+ def __init__(self, nwbFile, bento):
+ self.nwbFile = nwbFile
+ self.bento = bento
+ self.neuralExists = False
+ self.annotationsExists = False
+ self.poseExists = False
+
+ def checkData(self):
+ """
+ Base class template for checking if the required data exists
+ in the NWBFile object for launching a selected post-processing
+ module.
+ """
+ if list(self.nwbFile.acquisition.keys()):
+ self.neuralExists = True
+ if list(self.nwbFile.intervals.keys()):
+ self.annotationsExists = True
+ if list(self.nwbFile.processing.keys()):
+ self.poseExists = True
+
+ @Slot()
+ def getAnnotationsData(self):
+ """
+ Base class template for getting annotations data from the NWBFile object.
+ Implementation should save all the required information in class
+ variables in order to be able to plot the results
+ """
+ self.channels = list(self.nwbFile.intervals.keys())
+ self.annotationsData = {}
+ self.behaviorNames = {}
+ for ch in self.channels:
+ channelName = ch.split('_', 1)[-1]
+ self.annotationsData[channelName] = self.nwbFile.intervals[ch].to_dataframe()
+ if 'behaviorName' in list(self.annotationsData[channelName].columns):
+ self.behaviorNames[channelName] = np.unique(np.array(self.annotationsData[channelName]['behaviorName']))
+
+ def getNeuralData(self):
+ """
+ Base class template for getting neural data from the NWBFile object.
+ Implementation should save all the required information in class
+ variables in order to be able to plot the results
+ """
+ self.neuralData = self.nwbFile.acquisition['neural_data'].data[:]
+ self.neuralSampleRate = self.nwbFile.acquisition['neural_data'].rate
+ self.neuralStartTime = self.nwbFile.acquisition['neural_data'].starting_time
+
+
+ def getPoseData(self):
+ """
+ Base class template for getting pose data from the NWBFile object.
+ Implementation should save all the required information in class
+ variables in order to be able to plot the results
+ """
+
+ @Slot()
+ def getBehaviors(self):
+ """
+ Base class template for getting behavior names and color codes
+ from the color_profiles file.
+ Implementation should save all the required information in class
+ variables in order to be able to plot the results
+ """
+ path = expanduser("~") + sep + ".bento" + sep
+ profilePaths = [path, ""]
+ self.behaviors = {}
+ for path in profilePaths:
+ try:
+ fn = path + 'color_profiles.txt'
+ with open(fn,'r') as f:
+ line = f.readline()
+ while line:
+ hot_key, name, r, g, b = line.strip().split(' ')
+ if hot_key == '_':
+ hot_key = ''
+ self.behaviors[name] = [float(r), float(g), float(b)]
+ line = f.readline()
+ break # no exception, so success
+ except Exception as e:
+ print(f"Exception caught: {e}")
+ continue
+
+
+ def invokeUI(self) -> QFrame:
+ """
+ Base class template to invoke the UI with all the necessary user-defined
+ options, necessary to process the data
+ """
+
+
+class ProcessingRegistry():
+ """
+ Class that loads and manages processing plug-ins
+ """
+
+ def __init__(self, nwbFile=None, bento=None):
+ self.processing_modules = {}
+ self.plugin_dir = None
+ self.nwbFile = nwbFile
+ self.bento = bento
+
+ def __call__(self, type: str):
+ if not type in self.processing_modules:
+ return None
+ return self.processing_modules[type]
+
+ def register(self, type: str, module):
+ self.processing_modules[type] = module
+
+ def load_plugins(self):
+ """
+ Search the plugin directory for python files of
+ the form "plugin_processing_*.py"
+ Import any that are found, and call their "register" function if there is one
+ If no "register" function exists, the plugin will not be available for use.
+ """
+ plugin_dir = sys.path[0]
+ if 'src' in listdir(plugin_dir):
+ plugin_dir += sep + 'src'
+ if 'plugin' in listdir(plugin_dir):
+ plugin_dir += sep + 'plugin'
+ else:
+ return
+ self.plugin_dir = abspath(plugin_dir)
+ sys.path.append(self.plugin_dir)
+ paths = listdir(self.plugin_dir)
+ for path in paths:
+ if not path.lower().startswith('plugin_processing_'):
+ continue
+ stem, _ = splitext(path)
+ m = import_module(stem)
+ if 'register' not in dir(m):
+ continue
+ m.register(self, self.nwbFile, self.bento)
+
+ def getPluginDir(self) -> str:
+ return self.plugin_dir
diff --git a/src/timeEdit.py b/src/timeEdit.py
new file mode 100644
index 0000000..0058691
--- /dev/null
+++ b/src/timeEdit.py
@@ -0,0 +1,23 @@
+from qtpy.QtWidgets import QTimeEdit
+
+
+class CustomTimeEdit(QTimeEdit):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.bento = None
+ self.setDisplayFormat("HH:mm:ss.zzz")
+
+ def set_bento(self, bento):
+ self.bento = bento
+
+ def keyPressEvent(self, event):
+ # Override keyPressEvent to capture changes when keys are pressed
+ super().keyPressEvent(event)
+ self.bento.jumpToTime()
+
+ def mousePressEvent(self, event):
+ # Override keyPressEvent to capture changes when keys are pressed
+ super().mousePressEvent(event)
+ self.bento.jumpToTime()
+
+
\ No newline at end of file
diff --git a/src/timeSource.py b/src/timeSource.py
new file mode 100644
index 0000000..aab329c
--- /dev/null
+++ b/src/timeSource.py
@@ -0,0 +1,188 @@
+# timeSource.py
+
+from qtpy.QtCore import QObject, QTimer, Signal, Slot
+from qtpy.QtWidgets import QGraphicsScene
+from qtpy.QtMultimedia import QMediaPlayer
+from timecode import Timecode
+from video.videoScene import VideoSceneNative
+
+class TimeSourceAbstractBase(QObject):
+ """
+ Abstract base class for sources of timing ticks,
+ whether internal or external
+ """
+
+ timeChanged: Signal = Signal(Timecode)
+ tickRateChanged: Signal = Signal(float)
+
+ def __init__(self, notifyTimeChanged: Slot):
+ super().__init__()
+ self._currentTime = Timecode(30., "00:00:00:01")
+ self._maxFrameRate: float = 1.
+ self._minFrameRate: float = 1.
+ self._frameRate: float = 1.
+ self.timeChanged.connect(notifyTimeChanged)
+
+ def start(self):
+ raise NotImplementedError("The derived class needs to implement this method")
+
+ def stop(self):
+ raise NotImplementedError("The derived class needs to implement this method")
+
+ @Slot()
+ def doubleFrameRate(self):
+ raise NotImplementedError("The derived class needs to implement this method")
+
+ @Slot()
+ def halveFrameRate(self):
+ raise NotImplementedError("The derived class needs to implement this method")
+
+ @Slot()
+ def resetFrameRate(self):
+ raise NotImplementedError("The derived class needs to implement this method")
+
+ def setMaxFrameRate(self, maxRate: float):
+ self._maxFrameRate = maxRate
+
+ def setMinFrameRate(self, minRate: float):
+ self._minFrameRate = minRate
+
+ def setCurrentTime(self, currentTime: Timecode):
+ if self._currentTime != currentTime:
+ self._currentTime = currentTime
+ self.timeChanged.emit(self._currentTime)
+
+ def currentTime(self) -> Timecode:
+ return self._currentTime
+
+ @Slot()
+ def quit(self):
+ raise NotImplementedError("The derived class needs to implement this method")
+
+ def disconnectSignals(self, bento: QObject):
+ pass
+
+
+class TimeSourceQTimer(TimeSourceAbstractBase):
+ """
+ Tick source coming from a QTimer
+ """
+
+ def __init__(self, notifyTimeChanged: Slot):
+ super().__init__(notifyTimeChanged)
+ self.default_frame_interval: float = 1000./30.
+ self.timer: QTimer = QTimer()
+ self.timer.setInterval(round(self.default_frame_interval))
+ self.timer.timeout.connect(self.doTick)
+
+ def start(self):
+ self.timer.start()
+
+ def stop(self):
+ self.timer.stop()
+
+ Slot()
+ def doubleFrameRate(self):
+ if self._frameRate * 2. <= self._maxFrameRate:
+ self._frameRate *= 2.
+ self.timer.setInterval(round(self.default_frame_interval / self._frameRate))
+ self.tickRateChanged.emit(self._frameRate)
+
+ @Slot()
+ def halveFrameRate(self):
+ if self._frameRate / 2. >= self._minFrameRate:
+ self._frameRate /= 2.
+ self.timer.setInterval(round(self.default_frame_interval / self._frameRate))
+ self.tickRateChanged.emit(self._frameRate)
+
+ @Slot()
+ def resetFrameRate(self):
+ if self._frameRate != 1.:
+ self._frameRate = 1.
+ self.timer.setInterval(round(self.default_frame_interval / self._frameRate))
+ self.tickRateChanged.emit(self._frameRate)
+
+ @Slot()
+ def doTick(self):
+ currentTime = self._currentTime + 1
+ self._currentTime = currentTime
+ self.timeChanged.emit(self._currentTime)
+
+ @Slot()
+ def quit(self):
+ if self.timer.isActive():
+ self.timer.stop()
+
+class TimeSourceQMediaPlayer(TimeSourceAbstractBase):
+ """
+ Tick source coming from a video player
+ This is a little tricky to avoid endless signal looping. The key to
+ understanding is that the QMediaPlayer is the source of time while it
+ is playing, and doesn't act on time updates, but acts on time updates
+ when the player is stopped, and doesn't propagate them. So it either
+ send or receives time updates, but never both.
+ """
+
+ #Signals
+ startCalled = Signal()
+ stopCalled = Signal()
+ quitCalled = Signal()
+
+ def __init__(self, notifyTimeChanged: Slot, scene: QGraphicsScene):
+ super().__init__(notifyTimeChanged)
+ assert isinstance(scene, VideoSceneNative)
+ self.scene = scene
+ self.player = scene.getPlayer()
+ self.player.positionChanged.connect(self.doTick)
+ self.startCalled.connect(scene.play)
+ self.stopCalled.connect(scene.stop)
+ self.tickRateChanged.connect(self.player.setPlaybackRate)
+ self.quitCalled.connect(self.player.stop)
+
+ @Slot(int)
+ def doTick(self, msec: int):
+ # We need to avoid a recursive "set time" loop with the media player,
+ # so temporarily disconnect the signal from the slot to break the loop
+ if self.player.state() == QMediaPlayer.PlayingState:
+ self._currentTime = Timecode(self._currentTime.framerate, start_seconds=msec / 1000.)
+ self.timeChanged.emit(self._currentTime)
+
+ def start(self):
+ self.startCalled.emit()
+
+ def stop(self):
+ self.stopCalled.emit()
+
+ @Slot()
+ def doubleFrameRate(self):
+ if self._frameRate * 2. <= self._maxFrameRate:
+ self._frameRate *= 2.
+ self.tickRateChanged.emit(self._frameRate)
+
+ @Slot()
+ def halveFrameRate(self):
+ if self._frameRate / 2. >= self._minFrameRate:
+ self._frameRate /= 2.
+ self.tickRateChanged.emit(self._frameRate)
+
+ @Slot()
+ def resetFrameRate(self):
+ if self._frameRate != 1.:
+ self._frameRate = 1.
+ self.tickRateChanged.emit(self._frameRate)
+
+ def quit(self):
+ self.quitCalled.emit()
+
+ def setCurrentTime(self, currentTime: Timecode):
+ if self.player.state() != QMediaPlayer.PlayingState:
+ super().setCurrentTime(currentTime)
+
+ def disconnectSignals(self, bento: QObject):
+ self.player.stop()
+ self.scene.setIsTimeSource(False)
+ self.player.positionChanged.disconnect(self.doTick)
+ bento.timeChanged.disconnect(self.scene.updateFrame)
+ self.quitCalled.disconnect(self.player.stop)
+ self.tickRateChanged.disconnect(self.player.setPlaybackRate)
+ self.stopCalled.disconnect(self.scene.stop)
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index ef0ed5e..efdc94a 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -3,6 +3,8 @@
from os.path import abspath, sep
from qtpy.QtCore import QMarginsF, QRectF
from qtpy.QtGui import QColor
+from bisect import bisect_left
+from math import floor, log10, pow
def fix_path(path):
return abspath(path.replace('\\', sep).replace('/', sep))
@@ -11,6 +13,35 @@ def fix_path(path):
def padded_rectf(rectf: QRectF):
return rectf + QMarginsF(SCENE_PADDING, 0., SCENE_PADDING, 0.)
+def take_nearest(myList: list, value: float) -> float:
+ """
+ Return nearest entry in myList to value
+ """
+ pos = bisect_left(myList, value)
+ if pos == 0:
+ return myList[0]
+ if pos == len(myList):
+ return myList[-1]
+ before = myList[pos - 1]
+ after = myList[pos]
+ if after - value < value - before:
+ return after
+ else:
+ return before
+
+def quantizeTicksScale(ticksScale: float) -> float:
+ """
+ Quantize the ticks scale so that the leading digit is 1, 2 or 5
+ and all the other digits are 0, e.g. 2, 500, 10, 2000.
+ The minimum tick scale is 1.0 (that is, a tick every second)
+ """
+ fact = pow(10., floor(log10(ticksScale)))
+ sigDig = take_nearest([1., 2., 5.], ticksScale / fact)
+ return sigDig * fact
+
+def round_to_3(x):
+ return round(x, -(int(floor(log10(abs(x))))-2))
+
cm_data_parula = [
[ 0.26710521, 0.03311059, 0.6188155 ],
[ 0.26493929, 0.04780926, 0.62261795],
@@ -786,11 +817,23 @@ def padded_rectf(rectf: QRectF):
[ 0.47960, 0.01583, 0.01055]]
def get_colormap(colormap_name: str) -> list:
+ """
+ Return a color map (list of QColor) given the color map name.
+
+ Args:
+ colormap_name (str): The name of a supported colormap.
+
+ Returns:
+ A list of QColor corresponding to the color map.
+
+ Raises:
+ Exception: If the color map name is not among those supported.
+ """
if colormap_name.lower() == "parula":
cm_data = cm_data_parula
elif colormap_name.lower() == "turbo":
cm_data = cm_data_turbo
- elif colormap_name.lower == "viridis":
+ elif colormap_name.lower() == "viridis":
cm_data = cm_data_viridis
else:
raise Exception(f"get_colormap: colormap name {colormap_name} not supported")
diff --git a/src/video/mp4Io.py b/src/video/mp4Io.py
index 60cab12..c6c64cd 100644
--- a/src/video/mp4Io.py
+++ b/src/video/mp4Io.py
@@ -1,6 +1,7 @@
import cv2
import numpy as np
import os
+from qtpy.QtGui import QImage, QPixmap
class mp4Io_reader():
def __init__(self, filename, info=[]):
@@ -28,7 +29,7 @@ def seek(self, index):
self.file.set(cv2.CAP_PROP_POS_FRAMES, index)
def getTs(self,n=None):
- if n==None:
+ if n==None:
n = self.header['numFrames']
ts = np.zeros(n+1)
@@ -48,6 +49,13 @@ def getFrame(self, index, decode=True):
ts = self.file.get(cv2.CAP_PROP_POS_MSEC)/1000.
return frame, ts
+ def getFrameAsQPixmap(self, index, decode=True):
+ image, _ = self.getFrame(index, decode)
+ h, w, ch = image.shape
+ bytes_per_line = ch * w
+ convert_to_Qt_format = QImage(image.data, w, h, bytes_per_line, QImage.Format_BGR888)
+ return QPixmap.fromImage(convert_to_Qt_format)
+
def close(self):
self.file.release()
diff --git a/src/video/seqIo.py b/src/video/seqIo.py
index ff2c576..3f3ef4d 100755
--- a/src/video/seqIo.py
+++ b/src/video/seqIo.py
@@ -14,6 +14,7 @@
import pdb
import cv2
import progressbar as pb
+from qtpy.QtGui import QPixmap
# Create interface sr for reading seq files.
# sr = seqIo_reader( fName )
@@ -46,7 +47,7 @@ def fread(fid, nelements, dtype):
"""Equivalent to Matlab fread function"""
- if dtype is np.str:
+ if dtype is np.str_:
dt = np.uint8 # WARNING: assuming 8-bit ASCII for np.str!
else:
dt = dtype
@@ -56,9 +57,9 @@ def fread(fid, nelements, dtype):
data_array = data_array[0]
return data_array
-def fwrite(fid,a,dtype=np.str):
+def fwrite(fid,a,dtype=np.str_):
# assuming 8but ASCII for string
- if dtype is np.str:
+ if dtype is np.str_:
dt = np.uint8 # WARNING: assuming 8-bit ASCII for np.str!
else:
dt = dtype
@@ -654,6 +655,15 @@ def getFrame(self,index,decode=True):
ts = fread(self.file,1,np.uint32)+fread(self.file,1,np.uint16)/1000.
return np.array(I), ts
+ def getFrameAsQPixmap(self, index, decode=True):
+ """
+ Get the frame in the form of a Qt QPixmap
+ """
+ image, _ = self.getFrame(index, decode)
+ pixmap = QPixmap()
+ pixmap.loadFromData(image.tobytes())
+ return pixmap
+
# Close the file
def close(self):
self.file.close()
diff --git a/src/video/videoScene.py b/src/video/videoScene.py
new file mode 100644
index 0000000..460a0ce
--- /dev/null
+++ b/src/video/videoScene.py
@@ -0,0 +1,293 @@
+# videoScene.py
+"""
+Implement base class and derived classes for showing videos with pose and annotation overlays
+"""
+from pose.pose import PoseBase
+import video.seqIo as seqIo
+import video.mp4Io as mp4Io
+from qtpy.QtCore import QMargins, QObject, QRectF, Qt, QUrl, Slot
+from qtpy.QtGui import QBrush, QFontMetrics, QPainter, QPixmap
+from qtpy.QtWidgets import QGraphicsScene, QGraphicsItem
+from qtpy.QtMultimedia import QMediaContent, QMediaPlayer, QVideoSurfaceFormat
+from qtpy.QtMultimediaWidgets import QGraphicsVideoItem
+from dataExporter import DataExporter
+from pynwb import NWBFile
+from timecode import Timecode
+import os
+from typing import List
+
+class VideoSceneAbstractBase(QGraphicsScene, DataExporter):
+ """
+ An abstract scene that knows how to draw annotations text and pose data
+ into its foreground
+ """
+
+ @staticmethod
+ def supportedFormats() -> List:
+ raise NotImplementedError("Derived class needs to override this method")
+
+ def __init__(self, bento: QObject, start_time: Timecode, parent: QObject=None):
+ QGraphicsScene.__init__(self, parent)
+ DataExporter.__init__(self)
+ self.bento = bento
+ self.annots = None
+ self.pose_class = None
+ self.showPoseData = False
+ self.frame_ix = 0
+ self._start_time = start_time
+ self._frameWidth = 0.
+ self._frameHeight = 0.
+ self._aspectRatio = 0.
+
+ def setAnnots(self, annots: list):
+ self.annots = annots
+
+ def setPoseClass(self, pose_class: PoseBase):
+ self.pose_class = pose_class
+
+ def setShowPoseData(self, showPoseData: bool):
+ self.showPoseData = showPoseData
+
+ def drawPoses(self, painter: QPainter, frame_ix: int):
+ if self.showPoseData and self.pose_class:
+ self.pose_class.drawPoses(painter, frame_ix)
+
+ def exportToNWBFile(self, nwbFile: NWBFile):
+ if self.pose_class:
+ nwbFile = self.pose_class.exportPosesToNWBFile(nwbFile)
+ return nwbFile
+
+ def setStartTime(self, t: Timecode):
+ self._start_time = t
+
+ def drawForeground(self, painter: QPainter, rect: QRectF):
+ # add poses
+ self.drawPoses(painter, self.frame_ix)
+ # add annotations
+ font = painter.font()
+ pointSize = font.pointSize()+10
+ font.setPointSize(pointSize)
+ painter.setFont(font)
+ margins = QMargins(10, 10, 0, 0)
+ fm = QFontMetrics(font)
+ flags = Qt.AlignLeft | Qt.AlignTop
+ rectWithMargins = rect.toRect()
+ rectWithMargins -= margins
+ whiteBrush = QBrush(Qt.white)
+ blackBrush = QBrush(Qt.black)
+ if self.annots:
+ for annot in self.annots:
+ painter.setBrush(whiteBrush if annot[2].lightnessF() < 0.5 else blackBrush)
+ text = annot[0] + ": " + annot[1]
+ bounds = fm.boundingRect(rectWithMargins, flags, text)
+ painter.setPen(Qt.NoPen)
+ painter.drawRect(bounds)
+ painter.setPen(annot[2])
+ painter.drawText(bounds, text)
+ margins.setTop(margins.top() + pointSize + 3)
+ rectWithMargins = rect.toRect()
+ rectWithMargins -= margins
+
+ def frameWidth(self) -> float:
+ return self._frameWidth
+
+ def frameHeight(self) -> float:
+ return self._frameHeight
+
+ def aspectRatio(self) -> float:
+ return self._aspectRatio
+
+ def videoItem(self) -> QGraphicsItem:
+ raise NotImplementedError("Derived class needs to override this method")
+
+ def running_time(self) -> float:
+ raise NotImplementedError("Derived class needs to override this method")
+
+ def sample_rate(self) -> float:
+ raise NotImplementedError("Derived class needs to override this method")
+
+ @Slot(Timecode)
+ def updateFrame(self, t: Timecode):
+ raise NotImplementedError("Derived class needs to override this method")
+
+ def getPlayer(self) -> QMediaPlayer:
+ return None
+
+ def reset(self):
+ pass
+
+class VideoSceneNative(VideoSceneAbstractBase):
+ """
+ A scene that knows how to play video in various standard formats
+ Avoiding endless update loops is tricky. The player's position is only set
+ explicitly when the player is *not* playing.
+ """
+
+ @staticmethod
+ def supportedFormats() -> List:
+ return ['mp4', 'avi']
+
+ # update current time every 1/10 second
+ time_update_msec: int = round(1000 / 10)
+
+ def __init__(self, bento: QObject, start_time: Timecode, parent: QObject=None):
+ super().__init__(bento, start_time, parent)
+ self.player = QMediaPlayer()
+ self.playerItem = QGraphicsVideoItem()
+ self.player.setVideoOutput(self.playerItem)
+ self.addItem(self.playerItem)
+ self.player.durationChanged.connect(bento.noteVideoDurationChanged)
+ self.player.setNotifyInterval(self.time_update_msec)
+ self.frameRate = 30.0
+ self.duration = 0.0
+ self._isTimeSource = False
+ self._running = False
+
+ def drawPoses(self, painter: QPainter, frame_ix: int):
+ super().drawPoses(painter, frame_ix)
+
+ def drawForeground(self, painter: QPainter, rect: QRectF):
+ self.frame_ix = round(self.player.position() * self.frameRate / 1000.) # position() is in msec
+ if self.playerItem:
+ videoBounds = self.playerItem.boundingRect()
+ videoNativeSize = self.playerItem.nativeSize()
+ if not videoNativeSize.isEmpty() and not videoBounds.isEmpty():
+ # prevent divide by zero or pointless work
+ painter.save()
+ sx = videoBounds.width() / videoNativeSize.width()
+ sy = videoBounds.height() / videoNativeSize.height()
+ painter.scale(sx, sy)
+ super().drawForeground(painter, rect)
+ painter.restore()
+
+ @Slot(QVideoSurfaceFormat)
+ def noteSurfaceFormatChanged(self, surfaceFormat: QVideoSurfaceFormat):
+ frameRate = surfaceFormat.frameRate()
+ self._frameWidth = surfaceFormat.frameWidth()
+ self._frameHeight = surfaceFormat.frameHeight()
+ self._aspectRatio = surfaceFormat.pixelAspectRatio()
+ if frameRate > 0.:
+ print(f"Setting frameRate to {frameRate}")
+ self.frameRate = frameRate
+
+ def setVideoPath(self, videoPath: str):
+ if self.duration == 0.0:
+ reader = mp4Io.mp4Io_reader(videoPath)
+ self.duration = float(reader.header['numFrames']) / float(reader.header['fps'])
+ self.player.setMedia(QUrl.fromLocalFile(videoPath))
+ self.playerItem.videoSurface().surfaceFormatChanged.connect(self.noteSurfaceFormatChanged)
+ # force the player to load the media
+ self.player.play()
+ self.player.pause()
+ # reset to beginning
+ self.player.setPosition(0)
+ # do some other setup
+ frameRate = self.playerItem.videoSurface().surfaceFormat().frameRate()
+ if frameRate > 0:
+ self.frameRate = frameRate
+ self.bento.timeChanged.connect(self.updateFrame)
+
+ @Slot(Timecode)
+ def updateFrame(self, t: Timecode):
+ """
+ Only act on external time updates if we're not currently playing.
+ This is to avoid an endless time update loop when this player is
+ acting as the timeSource for bento.
+ Note that the check for self._running is needed to avoid a problem when
+ playing has just stopped and we set bento's current time (below).
+ """
+ if self.player.state() != QMediaPlayer.PlayingState and not self._running:
+ self.player.setPosition(round(t.float * 1000.))
+
+ @Slot()
+ def play(self):
+ self._running = True
+ self.player.play()
+
+ @Slot()
+ def stop(self):
+ self.player.pause()
+ if self._isTimeSource:
+ self.bento.set_time(Timecode(30.0, start_seconds=self.player.position()/1000.))
+ self._running = False
+
+ @Slot(float)
+ def setPlaybackRate(self, rate: float):
+ self.player.setPlaybackRate(rate)
+
+ def videoItem(self) -> QGraphicsItem:
+ return self.playerItem
+
+ def running_time(self) -> float:
+ return float(self.duration)
+
+ def sample_rate(self) -> float:
+ if self.frameRate == 0.:
+ return 30.0
+ else:
+ return self.frameRate
+
+ def setIsTimeSource(self, isTimeSource: bool):
+ self._isTimeSource = isTimeSource
+
+ def getPlayer(self) -> QMediaPlayer:
+ return self.player
+
+ def reset(self):
+ if self.player:
+ self.player.setMedia(QMediaContent(None))
+
+class VideoScenePixmap(VideoSceneAbstractBase):
+ """
+ A scene that knows how to play videos in Caltech Anderson Lab .seq format
+ """
+
+ @staticmethod
+ def supportedFormats() -> List:
+ return ['seq', 'mp4', 'avi']
+
+ def __init__(self, bento: QObject, start_time: Timecode, parent: QObject=None):
+ super().__init__(bento, start_time, parent)
+ self.reader = None
+ self.frame_ix: int = 0
+ self.pixmap = QPixmap()
+ self.pixmapItem = self.addPixmap(self.pixmap)
+ # self.region = None
+
+ def setVideoPath(self, videoPath: str):
+ _, ext = os.path.splitext(videoPath)
+ ext = ext.lower()
+ print(videoPath)
+ if ext == '.seq':
+ self.reader = seqIo.seqIo_reader(videoPath)
+ elif ext in ['.avi', '.mp4']:
+ self.reader = mp4Io.mp4Io_reader(videoPath)
+ else:
+ raise ValueError("Expected .seq file")
+ self._frameWidth = self.reader.header['width']
+ self._frameHeight = self.reader.header['height']
+ self._aspectRatio = self._frameHeight / self._frameWidth
+ self.bento.timeChanged.connect(self.updateFrame)
+
+ @Slot(Timecode)
+ def updateFrame(self, t: Timecode):
+ if not self.reader or t < self._start_time:
+ return
+ myTc = Timecode(self.reader.header['fps'], start_seconds = t.float)
+ self.frame_ix = min(myTc.frames, self.reader.header['numFrames']-1)
+ self.pixmap = self.reader.getFrameAsQPixmap(self.frame_ix, decode=False)
+ self.pixmapItem.setPixmap(self.pixmap)
+
+ def videoItem(self) -> QGraphicsItem:
+ return self.pixmapItem
+
+ def running_time(self) -> float:
+ if not self.reader:
+ return 0.
+ return float(self.reader.header['numFrames']) / float(self.reader.header['fps'])
+
+ def sample_rate(self) -> float:
+ if not self.reader:
+ return 30.0
+ else:
+ return self.reader.header['fps']
diff --git a/src/video/videoWindow.py b/src/video/videoWindow.py
index 70e7d72..9a5764e 100644
--- a/src/video/videoWindow.py
+++ b/src/video/videoWindow.py
@@ -1,78 +1,24 @@
# videoWindow.py
from video.videoWindow_ui import Ui_videoFrame
-import video.seqIo as seqIo
-import video.mp4Io as mp4Io
-from qtpy.QtCore import Signal, Slot, QMargins, QPointF, Qt
-from qtpy.QtGui import QBrush, QFontMetrics, QPen, QPixmap, QImage, QPolygonF
-from qtpy.QtWidgets import QFrame, QGraphicsScene
-from timecode import Timecode
-import numpy as np
+from video.videoScene import VideoSceneAbstractBase, VideoSceneNative, VideoScenePixmap
+from qtpy.QtCore import QEvent, Qt, Signal, Slot
+from qtpy.QtWidgets import QFrame
+from qtpy.QtMultimedia import QMediaPlayer
+from dataExporter import DataExporter
+from pynwb import NWBFile
import os
-import time
-
-class VideoScene(QGraphicsScene):
- """
- A scene that knows how to draw annotations text into its foreground
- """
-
- def __init__(self, parent=None):
- super().__init__(parent)
- self.annots = None
- self.pose_class = None
- self.pose_frame_ix = 0
- self.showPoseData = False
-
- def setAnnots(self, annots):
- self.annots = annots
-
- def setPoseClass(self, pose_class):
- self.pose_class = pose_class
-
- def setShowPoseData(self, showPoseData):
- self.showPoseData = showPoseData
-
- def setPoseFrameIx(self, ix):
- self.pose_frame_ix = ix
-
- def drawPoses(self, painter):
- if self.showPoseData and self.pose_class:
- self.pose_class.drawPoses(painter, self.pose_frame_ix)
-
- def drawForeground(self, painter, rect):
- # add poses
- self.drawPoses(painter)
- # add annotations
- font = painter.font()
- pointSize = font.pointSize()+10
- font.setPointSize(pointSize)
- painter.setFont(font)
- margins = QMargins(10, 10, 0, 0)
- fm = QFontMetrics(font)
- flags = Qt.AlignLeft | Qt.AlignTop
- rectWithMargins = rect.toRect()
- rectWithMargins -= margins
- whiteBrush = QBrush(Qt.white)
- blackBrush = QBrush(Qt.black)
- for annot in self.annots:
- painter.setBrush(whiteBrush if annot[2].lightnessF() < 0.5 else blackBrush)
- text = annot[0] + ": " + annot[1]
- bounds = fm.boundingRect(rectWithMargins, flags, text)
- painter.setPen(Qt.NoPen)
- painter.drawRect(bounds)
- painter.setPen(annot[2])
- painter.drawText(bounds, text)
- margins.setTop(margins.top() + pointSize + 3)
- rectWithMargins = rect.toRect()
- rectWithMargins -= margins
-
-class VideoFrame(QFrame):
+from timecode import Timecode
+
+class VideoFrame(QFrame, DataExporter):
openReader = Signal(str)
quitting = Signal()
def __init__(self, bento):
- super().__init__()
+ QFrame.__init__(self)
+ DataExporter.__init__(self)
+ self.dataExportType = "video"
self.bento = bento
self.sizePolicy().setHeightForWidth(True)
self.ui = Ui_videoFrame()
@@ -81,17 +27,16 @@ def __init__(self, bento):
bento.quitting.connect(self.close)
# data related to video
- self.reader = None
- self.scene = VideoScene()
- self.ui.videoView.setScene(self.scene)
- self.ui.showPoseCheckBox.stateChanged.connect(self.showPoseDataChanged)
- self.pixmap = QPixmap()
- self.pixmapItem = self.scene.addPixmap(self.pixmap)
- self.active_annots = []
+ self.scene = None
self.aspect_ratio = 1.
+ self.active_annots = []
+
+ def resizeFrame(self):
+ self.ui.videoView.fitInView(self.scene.videoItem(), aspectRadioMode=Qt.KeepAspectRatio)
+
def resizeEvent(self, event):
- self.ui.videoView.fitInView(self.pixmapItem, aspectRadioMode=Qt.KeepAspectRatio)
+ self.resizeFrame()
def mouseReleaseEvent(self, event):
viewport = self.ui.videoView.viewport()
@@ -106,38 +51,43 @@ def mouseReleaseEvent(self, event):
# too tall
print("too tall")
- def load_video(self, fn):
- self.ext = os.path.basename(fn).rsplit('.',1)[-1]
- if self.ext=='mp4' or self.ext=='avi':
- self.reader = mp4Io.mp4Io_reader(fn)
- elif self.ext=='seq':
- self.reader = seqIo.seqIo_reader(fn)
+ def supported_by_native_player(self, fn: str) -> bool:
+ ext = os.path.basename(fn).rsplit('.',1)[-1].lower()
+ if ext in VideoSceneNative(self.bento, self.bento.current_time()).supportedFormats():
+ return True
+ if not ext in VideoScenePixmap(self.bento, self.bento.current_time()).supportedFormats():
+ raise Exception(f"video format {ext} not supported.")
+ return False
+
+ def load_video(self, fn: str, start_time: Timecode, forcePixmapMode: bool):
+ if not forcePixmapMode and self.supported_by_native_player(fn):
+ self.scene = VideoSceneNative(self.bento, start_time)
else:
- raise Exception(f"video format {self.ext} not supported.")
- frame_width = self.reader.header['width']
- frame_height = self.reader.header['height']
- self.aspect_ratio = float(frame_height) / float(frame_width)
+ self.scene = VideoScenePixmap(self.bento, start_time)
+ self.scene.setVideoPath(fn)
+ self.ui.videoView.setScene(self.scene)
+ self.ui.showPoseCheckBox.stateChanged.connect(self.showPoseDataChanged)
+ self.ui.videoView.show()
+ self.aspect_ratio = self.scene.aspectRatio()
print(f"aspect_ratio set to {self.aspect_ratio}")
- self.updateFrame(self.bento.current_time)
- self.ui.videoView.fitInView(self.pixmapItem, aspectRadioMode=Qt.KeepAspectRatio)
+ self.resizeFrame()
+ self.scene.updateFrame(self.bento.current_time())
def set_pose_class(self, pose_class):
self.scene.setPoseClass(pose_class)
self.ui.showPoseCheckBox.setEnabled(bool(pose_class))
self.scene.setShowPoseData(bool(pose_class) and self.ui.showPoseCheckBox.isChecked())
- def sample_rate(self):
- if not self.reader:
- return 30.0
- else:
- return self.reader.header['fps']
+ def set_start_time(self, t):
+ self.scene.setStartTime(t)
+
+ def sample_rate(self) -> float:
+ return self.scene.sample_rate()
- def running_time(self):
- if not self.reader:
- return 0.
- return float(self.reader.header['numFrames']) / float(self.reader.header['fps'])
+ def running_time(self) -> float:
+ return self.scene.running_time()
- def keyPressEvent(self, event):
+ def keyPressEvent(self, event: QEvent):
if event.key() == Qt.Key_Left:
if event.modifiers() & Qt.ShiftModifier:
self.bento.skipBackward()
@@ -158,39 +108,29 @@ def keyPressEvent(self, event):
return
event.accept()
- @Slot(Timecode)
- def updateFrame(self, t):
- if not self.reader:
- return
- myTc = Timecode(self.reader.header['fps'], start_seconds=t.float)
- i = min(myTc.frames, self.reader.header['numFrames']-1)
- image, _ = self.reader.getFrame(i, decode=False)
- if self.ext=='seq':
- self.pixmap.loadFromData(image.tobytes())
- self.pixmapItem.setPixmap(self.pixmap)
- elif self.ext=='mp4' or self.ext=='avi':
- h, w, ch = image.shape
- bytes_per_line = ch * w
- convert_to_Qt_format = QImage(image.data, w, h, bytes_per_line, QImage.Format_BGR888)
- convert_to_Qt_format = QPixmap.fromImage(convert_to_Qt_format)
- self.pixmapItem.setPixmap(convert_to_Qt_format)
- else:
- raise Exception(f"video format {self.ext} not supported")
-
- # get the frame number for this frame and set it into the scene,
- # whether we have and are showing pose data or not
- self.scene.setPoseFrameIx(myTc.frames)
-
- if isinstance(self.scene, VideoScene):
- self.scene.setAnnots(self.active_annots)
- self.show()
+ def reset(self):
+ if isinstance(self.scene, VideoSceneAbstractBase):
+ self.scene.reset()
@Slot(list)
- def updateAnnots(self, annots):
+ def updateAnnots(self, annots: list):
self.active_annots = annots
+ if isinstance(self.scene, VideoSceneAbstractBase):
+ self.scene.setAnnots(self.active_annots)
@Slot(Qt.CheckState)
- def showPoseDataChanged(self, showPoseData):
+ def showPoseDataChanged(self, showPoseData: Qt.CheckState):
if self.scene:
self.scene.setShowPoseData(bool(showPoseData))
- self.updateFrame(self.bento.current_time) # force redraw
+ self.scene.updateFrame(self.bento.current_time()) # force redraw
+
+ def getPlayer(self) -> QMediaPlayer:
+ if not self.scene:
+ return None
+ return self.scene.getPlayer()
+
+ def exportToNWBFile(self, nwbFile: NWBFile):
+ print(f"Export data from {self.dataExportType} to NWB file")
+ if self.scene:
+ nwbFile = self.scene.exportToNWBFile(nwbFile)
+ return nwbFile
diff --git a/src/widgets/annotationsWidget.py b/src/widgets/annotationsWidget.py
index 94a122d..c773abf 100644
--- a/src/widgets/annotationsWidget.py
+++ b/src/widgets/annotationsWidget.py
@@ -2,11 +2,12 @@
"""
"""
-from qtpy.QtCore import Qt, QPointF, QRectF, Slot
-from qtpy.QtGui import (QBrush, QPen, QKeyEvent, QMouseEvent,
+from qtpy.QtCore import Qt, QMargins, QPointF, QRectF, Slot
+from qtpy.QtGui import (QBrush, QFontMetrics, QPen, QKeyEvent, QMouseEvent,
QTransform, QWheelEvent)
from qtpy.QtWidgets import QGraphicsScene, QGraphicsView
from timecode import Timecode
+from utils import quantizeTicksScale, round_to_3
class AnnotationsView(QGraphicsView):
"""
@@ -18,7 +19,7 @@ class AnnotationsView(QGraphicsView):
from QtGraphicsScene
"""
- def __init__(self, parent=None):
+ def __init__(self, parent=None, showTickLabels = True):
super().__init__(parent)
self.bento = None
self.start_x = 0.
@@ -33,10 +34,14 @@ def __init__(self, parent=None):
self.ticksScale = 1.
self.setInteractive(False)
self.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
+ self.showTickLabels = showTickLabels
def set_bento(self, bento):
self.bento = bento
+ def set_showTickLabels(self, showTickLabels):
+ self.showTickLabels = showTickLabels
+
#def set_v_factor(self, v_factor):
# self.v_factor = self.height
@@ -69,6 +74,8 @@ def setScale(self, hScale: float, vScale: float) -> None:
def setHScale(self, hScale):
self.setTransformScale(self.transform(), scale_h=hScale)
+ initialTicksScale = 100./hScale
+ self.ticksScale = quantizeTicksScale(initialTicksScale)
def setVScale(self, vScale):
self.setTransformScale(self.transform(), scale_v=vScale)
@@ -77,7 +84,7 @@ def setVScale(self, vScale):
def setHScaleAndShow(self, hScale):
self.scale_h = hScale
self.setHScale(hScale)
- self.show()
+ self.updatePosition(self.bento.current_time()) # calls show()
@Slot(float)
def setVScaleAndShow(self, v_factor):
@@ -145,40 +152,83 @@ def maybeDrawPendingBout(self, painter, rect):
painter.setBrush(Qt.NoBrush)
def drawForeground(self, painter, rect):
+ """
+ drawForeground
+ Draws a pending annotation bout, if one is in the process of being added, deleted or changed,
+ and timing tick marks and labels, if enabled. Overrides the default drawForeground, which
+ does nothing.
+
+ In order to draw the tick labels, we need to change the transform to identity so that
+ the labels (and tick amrks themselves) don't scale when the user changes
+ the view's horizontal scale. To figure out where to draw the ticks and labels,
+ we need to get their positions in pixels before we change the transform, do the drawing
+ in pixel coordinates, and finally restore the transform.
+ """
self.maybeDrawPendingBout(painter, rect)
- # draw current time indicator
+
+ # gather the position data we need in pixel coordinates
+ # Note that mapFromScene requires both x and y coordinates, expressed in one of several ways,
+ # and also returns a QPointF
now = self.bento.get_time().float
pen = QPen(Qt.black)
pen.setWidth(0)
painter.setPen(pen)
+ nowTop = QPointF(now, rect.top())
+ nowBottom = QPointF(now, rect.bottom())
+ device_nowTop = self.mapFromScene(nowTop)
+ device_nowBottom = self.mapFromScene(nowBottom)
+ device_now = device_nowTop.x()
+ device_left, device_top = self.mapFromScene(rect.topLeft()).toTuple()
+ device_right, device_bottom = self.mapFromScene(rect.bottomRight()).toTuple()
+ device_ticksScale = self.mapFromScene(QPointF(now + self.ticksScale, rect.top())).x() - device_now
+
+ # transform to identity to facilitate drawing tick labels that don't scale
+ savedTransform = painter.transform()
+ painter.setTransform(QTransform())
painter.drawLine(
- QPointF(now, rect.top()),
- QPointF(now, rect.bottom())
+ device_nowTop,
+ device_nowBottom
)
+ # draw the time label
+ if self.showTickLabels:
+ font = painter.font()
+ fm = QFontMetrics(font)
+ text = '0.0'
+ tickLabelY = device_top + fm.ascent() + 2
+ painter.drawText(device_now + 4, tickLabelY, text)
# draw tick marks
if self.scene().loaded:
+ eighth = (device_bottom - device_top) / 8.
+ eighthDown = device_top + eighth
+ eighthUp = device_bottom - eighth
+ device_offset = device_ticksScale
offset = self.ticksScale
- eighth = (rect.bottom() - rect.top()) / 8.
- eighthDown = rect.top() + eighth
- eighthUp = rect.bottom() - eighth
- while now + offset < rect.right():
+ while device_now + device_offset < device_right:
painter.drawLine(
- QPointF(now + offset, rect.top()),
- QPointF(now + offset, eighthDown)
+ QPointF(device_now + device_offset, device_top),
+ QPointF(device_now + device_offset, eighthDown)
)
painter.drawLine(
- QPointF(now + offset, eighthUp),
- QPointF(now + offset, rect.bottom())
+ QPointF(device_now + device_offset, eighthUp),
+ QPointF(device_now + device_offset, device_bottom)
)
painter.drawLine(
- QPointF(now - offset, rect.top()),
- QPointF(now - offset, eighthDown)
+ QPointF(device_now - device_offset, device_top),
+ QPointF(device_now - device_offset, eighthDown)
)
painter.drawLine(
- QPointF(now - offset, eighthUp),
- QPointF(now - offset, rect.bottom())
+ QPointF(device_now - device_offset, eighthUp),
+ QPointF(device_now - device_offset, device_bottom)
)
- offset += self.ticksScale
+ if self.showTickLabels:
+ text = str(round_to_3(offset))
+ painter.drawText(device_now + device_offset + 4, tickLabelY, text)
+ text = "-" + text
+ painter.drawText(device_now - device_offset + 4, tickLabelY, text)
+ offset += self.ticksScale
+ device_offset += device_ticksScale
+ # restore original transform
+ painter.setTransform(savedTransform)
class AnnotationsScene(QGraphicsScene):
"""
diff --git a/src/widgets/deleteableViews.py b/src/widgets/deleteableViews.py
index 4c19054..1d6d430 100644
--- a/src/widgets/deleteableViews.py
+++ b/src/widgets/deleteableViews.py
@@ -2,7 +2,38 @@
from qtpy.QtCore import Qt
from qtpy.QtGui import QKeyEvent
-from qtpy.QtWidgets import QTableView, QMessageBox, QTreeWidget, QTreeWidgetItem
+from qtpy.QtWidgets import (QTableView, QMessageBox, QTreeWidget, QTreeWidgetItem, QComboBox,
+ QStyledItemDelegate, QLineEdit, QDoubleSpinBox)
+
+
+class OffsetTimeItemDelegate(QStyledItemDelegate):
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+
+ def createEditor(self, parent, option, index):
+ editor = super().createEditor(parent, option, index)
+ if (isinstance(editor, QDoubleSpinBox) and
+ index.model().headerData(index.column(), Qt.Horizontal, Qt.DisplayRole) == 'Offset Time'):
+ editor = QLineEdit(parent)
+ return editor
+ return editor
+
+class CustomComboBoxDelegate(QStyledItemDelegate):
+ def __init__(self, comboItems):
+ super().__init__()
+ self.comboItems = comboItems
+
+ def createEditor(self, parent, option, index):
+ editor = super().createEditor(parent, option, index)
+ if (isinstance(editor, QLineEdit) and
+ index.model().headerData(index.column(), Qt.Horizontal, Qt.DisplayRole) == 'Format'):
+ editor.close()
+ comboBox = QComboBox(parent)
+ comboBox.addItems(self.comboItems)
+ comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)
+ return comboBox
+ return editor
class DeleteableTableView(QTableView):
diff --git a/src/widgets/neuralWidget.py b/src/widgets/neuralWidget.py
index 037b2c8..25a9d88 100644
--- a/src/widgets/neuralWidget.py
+++ b/src/widgets/neuralWidget.py
@@ -4,16 +4,20 @@
from os import X_OK
from qtpy.QtCore import Qt, QPointF, QRectF, Signal, Slot
-from qtpy.QtWidgets import (QGraphicsItem, QGraphicsItemGroup, QGraphicsPathItem,
+from qtpy.QtWidgets import (QGraphicsItem, QGraphicsItemGroup, QGraphicsPathItem, QGraphicsPixmapItem,
QGraphicsScene, QGraphicsView, QMessageBox)
-from qtpy.QtGui import (QBrush, QColor, QImage, QMouseEvent, QPainterPath, QPen,
- QPixmap, QTransform, QWheelEvent)
+from qtpy.QtGui import (QBrush, QColor, QImage, QKeyEvent, QMouseEvent, QPainterPath, QPen,
+ QPixmap, QTransform, QWheelEvent, QPolygonF)
+from qtpy.QtCharts import QtCharts
import numpy as np
import pymatreader as pmr
from qimage2ndarray import gray2qimage
from timecode import Timecode
-from utils import get_colormap, padded_rectf
+from utils import get_colormap, padded_rectf, quantizeTicksScale
+from pynwb import NWBFile, TimeSeries
import warnings
+import shiboken2 as shiboken
+import ctypes
class QGraphicsSubSceneItem(QGraphicsItem):
"""
@@ -75,7 +79,6 @@ def __init__(self, parent=None):
self.start_transform = None
self.start_x = 0.
self.scale_h = 1.
- self.min_scale_v = 0.2
self.center_y = 0.
self.horizontalScrollBar().sliderReleased.connect(self.updateFromScroll)
self.verticalScrollBar().sliderReleased.connect(self.updateFromScroll)
@@ -90,6 +93,7 @@ def setScene(self, scene):
# give the sceneRect some additional padding
# so that the start and end can be centered in the view
super().setScene(scene)
+ scene.setParent(self)
self.time_x = Timecode(str(scene.sample_rate), '0:0:0:1')
self.center_y = float(scene.num_chans) / 2.
@@ -107,33 +111,46 @@ def setTransformScaleV(self, t, scale_v: float):
)
self.setTransform(t, combine=False)
+ def setTransformScaleH(self, t, scale_h: float):
+ t.setMatrix(
+ scale_h,
+ t.m12(),
+ t.m13(),
+ t.m21(),
+ t.m22(),
+ t.m23(),
+ t.m31(),
+ t.m32(),
+ t.m33()
+ )
+ self.setTransform(t, combine=False)
+
def resizeEvent(self, event):
oldHeight = float(event.oldSize().height())
if oldHeight < 0.:
return
newHeight = float(event.size().height())
- self.min_scale_v *= newHeight / oldHeight
+ min_scale_v = newHeight / self.sceneRect().height()
t = self.transform()
- if t.m22() < self.min_scale_v:
- self.setTransformScaleV(t, self.min_scale_v)
- self.update()
+ if t.m22() < min_scale_v:
+ self.setTransformScaleV(t, min_scale_v)
+ self.updatePosition(self.bento.current_time())
+ self.synchronizeHScale()
@Slot()
def updateScene(self):
self.sample_rate = self.scene().sample_rate
- # self.time_x.framerate(self.sample_rate)
self.center_y = self.scene().height() / 2.
- scale_v = max(self.viewport().height() / self.scene().height(), self.min_scale_v)
+ scale_v = self.viewport().height() / self.scene().height()
self.scale(10., scale_v)
- self.min_scale_v = self.transform().m22()
- self.updatePosition(self.bento.current_time)
+ self.updatePosition(self.bento.current_time())
@Slot(Timecode)
def updatePosition(self, t):
pt = QPointF(t.float, self.center_y)
self.centerOn(pt)
- self.show()
+ self.update()
def synchronizeHScale(self):
self.hScaleChanged.emit(self.transform().m11())
@@ -141,9 +158,10 @@ def synchronizeHScale(self):
def mousePressEvent(self, event):
assert isinstance(event, QMouseEvent)
assert self.bento
- assert not self.transform().isRotating()
- self.start_transform = QTransform(self.transform())
- self.scale_h = self.transform().m11()
+ t = self.transform()
+ assert not t.isRotating()
+ self.start_transform = QTransform(t)
+ self.scale_h = t.m11()
self.start_x = event.localPos().x()
self.start_y = event.localPos().y()
self.time_x = self.bento.get_time()
@@ -153,11 +171,17 @@ def mouseMoveEvent(self, event):
assert isinstance(event, QMouseEvent)
assert self.bento
if event.modifiers() & Qt.ShiftModifier:
- factor_x = event.localPos().x() / self.start_x
+ factor_x = max(0.1, event.localPos().x()) / self.start_x
factor_y = event.localPos().y() / self.start_y
t = QTransform(self.start_transform)
t.scale(factor_x, factor_y)
- self.setTransformScaleV(t, min(64., max(self.min_scale_v, t.m22())))
+ min_scale_v = self.viewport().rect().height() / self.sceneRect().height()
+ self.setTransformScaleV(t, max(min_scale_v, t.m22()))
+ min_scale_h = self.viewport().rect().width() / self.sceneRect().width()
+ h_scale = max(min_scale_h, t.m11())
+ initialTicksScale = 100./h_scale
+ self.ticksScale = quantizeTicksScale(initialTicksScale)
+ self.setTransformScaleH(t, h_scale)
self.synchronizeHScale()
else:
x = event.localPos().x() / self.scale_h
@@ -169,6 +193,11 @@ def mouseMoveEvent(self, event):
start_seconds=self.time_x.float + (start_x - x)
))
+ def keyPressEvent(self, event: QKeyEvent) -> None:
+ # Override the widget behavior on key strokes
+ # to let the parent window handle the event
+ event.ignore()
+
def updateFromScroll(self):
assert self.bento
center = self.viewport().rect().center()
@@ -179,11 +208,6 @@ def updateFromScroll(self):
start_seconds=sceneCenter.x()
))
- def wheelEvent(self, event):
- assert isinstance(event, QWheelEvent)
- super(NeuralView, self).wheelEvent(event)
- self.updateFromScroll()
-
def drawForeground(self, painter, rect):
now = self.bento.get_time().float
pen = QPen(Qt.white if self.scene().heatmap.isVisible() else Qt.black)
@@ -247,6 +271,8 @@ def __init__(self):
self.heatmap = None
self.annotations = None
self.activeChannel = None
+ self.tracesCache = dict()
+ self.tracesOrder = None
"""
.mat files can be either old-style (MatLab 7.2 and earlier), in which case we need to use
@@ -268,43 +294,43 @@ def loadNeural(self, ca_file, sample_rate, start_frame, stop_frame, time_start,
warnings.simplefilter('ignore', category=UserWarning)
mat = pmr.read_mat(ca_file)
try:
- data = mat['results']['C_raw']
+ self.data = mat['results']['C_raw']
except Exception as e:
QMessageBox.about(self, "Load Error", f"Error loading neural data from file {ca_file}: {e}")
return
- self.range = data.max() - data.min()
+ self.range = self.data.max() - self.data.min()
# Provide for a little space between traces
- self.minimum = data.min() + self.range * 0.05
+ self.minimum = self.data.min() + self.range * 0.05
self.range *= 0.9
self.sample_rate = sample_rate
self.start_frame = start_frame
self.stop_frame = stop_frame
- self.num_chans = data.shape[0]
- self.data_min = data.min()
- self.data_max = data.max()
+ self.num_chans = self.data.shape[0]
+ self.data_min = self.data.min()
+ self.data_max = self.data.max()
self.colorMapper = NeuralColorMapper(self.data_min, self.data_max, "parula")
# for chan in range(self.num_chans):
- for chan in range(self.num_chans):
- self.loadChannel(data, chan)
-
- # Image has a pixel for each frame for each channel
- self.heatmapImage = self.colorMapper.mappedImage(data)
- self.heatmap = self.addPixmap(QPixmap.fromImageInPlace(self.heatmapImage, Qt.NoFormatConversion))
-
- # Scale the heatmap's time axis by the 1 / sample rate so that it corresponds correctly
- # to the time scale
- transform = QTransform()
- transform.scale(1. / self.sample_rate, 1.)
- self.heatmap.setTransform(transform)
- self.heatmap.setOpacity(0.5)
+ self.t_values = ((np.arange(self.data.shape[1]) - self.start_frame)/self.sample_rate) + self.time_start.float
+ self.y_values = np.arange(self.num_chans).reshape(-1,1) + 0.5 + self.normalize(self.data)
+ self.tracesOrder = self.y_values[:, self.start_frame]
+ self.drawTraces(self.t_values[self.start_frame:self.stop_frame],
+ self.y_values[:,self.start_frame:self.stop_frame])
+ #for chan in range(self.num_chans):
+ # self.loadChannel(self.data, chan)
+ self.putTracesIntoAGroup()
+ # create heatmap
+ heatmapData = self.data[:,self.start_frame:self.stop_frame]
+ self.createHeatmap(heatmapData) #np.arange(self.num_chans)
# finally, add the traces on top of everything
self.addItem(self.traces)
# pad some time on left and right to allow centering
sceneRect = padded_rectf(self.sceneRect())
- sceneRect.setY(-1.)
sceneRect.setHeight(float(self.num_chans) + 1.)
self.setSceneRect(sceneRect)
+ self.setVisibility(showTraces, showHeatmap, showAnnotations)
+
+ def setVisibility(self, showTraces, showHeatmap, showAnnotations):
if isinstance(self.traces, QGraphicsItem):
self.traces.setVisible(showTraces)
if isinstance(self.heatmap, QGraphicsItem):
@@ -315,26 +341,103 @@ def loadNeural(self, ca_file, sample_rate, start_frame, stop_frame, time_start,
self.heatmap.setVisible(showHeatmap)
if isinstance(self.annotations, QGraphicsItem):
self.annotations.setVisible(showAnnotations)
+
+ def createHeatmap(self, data):
+ self.heatmapImage, self.heatmap = None, None
+ # Image has a pixel for each frame for each channel
+ self.heatmapImage = self.colorMapper.mappedImage(data)
+ self.heatmap = self.addPixmap(QPixmap.fromImageInPlace(self.heatmapImage, Qt.NoFormatConversion))
- def loadChannel(self, data, chan):
+ # Scale the heatmap's time axis by the 1 / sample rate so that it corresponds correctly
+ # to the time scale
+ transform = QTransform()
+ transform.scale(1. / self.sample_rate, 1.)
+ self.heatmap.setTransform(transform)
+ self.heatmap.setOpacity(0.5)
+
+ def drawTraces(self, t_values, y_values):
+ for chan in range(self.num_chans):
+ trace = QPainterPath()
+ trace.reserve(self.stop_frame - self.start_frame)
+ trace.addPolygon(self.createPoly(t_values, y_values[chan, :]))
+ self.tracesCache[str(chan)] = trace
+
+ def createPoly(self, x_values, y_values):
+ if not (x_values.size == y_values.size == x_values.shape[0] == y_values.shape[0]):
+ raise ValueError("Arguments must be 1D NumPy arrays with same size")
+ size = x_values.size
+ poly = QPolygonF(size)
+ address = shiboken.getCppPointer(poly.data())[0]
+ buffer = (ctypes.c_double * 2 * size).from_address(address)
+ memory = np.frombuffer(buffer, np.float64)
+ memory[: (size - 1) * 2 + 1 : 2] = np.array(x_values, dtype=np.float64, copy=False)
+ memory[1 : (size - 1) * 2 + 2 : 2] = np.array(y_values, dtype=np.float64, copy=False)
+
+ return poly
+
+ def putTracesIntoAGroup(self):
+ self.traces = QGraphicsItemGroup()
pen = QPen()
- pen.setWidth(0)
- trace = QPainterPath()
- trace.reserve(self.stop_frame - self.start_frame + 1)
- y = float(chan) + self.normalize(data[chan][self.start_frame])
- trace.moveTo(self.time_start.float, y)
- time_start_float = self.time_start.float
-
- for ix in range(self.start_frame + 1, self.stop_frame):
- t = (ix - self.start_frame)/self.sample_rate + time_start_float
- val = self.normalize(data[chan][ix])
- # Add a section to the trace path
- y = float(chan) + val
- trace.lineTo(t, y)
- traceItem = QGraphicsPathItem(trace)
- traceItem.setPen(pen)
- self.traces.addToGroup(traceItem)
+ pen.setWidthF(0)
+ for k in list(self.tracesCache.keys()):
+ traceItem = QGraphicsPathItem(self.tracesCache[k])
+ traceItem.setPen(pen)
+ self.traces.addToGroup(traceItem)
+ def addDarkLines(self, partitionIdx):
+ pen = QPen()
+ pen.setWidth(0.75)
+ pen.setColor(QColor('darkBlue'))
+ y_values = self.y_values[:,self.start_frame:self.stop_frame]
+ t_values = self.t_values[self.start_frame:self.stop_frame]
+ self.lines = dict()
+ for idx in partitionIdx[:-1]:
+ min = np.amin(y_values[idx-1, :])
+ midpt = min - 0.5
+ l = np.full((y_values.shape[1],), midpt)
+ line = QPainterPath()
+ line.reserve(self.stop_frame - self.start_frame)
+ line.addPolygon(self.createPoly(t_values, l))
+ self.lines[str(idx)] = line
+
+ for l in list(self.lines.keys()):
+ lineItem = QGraphicsPathItem(self.lines[l])
+ lineItem.setPen(pen)
+ self.traces.addToGroup(lineItem)
+
+ def reorderTracesAndHeatmap(self, newOrder, partitionIdx):
+ showTraces = self.parent().parent().ui.showTraceRadioButton.isChecked()
+ showHeatMap = self.parent().parent().ui.showHeatMapRadioButton.isChecked()
+ showAnnotations = self.parent().parent().ui.showAnnotationsCheckBox.isChecked()
+ if isinstance(newOrder, np.ndarray):
+ _, counts = np.unique(newOrder, return_counts=True)
+ else:
+ raise ValueError(f"newOrder array {newOrder} should be a Numpy array")
+ if np.any(counts > 1):
+ raise ValueError(f"Unique integer values required in newOrder array {newOrder}")
+ if newOrder.size != self.tracesOrder.size:
+ raise RuntimeError(f"Size of newOrder array and traces order should be the same.")
+ for item in self.items():
+ if isinstance(item, QGraphicsItemGroup):
+ self.removeItem(item)
+ elif isinstance(item, QGraphicsPixmapItem):
+ self.removeItem(item)
+ heatmapData = self.data[newOrder, self.start_frame:self.stop_frame]
+ darkLines = np.full((partitionIdx.shape[0], heatmapData.shape[1]), self.data_min - 100.)
+ newHeatmapData = np.insert(heatmapData,
+ partitionIdx,
+ darkLines,
+ axis=0)
+ self.createHeatmap(newHeatmapData)
+ offset = (self.tracesOrder[newOrder] - self.tracesOrder)
+ self.tracesOrder = self.tracesOrder + offset
+ for chan in list(self.tracesCache.keys()):
+ self.tracesCache[chan] = self.tracesCache[chan].translated(0., offset[int(chan)])
+ self.addTraces()
+ self.addDarkLines(partitionIdx)
+ self.addItem(self.traces)
+ self.setVisibility(showTraces, showHeatMap, showAnnotations)
+
def normalize(self, y_val):
return 1.0 - (y_val - self.minimum) / self.range
@@ -368,4 +471,15 @@ def showAnnotations(self, enabled):
if isinstance(self.annotations, QGraphicsItem):
self.annotations.setVisible(enabled)
if isinstance(self.heatmap, QGraphicsItem):
- self.heatmap.setOpacity(0.5 if enabled else 1.)
\ No newline at end of file
+ self.heatmap.setOpacity(0.5 if enabled else 1.)
+
+ def exportToNWBFile(self, nwbFile: NWBFile):
+ neuralData = TimeSeries(name=f"neural_data",
+ data = self.data[:,self.start_frame+1:self.stop_frame],
+ rate=self.sample_rate,
+ starting_time = self.time_start.float,
+ unit = "None",
+ )
+ nwbFile.add_acquisition(neuralData)
+
+ return nwbFile