Skip to content

Commit

Permalink
Merge pull request #36 from marcpinet/refactor-more-robust-utils
Browse files Browse the repository at this point in the history
Refactor more robust utils
  • Loading branch information
marcpinet authored Apr 24, 2024
2 parents dece1cf + 7466989 commit 955e0ed
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 40 deletions.
32 changes: 2 additions & 30 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
publish:
needs: check-version-changed
if: needs.check-version-changed.outputs.version_changed == 'true' && github.ref == 'refs/heads/main'
if: needs.check-version-changed.outputs.version_changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Check out code
Expand All @@ -62,32 +62,4 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
publish-dev:
needs: check-version-changed
if: needs.check-version-changed.outputs.version_changed == 'true' && github.ref == 'refs/heads/develop'
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Get version
id: get_version
run: |
echo "VERSION=$(python setup.py --version)" >> $GITHUB_OUTPUT
- name: Build and publish dev version
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
VERSION="${{ steps.get_version.outputs.VERSION }}-dev"
python setup.py sdist bdist_wheel
twine upload --repository-url https://test.pypi.org/legacy/ dist/*$VERSION*
twine upload dist/*
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
release:
needs: check-version-changed
if: needs.check-version-changed.outputs.version_changed == 'true' && github.ref == 'refs/heads/main'
if: needs.check-version-changed.outputs.version_changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
Expand Down Expand Up @@ -86,4 +86,4 @@ jobs:
tag_name: v${{ steps.get_version.outputs.VERSION }}
body: ${{ steps.generate_release_notes.outputs.RELEASE_NOTES }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 changes: 14 additions & 7 deletions neuralnetlib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def dict_with_list_to_dict_with_ndarray(d: dict) -> dict:
return d


def shuffle(x, y, random_state: int = None) -> tuple:
def shuffle(x: np.ndarray, y: np.ndarray, random_state: int = None) -> tuple:
"""Shuffles the data along the first axis."""
x = np.array(x)
y = np.array(y)
rng = np.random.default_rng(random_state if random_state is not None else int(time.time_ns()))
indices = rng.permutation(len(x))
return x[indices], y[indices]
Expand All @@ -44,23 +46,28 @@ def progress_bar(current: int, total: int, width: int = 30, message: str = "") -
sys.stdout.flush()


def train_test_split(x, y, test_size: float = 0.2, random_state: int = None) -> tuple:
def train_test_split(x: np.ndarray, y: np.ndarray, test_size: float = 0.2, random_state: int = None) -> tuple:
"""
Splits the data into training and test sets.
Args:
x (np.ndarray): input data
y (np.ndarray): target data
x (np.ndarray or list): input data
y (np.ndarray or list): target data
test_size (float): the proportion of the dataset to include in the test split
random_state (int): seed for the random number generator
Returns:
tuple: x_train, x_test, y_train, y_test
tuple: (x_train, x_test, y_train, y_test)
"""
x = np.array(x)
y = np.array(y)
rng = np.random.default_rng(random_state if random_state is not None else int(time.time_ns()))
indices = np.arange(len(x))
rng.shuffle(indices)
split_index = int(len(x) * (1 - test_size))
x_train, x_test = x[indices[:split_index]], x[indices[split_index:]]
y_train, y_test = y[indices[:split_index]], y[indices[split_index:]]
x_train = x[indices[:split_index]]
x_test = x[indices[split_index:]]
y_train = y[indices[:split_index]]
y_test = y[indices[split_index:]]
return x_train, x_test, y_train, y_test

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='neuralnetlib',
version='2.4.1',
version='2.4.2',
author='Marc Pinet',
description='A simple convolutional neural network library with only numpy as dependency',
long_description=open('README.md', encoding="utf-8").read(),
Expand Down

0 comments on commit 955e0ed

Please sign in to comment.