Skip to content

Commit

Permalink
refactor config validation_policy to not store policies on the config (
Browse files Browse the repository at this point in the history
  • Loading branch information
brianjlai authored Jul 11, 2023
1 parent f3ea989 commit f79aa72
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Mapping, Optional, Union

from airbyte_cdk.models import ConfiguredAirbyteCatalog
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, validator

PrimaryKeyType = Optional[Union[str, List[str], List[List[str]]]]

Expand Down Expand Up @@ -66,31 +66,25 @@ class FileBasedStreamConfig(BaseModel):
file_type: str
globs: Optional[List[str]]
validation_policy: Union[str, Any]
validation_policies: Dict[str, Any]
catalog_schema: Optional[ConfiguredAirbyteCatalog]
input_schema: Optional[Dict[str, Any]]
primary_key: PrimaryKeyType
max_history_size: Optional[int]
days_to_sync_if_history_is_full: Optional[int]
format: Optional[Mapping[str, CsvFormat]] # this will eventually be a Union once we have more than one format type

@validator("file_type", pre=True)
def validate_file_type(cls, v):
if v not in VALID_FILE_TYPES:
raise ValueError(f"Format filetype {v} is not a supported file type")
return v

@validator("format", pre=True)
def transform_format(cls, v):
if isinstance(v, Mapping):
file_type = v.get("filetype", "")
if file_type.casefold() not in VALID_FILE_TYPES:
raise ValueError(f"Format filetype {file_type} is not a supported file type")
return {file_type: {key: val for key, val in v.items()}}
if file_type:
if file_type.casefold() not in VALID_FILE_TYPES:
raise ValueError(f"Format filetype {file_type} is not a supported file type")
return {file_type: {key: val for key, val in v.items()}}
return v

@root_validator
def set_validation_policy(cls, values):
validation_policy_key = values.get("validation_policy")
validation_policies = values.get("validation_policies")

if validation_policy_key not in validation_policies:
raise ValueError(f"validation_policy must be one of {list(validation_policies.keys())}")

values["validation_policy"] = validation_policies[validation_policy_key]

return values
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class FileBasedSourceError(Enum):
CONFIG_VALIDATION_ERROR = "Error creating stream config object."
MISSING_SCHEMA = "Expected `json_schema` in the configured catalog but it is missing."
UNDEFINED_PARSER = "No parser is defined for this file type."
UNDEFINED_VALIDATION_POLICY = "The validation policy defined in the config does not exist for the source."


class BaseFileBasedSourceError(Exception):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
try:
streams = []
for stream in config["streams"]:
stream_config = FileBasedStreamConfig(validation_policies=self.validation_policies, **stream)
stream_config = FileBasedStreamConfig(**stream)
if stream_config.validation_policy not in self.validation_policies:
raise ValidationError(
f"validation_policy must be one of {list(self.validation_policies.keys())}", model=FileBasedStreamConfig
)
streams.append(
DefaultFileBasedStream(
config=stream_config,
Expand All @@ -90,6 +94,7 @@ def streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
availability_strategy=self.availability_strategy,
discovery_policy=self.discovery_policy,
parsers=self.parsers,
validation_policies=self.validation_policies,
cursor=DefaultFileBasedCursor(stream_config.max_history_size, stream_config.days_to_sync_if_history_is_full),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, PrimaryKeyType
from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, UndefinedParserError
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError, UndefinedParserError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.types import StreamSlice, StreamState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
availability_strategy: AvailabilityStrategy,
discovery_policy: AbstractDiscoveryPolicy,
parsers: Dict[str, FileTypeParser],
validation_policies: Dict[str, AbstractSchemaValidationPolicy],
):
super().__init__()
self.config = config
Expand All @@ -50,6 +52,7 @@ def __init__(
self._discovery_policy = discovery_policy
self._availability_strategy = availability_strategy
self._parsers = parsers
self._validation_policies = validation_policies

@property
@abstractmethod
Expand Down Expand Up @@ -122,7 +125,13 @@ def get_parser(self, file_type: str) -> FileTypeParser:
raise UndefinedParserError(FileBasedSourceError.UNDEFINED_PARSER, stream=self.name, file_type=file_type)

def record_passes_validation_policy(self, record: Mapping[str, Any]) -> bool:
return self.config.validation_policy.record_passes_validation_policy(record, self._catalog_schema)
validation_policy = self._validation_policies.get(self.config.validation_policy)
if validation_policy:
return validation_policy.record_passes_validation_policy(record=record, schema=self._catalog_schema)
else:
raise RecordParseError(
FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, stream=self.name, validation_policy=self.config.validation_policy
)

@cached_property
def availability_strategy(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping
type=Type.LOG,
log=AirbyteLogMessage(
level=Level.INFO,
message=f"Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema. stream={self.name} file={file.uri} validation_policy={self.config.validation_policy.name} n_skipped={n_skipped}",
message=f"Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema. stream={self.name} file={file.uri} validation_policy={self.config.validation_policy} n_skipped={n_skipped}",
),
)
break
Expand All @@ -114,7 +114,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping
type=Type.LOG,
log=AirbyteLogMessage(
level=Level.INFO,
message=f"Records in file did not pass validation policy. stream={self.name} file={file.uri} n_skipped={n_skipped} validation_policy={self.config.validation_policy.name}",
message=f"Records in file did not pass validation policy. stream={self.name} file={file.uri} n_skipped={n_skipped} validation_policy={self.config.validation_policy}",
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#

import pytest as pytest
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, QuotingBehavior
from airbyte_cdk.sources.file_based.schema_validation_policies import EmitRecordPolicy
from airbyte_cdk.sources.file_based.config.file_based_stream_config import CsvFormat, FileBasedStreamConfig, QuotingBehavior
from pydantic import ValidationError


Expand All @@ -27,14 +26,53 @@ def test_csv_config(file_type, input_format, expected_format, expected_error):
"file_type": file_type,
"globs": ["*"],
"validation_policy": "emit_record",
"validation_policies": {"emit_record": EmitRecordPolicy()},
"format": input_format,
"format": {
file_type: input_format
},
}

if expected_error:
with pytest.raises(expected_error):
FileBasedStreamConfig(**stream_config)
else:
actual_config = FileBasedStreamConfig(**stream_config)
assert not hasattr(actual_config.format[file_type], "filetype")
for expected_format_field, expected_format_value in expected_format.items():
assert isinstance(actual_config.format[file_type], CsvFormat)
assert getattr(actual_config.format[file_type], expected_format_field) == expected_format_value


def test_legacy_format():
"""
This test verifies that we can process the legacy format of the config object used by the existing S3 source with a
single `format` option as opposed to the current file_type -> format mapping.
"""
stream_config = {
"name": "stream1",
"file_type": "csv",
"globs": ["*"],
"validation_policy": "emit_record_on_schema_mismatch",
"format": {
"filetype": "csv",
"delimiter": "d",
"quote_char": "q",
"escape_char": "e",
"encoding": "ascii",
"double_quote": True,
"quoting_behavior": "Quote All"
},
}

expected_format = {
"delimiter": "d",
"quote_char": "q",
"escape_char": "e",
"encoding": "ascii",
"double_quote": True,
"quoting_behavior": QuotingBehavior.QUOTE_ALL
}

actual_config = FileBasedStreamConfig(**stream_config)
assert isinstance(actual_config.format["csv"], CsvFormat)
for expected_format_field, expected_format_value in expected_format.items():
assert getattr(actual_config.format["csv"], expected_format_field) == expected_format_value
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError
from unit_tests.sources.file_based.helpers import (
FailingSchemaValidationPolicy,
TestErrorListMatchingFilesInMemoryFilesStreamReader,
Expand Down Expand Up @@ -174,8 +174,8 @@
],
}
)
.set_validation_policies(FailingSchemaValidationPolicy)
.set_expected_check_error(ConfigValidationError, FileBasedSourceError.ERROR_VALIDATING_RECORD)
.set_validation_policies({FailingSchemaValidationPolicy.ALWAYS_FAIL: FailingSchemaValidationPolicy()})
.set_expected_check_error(None, FileBasedSourceError.ERROR_VALIDATING_RECORD)
).build()


Expand Down

0 comments on commit f79aa72

Please sign in to comment.