Skip to content

Commit

Permalink
Fixes (#82)
Browse files Browse the repository at this point in the history
* wip

* Update training-loop-build-rust.yml

Signed-off-by: Rui Campos <mail@ruicampos.org>

---------

Signed-off-by: Rui Campos <mail@ruicampos.org>
  • Loading branch information
RuiFilipeCampos authored Mar 13, 2024
1 parent 7c1e865 commit 87d0c30
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/training-loop-build-rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
echo "RUST_BACKTRACE=full" >> $GITHUB_ENV
- run: cargo check
- run: cargo test
- if: ${{ github.event_name == 'release' }}
run: cargo build --release --target-dir build
- if: ${{ github.event_name == 'release' }}
Expand Down
54 changes: 38 additions & 16 deletions pipelines/text_classification/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,42 +130,51 @@ def prepare_slices(conn, rng, epochs: int, number_of_partions: int, data_source:


limit = number_of_rows // number_of_partions
excess = number_of_rows % number_of_partions




step = 1

logger.info(f"Generating data for {folder} ...")

for epoch in range(epochs):

generate_permutation(conn, rng, number_of_rows)

for slice_idx in range(number_of_partions):

logger.info(f"Constructing epoch {epoch} slice {slice_idx}")

generate_permutation(conn, rng, number_of_rows)

data_from_db: list[tuple[Sentiment, str]] = fetch_data(
conn,
data_source,
offset=limit*slice_idx,
limit = number_of_rows // number_of_partions
offset = limit*slice_idx,
limit = limit if not slice_idx == number_of_partions - 1 else limit + excess
)

sentiments, reviews = parse_data_from_db(data_from_db)
safetensors = raw_data_to_tensor(sentiments, reviews)
stt.save_file(safetensors, f"{folder}/{step}_output.safetensors")
step += 1






@pytest.mark.parametrize("epochs", [2, 1])
@pytest.mark.parametrize("slices", [2, 1])
@pytest.mark.parametrize("slices", [3, 2, 1])
def test_full(epochs: int, slices: int):
import os
import shutil

SEED = 42

POS_SENTIMENT = 1
NEG_SENTIMENT = 0
A_TOKEN = 64
B_TOKEN = 65

Expand Down Expand Up @@ -206,45 +215,58 @@ def test_full(epochs: int, slices: int):
for idx, (dataset_id, permutation_idx) in enumerate(dataset):
assert idx + 1 == dataset_id
assert permutation_idx is not None, f"Permutation column was not constructed properly, value is None. {dataset=}"
assert isinstance(permutation_idx, int)
assert 1 <= permutation_idx <= 4
assert isinstance(permutation_idx, int), "Invalid type for permutation index."
assert 1 <= permutation_idx <= 4, f"Found invalid permutation index: {permutation_idx}"
permutations.append(permutation_idx)
ids.append(dataset_id)

ids.sort()
permutations.sort()
assert ids == permutations


# replicate the permutation column
test_rng = default_rng(seed=SEED)

slice_idx= 1
for _ in range(epochs):

# to reconstruct the data from the slices
epoch_data = []

for _ in range(slices):
safetensors_file = f"{folder}/{slice_idx+1}_output.safetensors"
assert os.path.exists(safetensors_file), "Missing data from disk"
safetensors_file = f"{folder}/{slice_idx}_output.safetensors"
assert os.path.exists(safetensors_file), f"Missing data from disk: {safetensors_file}"

data = stt.load_file(safetensors_file)

assert 'X' in data
assert 'Y' in data


for val in data['Y']:
epoch_data.append(val)
epoch_data.append(float(val))

for token, sentiment in zip(data['X'], data['Y']):
if token == A_TOKEN:
assert sentiment == 1
assert sentiment == POS_SENTIMENT
elif token == B_TOKEN:
assert sentiment == 0
assert sentiment == NEG_SENTIMENT
else:
assert False, f"Invalid token: {token}"

slice_idx += 1

epoch_permutation = test_rng.permutation(len(raw_data))

assert len(epoch_data) == len(raw_data), "Generated data does not have the correct lenght"
assert sum(epoch_data) == 2., "Generated data has incorrect values"


for val_1, val_idx in zip(epoch_data, epoch_permutation)):
val_2 = 1 if raw_data[val_idx][1] == "pos" else 0
assert float(val_1) == float(val_2)
epoch_permutation = test_rng.permutation(len(raw_data))

for val_1, val_idx in zip(epoch_data, epoch_permutation):
val_2 = POS_SENTIMENT if raw_data[val_idx][1] == "pos" else NEG_SENTIMENT
assert float(val_1) == float(val_2), epoch_data



Expand Down

0 comments on commit 87d0c30

Please sign in to comment.