Skip to content

Commit

Permalink
feat: Make batch size a CLI argument (#19)
Browse files Browse the repository at this point in the history
Makes it possible to set the number of files per batch for validation in
the command line. Changes the random sampler to change its batch sample
size based on "NO_OF_BATCHES".
  • Loading branch information
simentha authored Jan 8, 2024
1 parent 17120c7 commit 270c97c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
8 changes: 6 additions & 2 deletions atmos_validation/validate_netcdf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
\t\t\t\t\t Can be extremely slow for large datasets. Default behaviour is taking random samples.
--{validation_settings.SKIP_MIN_MAX_CHECK} \t Skip random sample check for min/max values.
--{validation_settings.SKIP_WARNINGS} \t\t\t Skip all checks that would only output a "WARNING".
--{validation_settings.BATCH_SIZE} \t\t\t Set the amount of .nc files per batch to validate. Defaults to 50.
"""


Expand Down Expand Up @@ -70,7 +71,7 @@ def validate(
path: str,
injected_logger: Optional[logging.Logger] = None,
additional_args: Optional[List[str]] = None,
batch_size: int = 50,
batch_size: Optional[int] = None,
) -> ValidationResult:
"""
Execute validation on a directory or file.
Expand All @@ -90,10 +91,13 @@ def validate(
log.create_or_update_logger(injected_logger)
if additional_args:
validation_settings.apply_settings(additional_args)
if batch_size:
validation_settings.set_batch_size(override=batch_size)

try:
log.info("load dataset from path %s", path)
batches = load_paths(path, batch_size)
batches = load_paths(path, batch_size=validation_settings.get_batch_size())
validation_settings.NO_OF_BATCHES = len(batches)
if not batches:
raise OSError("No NetCDF files in dir")
except Exception as err:
Expand Down
29 changes: 28 additions & 1 deletion atmos_validation/validate_netcdf/validation_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
CHECK_MIN_MAX_FULL: str = "--check-min-max-full"
SKIP_MIN_MAX_CHECK: str = "--skip-random-min-max-check"
SKIP_WARNINGS: str = "--skip-warnings"
URL_TO_PARAMETERS: str = "--set-url-to-parameters ="
URL_TO_PARAMETERS: str = "--set-url-to-parameters="
DEFAULT_URL_TO_PARAMETERS: str = "https://atmos.app.radix.equinor.com/config/parameters"
BATCH_SIZE: str = "--batch-size="
DEFAULT_BATCH_SIZE = 50
NO_OF_BATCHES: int
SETTINGS = set()


Expand All @@ -27,6 +30,30 @@ def get_url_to_parameters() -> str:
return DEFAULT_URL_TO_PARAMETERS


def set_batch_size(override: int):
batch_size_argument = None
for arg in SETTINGS:
if str(arg).startswith(BATCH_SIZE):
batch_size_argument = arg

if batch_size_argument:
SETTINGS.remove(batch_size_argument)

SETTINGS.add(f"{BATCH_SIZE}{override}")


def get_batch_size() -> int:
batch_size = DEFAULT_BATCH_SIZE
for arg in SETTINGS:
if str(arg).startswith(BATCH_SIZE):
try:
batch_size = int(str(arg).split(BATCH_SIZE, 1)[1].strip())
except ValueError as e:
raise TypeError("Batch size must be an integer") from e
print(f"using batch_size = {batch_size}")
return batch_size


def should_skip_min_max_check() -> bool:
return SKIP_MIN_MAX_CHECK in SETTINGS

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import sys
from typing import List, Tuple, Union

import xarray as xr
Expand All @@ -15,25 +16,26 @@
from ...utils import Severity, validation_node
from ...validation_logger import log

SEED = random.randrange(sys.maxsize)
print(f"using random seed: {SEED}")


def _get_random_time_slice(actual: xr.DataArray, rand: random.Random) -> slice:
"""To save processing time we only take 500 timestamps
per batch from a random starting time
"""
"""To save processing time we only check 5000 timestamps random samples"""
len_time = len(actual.Time)
if len_time > 1000:
start = rand.randint(0, len_time - 500)
sample_size = 5000 // validation_settings.NO_OF_BATCHES
if len_time > sample_size * 2:
start = rand.randint(0, len_time - sample_size - 1)
time_slice = slice(
start,
start + 500,
start + sample_size,
)
else:
start = rand.randint(0, len_time)
start = rand.randint(0, len_time // 2)
time_slice = slice(
start,
start + len_time,
start + len_time // 2,
)

return time_slice


Expand Down Expand Up @@ -123,7 +125,7 @@ def varinterval_validator(actual: xr.DataArray, expected: ParameterConfig) -> Li
def _check_randomly_selected_intervals_min_max(
actual: xr.DataArray, expected: ParameterConfig, dims: List[str]
):
rand = random.Random(1)
rand = random.Random(SEED)
result = []

slice_tuple = _get_slice_tuple(dims, actual, rand)
Expand Down

0 comments on commit 270c97c

Please sign in to comment.