diff --git a/README.md b/README.md index a63089b..066408e 100644 --- a/README.md +++ b/README.md @@ -13,14 +13,15 @@ The library also includes a command line interface for converting files from a g making capabilities for easy visualisation of the files. Opening the CLI - 1. In order to open it, navigate to path where the library is installed, in case of problems download you should download the project from github and follow the following instructions: + 1. If the install with pip worked perfectly, you can now type `aertb` in a terminal window and the CLI will open. + + 2. If you are installing it from Github: download you should download the project from github and follow the following instructions: - a `git clone ...` - - b Create a virual environment, if not installed run `pip install virtualenv`, + - b Create a virual environment, if venv is not installed run `pip install virtualenv`, then `python3 -m venv aertb_env` - - c On Linux/MacOS run `source aertb_env/bin/activate` + - c Run `source aertb_env/bin/activate` - d Run the following command: 'pip install -r requirements.txt' - - 3. Run `python3 .` or execute the `__main__.py` file + - e Open the cli with `python3 .` or with the `__main__.py` file Using the CLI 1. Once the CLI is open you get a a similar output on your terminal: diff --git a/aertb/core/__init__.py b/aertb/core/__init__.py index ff0cfd0..6681579 100644 --- a/aertb/core/__init__.py +++ b/aertb/core/__init__.py @@ -1,3 +1,4 @@ from .file_loader import FileLoader from .viz import make_gif -from .types import event_dtype \ No newline at end of file +from .types import event_dtype +from .hdf5tools import HDF5FileIterator \ No newline at end of file diff --git a/aertb/core/file_loader.py b/aertb/core/file_loader.py index c652ad8..56ee70e 100644 --- a/aertb/core/file_loader.py +++ b/aertb/core/file_loader.py @@ -11,10 +11,9 @@ # ============================================================================= -from os.path import join, isfile, isdir, splitext +from os.path import join, isfile, splitext from tqdm import tqdm import logging -import click import h5py import os @@ -30,6 +29,11 @@ # ============================================================================= class FileLoader: + ''' + A File loader for a given file extension + + :param extension: the file extension of the file + ''' def __init__(self, extension): diff --git a/aertb/core/hdf5tools.py b/aertb/core/hdf5tools.py new file mode 100644 index 0000000..74aa3b0 --- /dev/null +++ b/aertb/core/hdf5tools.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ============================================================================= + +__author__ = "Rafael Mosca" +__email__ = "rafael.mosca@mail.polimi.it" +__copyright__ = "Copyright 2020 - Rafael Mosca" +__license__ = "MIT" +__version__ = "1.0" + +# ============================================================================= +import numpy as np +import random +import click +import h5py + +from .types import Sample, event_dtype +# ============================================================================= + +class HDF5FileIterator: + + def __init__(self, filename, groups='all', n_samples='all', rand=-1): + """ + Returns an iterator over an HDF5 file + Suggested usage is: + ``` + iterator = HDF5FileIterator(..) + + for elem in iterator: + # do something ... + ``` + + Params + ------ + :param groups: the groups in the HDF5 that will be considered + by default all groups + :param n_samples: the number of samples that will be considered + by default every sample in the group + :param rand: if greater than zero it specifies the seed for the + random selection, if negative it is sequential + + Returns + ------- + nothing + + """ + dataset = h5py.File(filename, 'r') + self.dataset = dataset + + if groups == 'all': groups = list(dataset.keys()) + + samples = [] + for group in groups: + + group_samples = list(dataset[group].keys()) + + if n_samples == 'all': + n_samples = len(group_samples) + + elif len(group_samples) < n_samples: + err_msg = f'There are insufficient samples in groupĀ {group}' + click.secho(err_msg, bg='yellow') + n_samples = group_samples + + random.seed(rand) + indices = random.sample(range(0, len(group_samples)), n_samples) + + for i in indices: + samples.append((group_samples[i], group)) + + if rand > 0: + random.Random(rand).shuffle(samples) + + self.samples = samples + self.index = 0 + + def __iter__(self): + return self + + def __next__(self): + + while self.index < len(self.samples): + sample, group = self.samples[self.index] + data = self.dataset[group][sample] + events_np = np.array(data, dtype=event_dtype) + self.index += 1 + + return Sample(sample, group, events_np) + + else: + self.dataset.close() + raise StopIteration + + def reset(self): + self.index = 0 \ No newline at end of file diff --git a/aertb/core/types.py b/aertb/core/types.py index 0260bfc..0af5c5b 100644 --- a/aertb/core/types.py +++ b/aertb/core/types.py @@ -10,9 +10,11 @@ # ============================================================================= import numpy as np +from collections import namedtuple # ============================================================================= event_dtype = np.dtype([('x', np.uint16), ('y', np.uint16), ('ts', np.float32), ('p', np.int8)]) +Sample = namedtuple('Sample', ['name', 'label', 'events']) # ============================================================================= \ No newline at end of file diff --git a/__main__.py b/cli.py similarity index 100% rename from __main__.py rename to cli.py diff --git a/gitignore.txt b/gitignore.txt index eae99ed..d01742a 100644 --- a/gitignore.txt +++ b/gitignore.txt @@ -130,4 +130,5 @@ dmypy.json # Mac OS .Ds_Store +.DS_store .DS_Store \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dbe67ba..aa91afc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -matplotlib==3.2.2 +click_shell==2.0 tqdm==4.46.1 -h5py==2.10.0 gif==1.0.4 -numpy==1.18.5 -click_shell==2.0 click==7.1.2 +numpy==1.18.5 +h5py==2.10.0 +matplotlib==3.2.2 diff --git a/setup.py b/setup.py index f292a16..f23bcef 100644 --- a/setup.py +++ b/setup.py @@ -18,13 +18,20 @@ setup( name='aertb', - version="0.1.3", + version="0.2.0", author="Rafael Mosca", author_email="rafael.mosca@mail.polimi.it", url='https://github.com/rfma23', + # scripts=['bin/aertb_cli.py'], + entry_points={ + 'console_scripts': [ + 'aertb = cli:aertb_shell', + ], + }, packages=["aertb", "aertb.core", "aertb.core.loaders"], keywords = ['aedat', 'aer', 'dat', 'event', 'camera'], classifiers=list(filter(None, metadata.split('\n'))), long_description=long_description, long_description_content_type='text/markdown' ) +