Skip to content

Commit

Permalink
added a way to filter views
Browse files Browse the repository at this point in the history
  • Loading branch information
Mats E. Mollestad committed Oct 31, 2023
1 parent 5031038 commit 9972f89
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 43 deletions.
7 changes: 6 additions & 1 deletion aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class ModelMetadata:
description: str | None = field(default=None)
predictions_source: BatchDataSource | None = field(default=None)
predictions_stream: StreamDataSource | None = field(default=None)
historical_source: BatchDataSource | None = field(default=None)
dataset_folder: Folder | None = field(default=None)


Expand All @@ -95,6 +96,7 @@ def model_contract(
description: str | None = None,
predictions_source: BatchDataSource | None = None,
predictions_stream: StreamDataSource | None = None,
historical_source: BatchDataSource | None = None,
dataset_folder: Folder | None = None,
) -> Callable[[Type[T]], ModelContractWrapper[T]]:
def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
Expand All @@ -106,6 +108,7 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
description=description,
predictions_source=predictions_source,
predictions_stream=predictions_stream,
historical_source=historical_source,
dataset_folder=dataset_folder,
)
return ModelContractWrapper(metadata, cls)
Expand All @@ -123,6 +126,7 @@ def metadata_with(
tags: dict[str, str] | None = None,
predictions_source: BatchDataSource | None = None,
predictions_stream: StreamDataSource | None = None,
historical_source: BatchDataSource | None = None,
dataset_folder: Folder | None = None,
) -> ModelMetadata:
return ModelMetadata(
Expand All @@ -133,7 +137,8 @@ def metadata_with(
description,
predictions_source,
predictions_stream,
dataset_folder,
historical_source=historical_source,
dataset_folder=dataset_folder,
)

@abstractproperty
Expand Down
81 changes: 72 additions & 9 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, TypeVar, Any
from dataclasses import dataclass

from mashumaro.types import SerializableType

from aligned.schemas.codable import Codable
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import EventTimestamp, Feature

if TYPE_CHECKING:
Expand All @@ -27,14 +29,18 @@ def __init__(self) -> None:
from aligned.sources.redshift import RedshiftSQLDataSource
from aligned.sources.s3 import AwsS3CsvDataSource, AwsS3ParquetDataSource

self.supported_data_sources = {
PostgreSQLDataSource.type_name: PostgreSQLDataSource,
ParquetFileSource.type_name: ParquetFileSource,
CsvFileSource.type_name: CsvFileSource,
AwsS3CsvDataSource.type_name: AwsS3CsvDataSource,
AwsS3ParquetDataSource.type_name: AwsS3ParquetDataSource,
RedshiftSQLDataSource.type_name: RedshiftSQLDataSource,
}
source_types = [
PostgreSQLDataSource,
ParquetFileSource,
CsvFileSource,
AwsS3CsvDataSource,
AwsS3ParquetDataSource,
RedshiftSQLDataSource,
JoinDataSource,
FilteredDataSource,
]

self.supported_data_sources = {source.type_name: source for source in source_types}

@classmethod
def shared(cls) -> BatchDataSourceFactory:
Expand All @@ -47,6 +53,14 @@ def shared(cls) -> BatchDataSourceFactory:
T = TypeVar('T')


class BatchSourceModification:

source: BatchDataSource

def wrap_job(self, job: RetrivalJob) -> RetrivalJob:
raise NotImplementedError()


class BatchDataSource(ABC, Codable, SerializableType):
"""
A definition to where a specific pice of data can be found.
Expand Down Expand Up @@ -112,6 +126,8 @@ def _deserialize(cls, value: dict) -> BatchDataSource:
return data_class.from_dict(value)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
if isinstance(self, BatchSourceModification):
return self.wrap_job(self.source.all_data(request, limit))
raise NotImplementedError()

def all_between_dates(
Expand All @@ -120,13 +136,22 @@ def all_between_dates(
start_date: datetime,
end_date: datetime,
) -> RetrivalJob:
if isinstance(self, BatchSourceModification):
return self.wrap_job(self.source.all_between_dates(request, start_date, end_date))
raise NotImplementedError()

@classmethod
def multi_source_features_for(
cls: type[T], facts: RetrivalJob, requests: list[tuple[T, RetrivalRequest]]
) -> RetrivalJob:
raise NotImplementedError()
if len(requests) != 1:
raise NotImplementedError()

source, _ = requests[0]
if not isinstance(source, BatchSourceModification):
raise NotImplementedError()

return source.wrap_job(type(source.source).multi_source_features_for(facts, requests))

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return type(self).multi_source_features_for(facts, [(self, request)])
Expand All @@ -146,6 +171,8 @@ async def schema(self) -> dict[str, FeatureFactory]:
Returns:
dict[str, FeatureType]: A dictionary containing the column name and the feature type
"""
if isinstance(self, BatchSourceModification):
return await self.source.schema()
raise NotImplementedError(f'`schema()` is not implemented for {type(self)}.')

async def feature_view_code(self, view_name: str) -> str:
Expand Down Expand Up @@ -197,6 +224,42 @@ async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
raise NotImplementedError(f'Freshness is not implemented for {type(self)}.')


@dataclass
class FilteredDataSource(BatchSourceModification, BatchDataSource):

source: BatchDataSource
condition: DerivedFeature | Feature

type_name: str = 'subset'

def job_group_key(self) -> str:
return f'subset/{self.source.job_group_key()}'

def wrap_job(self, job: RetrivalJob) -> RetrivalJob:
return job.filter(self.condition)


@dataclass
class JoinDataSource(BatchSourceModification, BatchDataSource):

source: BatchDataSource
right_source: BatchDataSource
right_request: RetrivalRequest
left_on: str
right_on: str
method: str

type_name: str = 'join'

def job_group_key(self) -> str:
return f'join/{self.source.job_group_key()}'

def wrap_job(self, job: RetrivalJob) -> RetrivalJob:

right_job = self.right_source.all_data(self.right_request, limit=None)
return job.join(right_job, self.method, (self.left_on, self.right_on))


class ColumnFeatureMappable:
mapping_keys: dict[str, str]

Expand Down
36 changes: 25 additions & 11 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aligned.feature_view.feature_view import FeatureView
from aligned.request.retrival_request import FeatureRequest, RetrivalRequest
from aligned.retrival_job import (
FilterJob,
SelectColumnsJob,
RetrivalJob,
StreamAggregationJob,
SupervisedJob,
Expand Down Expand Up @@ -267,7 +267,7 @@ def features_for_request(
):
raise ValueError(f'Missing {self.event_timestamp_column} in entities')

return self.feature_source.features_for(entity_request, requests).filter(feature_names)
return self.feature_source.features_for(entity_request, requests).select_columns(feature_names)

def features_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, features: list[str]
Expand Down Expand Up @@ -676,7 +676,7 @@ def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> Retr
if subset_request.request_result.feature_columns != request.request_result.feature_columns:
job = job.derive_features(request.needed_requests)

return job.filter(request.features_to_include)
return job.select_columns(request.features_to_include)

async def freshness(self) -> dict[FeatureLocation, datetime]:
from aligned.schemas.feature import EventTimestamp
Expand Down Expand Up @@ -754,22 +754,22 @@ def cached_at(self, location: DataFileReference) -> RetrivalJob:
features = {f'{feature.location.identifier}:{feature.name}' for feature in self.model.features}
request = self.store.requests_for(RawStringFeatureRequest(features))

return FileFullJob(location, RetrivalRequest.unsafe_combine(request.needed_requests)).filter(
return FileFullJob(location, RetrivalRequest.unsafe_combine(request.needed_requests)).select_columns(
request.features_to_include
)

def process_features(self, input: RetrivalJob | ConvertableToRetrivalJob) -> RetrivalJob:
request = self.request()

if isinstance(input, RetrivalJob):
job = input.filter(request.features_to_include)
job = input.select_columns(request.features_to_include)
else:
job = RetrivalJob.from_convertable(input, request=request.needed_requests)

return (
job.ensure_types(request.needed_requests)
.derive_features(request.needed_requests)
.filter(request.features_to_include)
.select_columns(request.features_to_include)
)

def predictions_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> RetrivalJob:
Expand Down Expand Up @@ -958,7 +958,7 @@ def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> Supe
)
job = self.store.features_for_request(total_request, entities, total_request.features_to_include)
return SupervisedJob(
job.filter(total_request.features_to_include),
job.select_columns(total_request.features_to_include),
target_columns=targets,
)

Expand Down Expand Up @@ -1100,7 +1100,7 @@ def all(self, limit: int | None = None) -> RetrivalJob:
.derive_features(request.needed_requests)
)
if self.feature_filter:
return FilterJob(include_features=self.feature_filter, job=job)
return SelectColumnsJob(include_features=self.feature_filter, job=job)
else:
return job

Expand All @@ -1112,7 +1112,9 @@ def between_dates(self, start_date: datetime, end_date: datetime) -> RetrivalJob

if self.feature_filter:
request = self.view.request_for(self.feature_filter)
return FilterJob(self.feature_filter, self.source.all_between(start_date, end_date, request))
return SelectColumnsJob(
self.feature_filter, self.source.all_between(start_date, end_date, request)
)

request = self.view.request_all
return self.source.all_between(start_date, end_date, request)
Expand All @@ -1135,7 +1137,7 @@ def features_for(self, entities: ConvertableToRetrivalJob | RetrivalJob) -> Retr

job = self.source.features_for(entity_job, request)
if self.feature_filter:
return job.filter(self.feature_filter)
return job.select_columns(self.feature_filter)
else:
return job

Expand Down Expand Up @@ -1194,7 +1196,7 @@ async def write(self, values: ConvertableToRetrivalJob) -> None:
job = job.derive_features([request])

if self.feature_filter:
job = job.filter(self.feature_filter)
job = job.select_columns(self.feature_filter)

await self.batch_write(job)

Expand Down Expand Up @@ -1243,3 +1245,15 @@ async def batch_write(self, values: ConvertableToRetrivalJob | RetrivalJob) -> N

with feature_view_write_time.labels(self.view.name).time():
await self.source.write(job, job.retrival_requests)

async def freshness(self) -> datetime:

view = self.view
if not view.event_timestamp:
raise ValueError(
f"View named '{view.name}' have no event timestamp. Therefore, unable to compute freshness"
)

location = FeatureLocation.feature_view(view.name)

return (await self.source.freshness_for({location: view.event_timestamp}))[location]
Loading

0 comments on commit 9972f89

Please sign in to comment.