Skip to content

Commit

Permalink
Sam tov new data generators (#102)
Browse files Browse the repository at this point in the history
* Add new data generators

* Update pre-commit hooks and reformat.

* remove protobuf from requirements.txt

* Specify NT version

* Add tests for generators and add MPG + Abalone

* Update comment
  • Loading branch information
SamTov committed Sep 27, 2023
1 parent 5bb2d5f commit aac3c53
Show file tree
Hide file tree
Showing 15 changed files with 690 additions and 10 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 23.9.1
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies: [flake8-isort]
47 changes: 47 additions & 0 deletions CI/unit_tests/data/test_abalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test Abaone generator.
"""
from znnl.data import AbaloneDataGenerator


class TestAbaloneGenerator:
"""
Class for testing the Abalone generator.
"""

def test_creation(self):
"""
Test if one can create the generator.
"""
generator = AbaloneDataGenerator(train_fraction=0.8)

assert generator is not None
assert generator.train_ds["inputs"].shape == (3342, 10)
assert generator.train_ds["targets"].shape == (3342, 1)

assert generator.test_ds["inputs"].shape == (835, 10)
assert generator.test_ds["targets"].shape == (835, 1)
60 changes: 60 additions & 0 deletions CI/unit_tests/data/test_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test CIFAR10 generator.
"""
from znnl.data import CIFAR10Generator


class TestCIFARGenerator:
"""
Class for testing the CIFAR generator.
"""

def test_one_hot_creation(self):
"""
Test if one can create the generator.
"""
generator = CIFAR10Generator(ds_size=500)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.train_ds["targets"].shape == (500, 10)

assert generator.test_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.test_ds["targets"].shape == (500, 10)

def test_serial_creation(self):
"""
Test if one can create the generator.
"""
generator = CIFAR10Generator(ds_size=500, one_hot_encoding=False)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.train_ds["targets"].shape == (500, 1)

assert generator.test_ds["inputs"].shape == (500, 32, 32, 3)
assert generator.test_ds["targets"].shape == (500, 1)
93 changes: 93 additions & 0 deletions CI/unit_tests/data/test_decision_boundary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Unit test for the decision boundary.
"""
import jax.numpy as np
import numpy as onp
from pytest import approx

from znnl.data.decision_boundary import (
DecisionBoundaryGenerator,
circle,
linear_boundary,
)


class TestDecisionBoundary:
"""
Unit test for the decision boundary.
"""

def test_linear_boundary(self):
"""
Test the linear boundary.
"""
target_ratio = 0.0
for _ in range(10):
input_data = onp.random.uniform(0, 1, size=(10000, 2))
target_ratio += linear_boundary(input_data, 1.0, 0.0).mean()

assert target_ratio / 10 == approx(0.5, rel=0.01)

def test_circle(self):
"""
Test the circle boundary.
"""
target_ratio = 0.0
for _ in range(10):
input_data = onp.random.uniform(0, 1, size=(10000, 2))
target_ratio += circle(input_data, 0.25).mean()

# P(x in class 1) = 1 - (pi / 16)
assert target_ratio / 10 == approx(1 - (np.pi / 16), abs=0.01)

def test_one_hot_decision_boundary_generator(self):
"""
Test the actual generator.
"""
generator = DecisionBoundaryGenerator(
n_samples=10000, discriminator="circle", one_hot=True
)

# Check the dataset shapes
assert generator.train_ds["inputs"].shape == (10000, 2)
assert generator.train_ds["targets"].shape == (10000, 2)
assert generator.test_ds["inputs"].shape == (10000, 2)
assert generator.test_ds["targets"].shape == (10000, 2)

def test_serial_decision_boundary_generator(self):
"""
Test the actual generator.
"""
generator = DecisionBoundaryGenerator(
n_samples=10000, discriminator="circle", one_hot=False
)

# Check the dataset shapes
assert generator.train_ds["inputs"].shape == (10000, 2)
assert generator.train_ds["targets"].shape == (10000, 1)
assert generator.test_ds["inputs"].shape == (10000, 2)
assert generator.test_ds["targets"].shape == (10000, 1)
34 changes: 34 additions & 0 deletions CI/unit_tests/data/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,38 @@
Summary
-------
Test MNIST generator.
"""
from znnl.data import MNISTGenerator


class TestMNISTGenerator:
"""
Class for testing the MNIST generator.
"""

def test_one_hot_creation(self):
"""
Test if one can create the generator.
"""
generator = MNISTGenerator(ds_size=500)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.train_ds["targets"].shape == (500, 10)

assert generator.test_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.test_ds["targets"].shape == (500, 10)

def test_serial_creation(self):
"""
Test if one can create the generator.
"""
generator = MNISTGenerator(ds_size=500, one_hot_encoding=False)

assert generator is not None
assert generator.train_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.train_ds["targets"].shape == (500, 1)

assert generator.test_ds["inputs"].shape == (500, 28, 28, 1)
assert generator.test_ds["targets"].shape == (500, 1)
47 changes: 47 additions & 0 deletions CI/unit_tests/data/test_mpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test MPG generator.
"""
from znnl.data import MPGDataGenerator


class TestMPGGenerator:
"""
Class for testing the MPG generator.
"""

def test_creation(self):
"""
Test if one can create the generator.
"""
generator = MPGDataGenerator(train_fraction=0.8)

assert generator is not None
assert generator.train_ds["inputs"].shape == (314, 9)
assert generator.train_ds["targets"].shape == (314, 1)

assert generator.test_ds["inputs"].shape == (78, 9)
assert generator.test_ds["targets"].shape == (78, 1)
11 changes: 5 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
numpy
matplotlib
sphinx
flake8==5.0.4
black==22.8.0
flake8
black
ipython
numpydoc
optax
Expand All @@ -18,10 +18,9 @@ plotly
flax
tqdm
pandas
neural-tangents
neural-tangents==0.6.4
tensorflow-datasets
isort==5.10.1
isort
tensorflow
pyyaml
jupyter
protobuf==3.20.*
jupyter
4 changes: 4 additions & 0 deletions znnl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
Summary
-------
"""
from znnl.data.abalone import AbaloneDataGenerator
from znnl.data.cifar10 import CIFAR10Generator
from znnl.data.confined_particles import ConfinedParticles
from znnl.data.data_generator import DataGenerator
from znnl.data.mnist import MNISTGenerator
from znnl.data.mpg_generator import MPGDataGenerator
from znnl.data.points_on_a_circle import PointsOnCircle
from znnl.data.points_on_a_lattice import PointsOnLattice

Expand All @@ -38,4 +40,6 @@
PointsOnCircle.__name__,
MNISTGenerator.__name__,
CIFAR10Generator.__name__,
MPGDataGenerator.__name__,
AbaloneDataGenerator.__name__,
]
Loading

0 comments on commit aac3c53

Please sign in to comment.