From 87d0c30b34fcd415a4922bbc01da6f4c7eef8290 Mon Sep 17 00:00:00 2001 From: Rui Campos Date: Wed, 13 Mar 2024 14:32:50 +0000 Subject: [PATCH] Fixes (#82) * wip * Update training-loop-build-rust.yml Signed-off-by: Rui Campos --------- Signed-off-by: Rui Campos --- .../workflows/training-loop-build-rust.yml | 1 + pipelines/text_classification/datagen.py | 54 +++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/.github/workflows/training-loop-build-rust.yml b/.github/workflows/training-loop-build-rust.yml index cf68ca5d..838959c0 100644 --- a/.github/workflows/training-loop-build-rust.yml +++ b/.github/workflows/training-loop-build-rust.yml @@ -30,6 +30,7 @@ jobs: echo "RUST_BACKTRACE=full" >> $GITHUB_ENV - run: cargo check + - run: cargo test - if: ${{ github.event_name == 'release' }} run: cargo build --release --target-dir build - if: ${{ github.event_name == 'release' }} diff --git a/pipelines/text_classification/datagen.py b/pipelines/text_classification/datagen.py index a054a29e..881be4e8 100644 --- a/pipelines/text_classification/datagen.py +++ b/pipelines/text_classification/datagen.py @@ -130,22 +130,29 @@ def prepare_slices(conn, rng, epochs: int, number_of_partions: int, data_source: limit = number_of_rows // number_of_partions + excess = number_of_rows % number_of_partions + + + + step = 1 logger.info(f"Generating data for {folder} ...") for epoch in range(epochs): + + generate_permutation(conn, rng, number_of_rows) + for slice_idx in range(number_of_partions): logger.info(f"Constructing epoch {epoch} slice {slice_idx}") - generate_permutation(conn, rng, number_of_rows) data_from_db: list[tuple[Sentiment, str]] = fetch_data( conn, data_source, - offset=limit*slice_idx, - limit = number_of_rows // number_of_partions + offset = limit*slice_idx, + limit = limit if not slice_idx == number_of_partions - 1 else limit + excess ) sentiments, reviews = parse_data_from_db(data_from_db) @@ -153,19 +160,21 @@ def prepare_slices(conn, rng, epochs: int, number_of_partions: int, data_source: stt.save_file(safetensors, f"{folder}/{step}_output.safetensors") step += 1 - + @pytest.mark.parametrize("epochs", [2, 1]) -@pytest.mark.parametrize("slices", [2, 1]) +@pytest.mark.parametrize("slices", [3, 2, 1]) def test_full(epochs: int, slices: int): import os import shutil SEED = 42 + POS_SENTIMENT = 1 + NEG_SENTIMENT = 0 A_TOKEN = 64 B_TOKEN = 65 @@ -206,8 +215,8 @@ def test_full(epochs: int, slices: int): for idx, (dataset_id, permutation_idx) in enumerate(dataset): assert idx + 1 == dataset_id assert permutation_idx is not None, f"Permutation column was not constructed properly, value is None. {dataset=}" - assert isinstance(permutation_idx, int) - assert 1 <= permutation_idx <= 4 + assert isinstance(permutation_idx, int), "Invalid type for permutation index." + assert 1 <= permutation_idx <= 4, f"Found invalid permutation index: {permutation_idx}" permutations.append(permutation_idx) ids.append(dataset_id) @@ -215,36 +224,49 @@ def test_full(epochs: int, slices: int): permutations.sort() assert ids == permutations + + # replicate the permutation column test_rng = default_rng(seed=SEED) slice_idx= 1 for _ in range(epochs): + # to reconstruct the data from the slices epoch_data = [] + for _ in range(slices): - safetensors_file = f"{folder}/{slice_idx+1}_output.safetensors" - assert os.path.exists(safetensors_file), "Missing data from disk" + safetensors_file = f"{folder}/{slice_idx}_output.safetensors" + assert os.path.exists(safetensors_file), f"Missing data from disk: {safetensors_file}" data = stt.load_file(safetensors_file) + + assert 'X' in data + assert 'Y' in data + for val in data['Y']: - epoch_data.append(val) + epoch_data.append(float(val)) for token, sentiment in zip(data['X'], data['Y']): if token == A_TOKEN: - assert sentiment == 1 + assert sentiment == POS_SENTIMENT elif token == B_TOKEN: - assert sentiment == 0 + assert sentiment == NEG_SENTIMENT else: assert False, f"Invalid token: {token}" slice_idx += 1 - epoch_permutation = test_rng.permutation(len(raw_data)) + + assert len(epoch_data) == len(raw_data), "Generated data does not have the correct lenght" + assert sum(epoch_data) == 2., "Generated data has incorrect values" + - for val_1, val_idx in zip(epoch_data, epoch_permutation)): - val_2 = 1 if raw_data[val_idx][1] == "pos" else 0 - assert float(val_1) == float(val_2) + epoch_permutation = test_rng.permutation(len(raw_data)) + + for val_1, val_idx in zip(epoch_data, epoch_permutation): + val_2 = POS_SENTIMENT if raw_data[val_idx][1] == "pos" else NEG_SENTIMENT + assert float(val_1) == float(val_2), epoch_data