Skip to content

Commit

Permalink
Fix bug with IdentityFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulSimpetru committed Sep 11, 2024
1 parent 00e5704 commit cc5b201
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
7 changes: 5 additions & 2 deletions doc_octopy/datasets/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def __init__(
ground_truth_data_path: Path,
save_path: Path,
tasks_to_use: Sequence[str] = EXPERIMENTS_TO_USE,
debug: bool = False,
):
self.emg_data_path = emg_data_path
self.ground_truth_data_path = ground_truth_data_path
self.save_path = save_path
self.tasks_to_use = tasks_to_use
self.debug = debug

def create_dataset(self):
EMGDataset(
Expand All @@ -79,13 +81,13 @@ def create_dataset(self):
emg_representations_to_filter_after_chunking=["Last"],
ground_truth_filter_pipeline_before_chunking=[
[
ApplyFunctionFilter(function=np.reshape, newshape=(63, -1)),
ApplyFunctionFilter(function=np.reshape, name="Reshape", newshape=(63, -1)),
IndexDataFilter(indices=(slice(3, 63),)),
]
],
ground_truth_representations_to_filter_before_chunking=["Input"],
ground_truth_filter_after_pipeline_chunking=[
[ApplyFunctionFilter(function=np.mean, axis=-1, is_output=True)]
[ApplyFunctionFilter(function=np.mean, name="Mean", axis=-1, is_output=True)]
],
ground_truth_representations_to_filter_after_pipeline_chunking=["Last"],
augmentation_pipelines=[
Expand All @@ -94,6 +96,7 @@ def create_dataset(self):
[WaveletDecomposition(level=3, is_output=True, nr_of_grids=5)],
],
amount_of_chunks_to_augment_at_once=500,
debug=self.debug,
).create_dataset()


Expand Down
16 changes: 1 addition & 15 deletions doc_octopy/datasets/filters/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,21 +186,12 @@ class IdentityFilter(FilterBaseClass):
This filter is useful for debugging and testing purposes.
.. important:: If the input array is provided at initialization, the filter will always return that array.
Parameters
----------
input_is_chunked : bool
Whether the input is chunked or not.
is_output : bool
Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline.
input_array : np.ndarray
The input array that will be returned by the filter. If provided, the filter will always return this array.
Attributes
----------
input_array : np.ndarray
The input array that was filtered. This is stored after the filter is called.
Methods
-------
Expand All @@ -212,7 +203,6 @@ def __init__(
self,
input_is_chunked: bool = None,
is_output: bool = False,
input_array: np.ndarray = None,
name: str = None,
):
super().__init__(
Expand All @@ -222,9 +212,5 @@ def __init__(
name=name,
)

self.input_array = input_array

def _filter(self, input_array: np.ndarray) -> np.ndarray:
if self.input_array is None:
self.input_array = input_array
return self.input_array
return input_array
25 changes: 25 additions & 0 deletions doc_octopy/datasets/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ def create_dataset(self):
)

if self.debug:
print("After loading:")

print(emg_data_from_task)
emg_data_from_task.plot_graph()

print(ground_truth_data_from_task)
ground_truth_data_from_task.plot_graph()

if not emg_data_from_task.is_chunked["Input"]:
Expand Down Expand Up @@ -262,7 +267,12 @@ def create_dataset(self):
chunked_ground_truth_data_from_task = ground_truth_data_from_task

if self.debug:
print("After chunking:")

print(chunked_emg_data_from_task)
chunked_emg_data_from_task.plot_graph()

print(chunked_ground_truth_data_from_task)
chunked_ground_truth_data_from_task.plot_graph()
else:
chunked_emg_data_from_task = emg_data_from_task # [:min_length]
Expand Down Expand Up @@ -298,7 +308,12 @@ def create_dataset(self):
)

if self.debug:
print("After filtering the chunked data:")

print(emg_data_from_task)
chunked_emg_data_from_task.plot_graph()

print(ground_truth_data_from_task)
chunked_ground_truth_data_from_task.plot_graph()

for group_name, chunked_data_from_task in zip(
Expand All @@ -321,6 +336,16 @@ def create_dataset(self):
chunked_emg_data_from_task.output_representations.values()
)[-1].shape[0]

data_length_ground_truth = list(
chunked_ground_truth_data_from_task.output_representations.values()
)[-1].shape[0]

assert (
data_length == data_length_ground_truth
), "The data lengths of the EMG and ground truth data should be the same. For task {}, the EMG data has length {} and the ground truth data has length {}.".format(
task, data_length, data_length_ground_truth
)

for g in (training_group, testing_group, validation_group):
_add_to_dataset(
g,
Expand Down

0 comments on commit cc5b201

Please sign in to comment.