Skip to content

Commit

Permalink
Update data loaders. Now label column filter can be an integer or a l…
Browse files Browse the repository at this point in the history
…ist of integers.

PiperOrigin-RevId: 635061726
  • Loading branch information
raj-sinha committed May 18, 2024
1 parent 246cef9 commit deba662
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 17 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.2.0] - 2022-05-05
## [0.2.1] - 2024-05-18

* Updates to data loaders. Label column filter can now be a list of integers.

## [0.2.0] - 2024-05-05

* Add PyPi support. Minor reorganization of repository.

## [0.1.0] - 2024-04-17

* Initial release

[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...HEAD
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...HEAD
[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1
[0.2.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.1.0...v0.2.0
[0.1.0]: https://github.com/google-research/spade_anomaly_detection/releases/tag/v0.1.0
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Project metadata. Available keys are documented at:
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
name = "spade_anomaly_detection"
version = "0.2.0"
description = "Semi-supervised Pseudo Labeler Anomaly Detection with Ensembling (SPADE) is a semi-supervised anomaly detection method that uses an ensemble of one class classifiers as the pseudo-labelers and supervised classifiers to achieve state of the art results especially on datasets with distribution mismatch between labeled and unlabeled samples."
readme = "README.md"
requires-python = ">=3.8"
Expand Down Expand Up @@ -35,7 +34,7 @@ dependencies = [
]

# `version` is automatically set by flit to use `spade_anomaly_detection.__version__`
# dynamic = ["version"]
dynamic = ["version"]

[project.urls]
homepage = "https://github.com/google-research/spade_anomaly_detection"
Expand Down
2 changes: 1 addition & 1 deletion spade_anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@

# A new PyPI release will be pushed every time `__version__` is increased.
# When changing this, also update the CHANGELOG.md.
__version__ = '0.2.0'
__version__ = '0.2.1'
31 changes: 25 additions & 6 deletions spade_anomaly_detection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,9 @@ def load_tf_dataset_from_bigquery(
where_statements: Optional[List[str]] = None,
ignore_columns: Optional[Sequence[str]] = None,
batch_size: Optional[int] = None,
label_column_filter_value: Optional[int] = None,
label_column_filter_value: Optional[int | list[int]] = None,
convert_features_to_float64: bool = False,
page_size: Optional[int] = None,
) -> tf.data.Dataset:
"""Loads a TensorFlow dataset from a BigQuery Table.
Expand All @@ -346,10 +347,13 @@ def load_tf_dataset_from_bigquery(
dataset is not batched. In this case, when iterating through the
dataset, it will yield one record per call instead of a batch of
records.
label_column_filter_value: An integer used when filtering the label column
values. No value will result in all data returned from the table.
label_column_filter_value: An integer or list of integers used when
filtering the label column values. No value will result in all data
returned from the table.
convert_features_to_float64: Set to True to cast the contents of the
features columns to float64.
page_size: the pagination size to use when retrieving data from BigQuery.
A large value can result in fewer BQ calls, hence time savings.
Returns:
A TensorFlow dataset.
Expand All @@ -362,7 +366,15 @@ def load_tf_dataset_from_bigquery(
where_statements = (
list() if where_statements is None else where_statements
)
where_statements.append(f'{label_col_name} = {label_column_filter_value}')
if isinstance(label_column_filter_value, int):
where_statements.append(
f'{label_col_name} = {label_column_filter_value}'
)
else:
where_statements.append(
f'CAST({label_col_name} AS INT64) IN '
f'UNNEST({label_column_filter_value})'
)

if ignore_columns is not None:
metadata_builder = feature_metadata.BigQueryMetadataBuilder(
Expand All @@ -373,8 +385,10 @@ def load_tf_dataset_from_bigquery(
)

if where_statements:
metadata_retrieval_options = feature_metadata.MetadataRetrievalOptions(
where_clauses=where_statements
metadata_retrieval_options = (
feature_metadata.MetadataRetrievalOptions.get_none(
where_clauses=where_statements
)
)

tf_dataset, metadata = bq_dataset.get_dataset_and_metadata_for_table(
Expand All @@ -384,7 +398,12 @@ def load_tf_dataset_from_bigquery(
drop_remainder=True,
metadata_options=metadata_retrieval_options,
metadata_builder=metadata_builder,
page_size=page_size,
)
options = tf.data.Options()
# Avoid a large warning output by TF Dataset.
options.deterministic = False
tf_dataset = tf_dataset.with_options(options)

self.input_feature_metadata = metadata

Expand Down
8 changes: 5 additions & 3 deletions spade_anomaly_detection/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ def test_load_bigquery_dataset_unlabeled_value(self, metadata_mock):
label_column_filter_value=label_column_filter_value,
)

mock_metadata_call_actual = metadata_mock.call_args.kwargs['where_clauses']
mock_metadata_call_actual = metadata_mock.get_none.call_args.kwargs[
'where_clauses'
]

# Ensure that a where statement was not created when we don't pass in label
# values.
Expand Down Expand Up @@ -399,7 +401,7 @@ def test_where_statement_construction_no_error(self, mock_metadata):

self.assertListEqual(
expected_where_statements,
mock_metadata.call_args.kwargs['where_clauses'],
mock_metadata.get_none.call_args.kwargs['where_clauses'],
)

@mock.patch.object(
Expand All @@ -419,7 +421,7 @@ def test_where_statements_with_label_filter_no_error(self, mock_metadata):

self.assertListEqual(
where_statements,
mock_metadata.call_args.kwargs['where_clauses'],
mock_metadata.get_none.call_args.kwargs['where_clauses'],
)

@mock.patch.object(feature_metadata, 'BigQueryMetadataBuilder', autospec=True)
Expand Down
4 changes: 4 additions & 0 deletions spade_anomaly_detection/data_utils/bq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def get_dataset_and_metadata_for_table(
batch_size: int = 64,
with_mask: bool = _WITH_MASK_DEFAULT,
drop_remainder: bool = False,
page_size: Optional[int] = None,
) -> Tuple[tf.data.Dataset, feature_metadata.BigQueryTableMetadata]:
"""Gets the metadata and dataset for a BigQuery table.
Expand All @@ -735,6 +736,8 @@ def get_dataset_and_metadata_for_table(
with_mask: Whether the dataset should be returned with a mask format. For
more information see get_bigquery_dataset.
drop_remainder: If true no partial batches will be yielded.
page_size: the pagination size to use when retrieving data from BigQuery. A
large value can result in fewer BQ calls, hence time savings.
Returns:
A tuple of the output dataset and metadata for the specified table.
Expand Down Expand Up @@ -783,6 +786,7 @@ def get_dataset_and_metadata_for_table(
cache_location=None,
where_clauses=metadata_options.where_clauses,
drop_remainder=drop_remainder,
page_size=page_size,
)

return dataset, all_metadata
9 changes: 6 additions & 3 deletions spade_anomaly_detection/data_utils/bq_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ def _create_rand_df():
self.assertIsInstance(output_dataset, tf.data.Dataset)

# Loop through the batches and make sure they all match.
epoch = None
for epoch in range(1, 3):
batch_index = 0
for cur_batch in output_dataset:
Expand Down Expand Up @@ -760,7 +761,7 @@ def test_get_dataset_and_metadata_for_table_path(
mock_bq_storage_client = mock.create_autospec(
bigquery_storage.BigQueryReadClient, spec_set=True, instance=True)
metadata_options = feature_metadata.MetadataRetrievalOptions.get_all()
batch_sie = 128
batch_size = 128
with_mask = False

output_dataset, output_metadata = (
Expand All @@ -769,7 +770,7 @@ def test_get_dataset_and_metadata_for_table_path(
bigquery_client=mock_bq_client,
bigquery_storage_client=mock_bq_storage_client,
metadata_options=metadata_options,
batch_size=batch_sie,
batch_size=batch_size,
with_mask=with_mask,
)
)
Expand All @@ -780,11 +781,12 @@ def test_get_dataset_and_metadata_for_table_path(
get_metadata_mock.return_value,
mock_bq_client,
bqstorage_client=mock_bq_storage_client,
batch_size=batch_sie,
batch_size=batch_size,
with_mask=with_mask,
cache_location=None,
where_clauses=(),
drop_remainder=False,
page_size=None,
)

self.assertEqual(output_dataset, get_bigquery_dataset_mock.return_value)
Expand Down Expand Up @@ -832,6 +834,7 @@ def test_get_dataset_and_metadata_for_table_parts_defaults(
cache_location=None,
where_clauses=(),
drop_remainder=False,
page_size=None,
)

self.assertEqual(output_dataset, get_bigquery_dataset_mock.return_value)
Expand Down

0 comments on commit deba662

Please sign in to comment.