Skip to content

Commit

Permalink
Add tests for generators and add MPG + Abalone
Browse files Browse the repository at this point in the history
  • Loading branch information
SamTov committed Sep 26, 2023
1 parent b6f3763 commit ca30b41
Show file tree
Hide file tree
Showing 10 changed files with 395 additions and 2 deletions.
47 changes: 47 additions & 0 deletions CI/unit_tests/data/test_abalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test Abaone generator.
"""
from znnl.data import AbaloneDataGenerator


class TestAbaloneGenerator:
"""
Class for testing the Abalone generator.
"""

def test_creation(self):
"""
Test if one can create the generator.
"""
generator = AbaloneDataGenerator(train_fraction=0.8)

assert generator is not None
assert generator.train_ds["inputs"].shape == (3342, 10)
assert generator.train_ds["targets"].shape == (3342, 1)

assert generator.test_ds["inputs"].shape == (835, 10)
assert generator.test_ds["targets"].shape == (835, 1)
60 changes: 60 additions & 0 deletions CI/unit_tests/data/test_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test CIFAR10 generator.
"""
from znnl.data import CIFAR10Generator


class TestCIFARGenerator:
"""
Class for testing the CIFAR generator.
"""

def test_one_hot_creation(self):
"""
Test if one can create the generator.
"""
generator = CIFAR10Generator(ds_size=500)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.train_ds["targets"].shape == (500, 10)

assert generator.test_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.test_ds["targets"].shape == (500, 10)

def test_serial_creation(self):
"""
Test if one can create the generator.
"""
generator = CIFAR10Generator(ds_size=500, one_hot_encoding=False)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.train_ds["targets"].shape == (500, 1)

assert generator.test_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.test_ds["targets"].shape == (500, 1)
34 changes: 34 additions & 0 deletions CI/unit_tests/data/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,38 @@
Summary
-------
Test MNIST generator.
"""
from znnl.data import MNISTGenerator


class TestMNISTGenerator:
"""
Class for testing the MNIST generator.
"""

def test_one_hot_creation(self):
"""
Test if one can create the generator.
"""
generator = MNISTGenerator(ds_size=500)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.train_ds["targets"].shape == (500, 10)

assert generator.test_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.test_ds["targets"].shape == (500, 10)

def test_serial_creation(self):
"""
Test if one can create the generator.
"""
generator = MNISTGenerator(ds_size=500, one_hot_encoding=False)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.train_ds["targets"].shape == (500, 1)

assert generator.test_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.test_ds["targets"].shape == (500, 1)
47 changes: 47 additions & 0 deletions CI/unit_tests/data/test_mpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test MPG generator.
"""
from znnl.data import MPGDataGenerator


class TestMPGGenerator:
"""
Class for testing the MPG generator.
"""

def test_creation(self):
"""
Test if one can create the generator.
"""
generator = MPGDataGenerator(train_fraction=0.8)

assert generator is not None
assert generator.train_ds["inputs"].shape == (314, 9)
assert generator.train_ds["targets"].shape == (314, 1)

assert generator.test_ds["inputs"].shape == (78, 9)
assert generator.test_ds["targets"].shape == (78, 1)
4 changes: 4 additions & 0 deletions znnl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
Summary
-------
"""
from znnl.data.abalone import AbaloneDataGenerator
from znnl.data.cifar10 import CIFAR10Generator
from znnl.data.confined_particles import ConfinedParticles
from znnl.data.data_generator import DataGenerator
from znnl.data.mnist import MNISTGenerator
from znnl.data.mpg_generator import MPGDataGenerator
from znnl.data.points_on_a_circle import PointsOnCircle
from znnl.data.points_on_a_lattice import PointsOnLattice

Expand All @@ -38,4 +40,6 @@
PointsOnCircle.__name__,
MNISTGenerator.__name__,
CIFAR10Generator.__name__,
MPGDataGenerator.__name__,
AbaloneDataGenerator.__name__,
]
116 changes: 116 additions & 0 deletions znnl/data/abalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Abalone dataset generator.
"""
import urllib.request
import zipfile

import pandas as pd

from znnl.data.data_generator import DataGenerator


class AbaloneDataGenerator(DataGenerator):
"""
Generator for the Abalone data-set.
"""

def __init__(self, train_fraction: float):
"""
Constructor for the abalone dataset.
Parameters
----------
train_fraction : float
Fraction of the data to use for training.
"""
self._load_data()

self.data_file = "abalone.data"
self.columns = [
"Sex",
"Length",
"Diameter",
"Height",
"Whole weight",
"Shucked weight",
"Viscera weight",
"Shell weight",
"Rings",
]

# Collect the processed data
processed_data = self._process_raw_data()

# Create the data-sets
train_ds = processed_data.sample(frac=train_fraction, random_state=0)
train_labels = train_ds.pop("Rings")

test_ds = processed_data.drop(train_ds.index)
test_labels = test_ds.pop("Rings")

self.train_ds = {
"inputs": train_ds.to_numpy(),
"targets": train_labels.to_numpy().reshape(-1, 1),
}
self.test_ds = {
"inputs": test_ds.to_numpy(),
"targets": test_labels.to_numpy().reshape(-1, 1),
}

self.data_pool = self.train_ds["inputs"]

def _load_data(self):
"""
Download the data.
"""
filehandle, _ = urllib.request.urlretrieve(
"http://archive.ics.uci.edu/static/public/1/abalone.zip"
)
with zipfile.ZipFile(filehandle, "r") as zip_ref:
zip_ref.extractall()

def _process_raw_data(self):
"""
Process the raw data
"""
# Process the raw data.
raw_data = pd.read_csv(
self.data_file,
names=self.columns,
na_values="?",
comment="#",
sep=",",
skipinitialspace=True,
)
raw_data.dropna()

# encode the sex data
raw_data = pd.get_dummies(raw_data, columns=["Sex"], prefix="", prefix_sep="")
# Normalize
raw_data = (raw_data - raw_data.mean()) / raw_data.std()

return raw_data
3 changes: 3 additions & 0 deletions znnl/data/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(self, ds_size: int = 500, one_hot_encoding: bool = True):
self.test_ds["targets"] = nn.one_hot(
self.test_ds["targets"], num_classes=10
)
else:
self.train_ds["targets"] = self.train_ds["targets"].reshape(-1, 1)
self.test_ds["targets"] = self.test_ds["targets"].reshape(-1, 1)

def plot_image(self, indices: list = None, data_list: list = None):
"""
Expand Down
6 changes: 4 additions & 2 deletions znnl/data/decision_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def linear_boundary(data: onp.ndarray, gradient: float, intercept: float) -> np.
Parameters
----------
data : np.ndarray (n_samples, 2)
data : onp.ndarray (n_samples, 2)
Data to be converted into classes.
gradient : float
Gradient of the line, default 1.0.
Expand All @@ -67,7 +67,7 @@ class 0 and outside are class 1.
Parameters
----------
data : np.ndarray
data : onp.ndarray
Data to be converted into classes.
radius : float
Radius of the circle.
Expand Down Expand Up @@ -131,6 +131,8 @@ def __init__(
self.train_ds = self._build_dataset(n_samples=n_samples)
self.test_ds = self._build_dataset(n_samples=n_samples)

self.data_pool = self.train_ds["inputs"]

def _build_dataset(self, n_samples: int):
"""
Helper method to create datasets quickly.
Expand Down
3 changes: 3 additions & 0 deletions znnl/data/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(self, ds_size: int = 500, one_hot_encoding: bool = True):
self.test_ds["targets"] = nn.one_hot(
self.test_ds["targets"], num_classes=10
)
else:
self.train_ds["targets"] = self.train_ds["targets"].reshape(-1, 1)
self.test_ds["targets"] = self.test_ds["targets"].reshape(-1, 1)

def plot_image(self, indices: list = None, data_list: list = None):
"""
Expand Down
Loading

0 comments on commit ca30b41

Please sign in to comment.