Skip to content

Commit

Permalink
Passing the store to custom transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Oct 13, 2024
1 parent b441458 commit 9bc1f29
Show file tree
Hide file tree
Showing 11 changed files with 557 additions and 209 deletions.
7 changes: 7 additions & 0 deletions aligned/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
CustomAggregation,
List,
Embedding,
transform_polars,
transform_row,
transform_pandas,
)
from aligned.compiler.model import model_contract, FeatureInputVersions
from aligned.data_source.stream_data_source import HttpStreamSource
Expand Down Expand Up @@ -78,6 +81,10 @@
'EmbeddingModel',
'feature_view',
'model_contract',
# Transformations
'transform_polars',
'transform_row',
'transform_pandas',
# Aggregation
'CustomAggregation',
# Schemas
Expand Down
2 changes: 1 addition & 1 deletion aligned/compiler/aggregation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def compile(self) -> Transformation:
from aligned.schemas.transformation import PolarsFunctionTransformation, PolarsLambdaTransformation

if isinstance(self.method, pl.Expr):
method = lambda df, alias: self.method # noqa: E731
method = lambda df, alias, store: self.method # noqa: E731
code = ''
return PolarsLambdaTransformation(method=dill.dumps(method), code=code, dtype=self.dtype.dtype)
else:
Expand Down
23 changes: 14 additions & 9 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if TYPE_CHECKING:
from aligned.sources.s3 import AwsS3Config
from aligned.feature_store import ContractStore


class TransformationFactory:
Expand Down Expand Up @@ -527,7 +528,9 @@ def fill_na(self: T, value: FeatureFactory | Any) -> T:
return instance # type: ignore [return-value]

def transformed_using_features_pandas(
self: T, using_features: list[FeatureFactory], transformation: Callable[[pd.DataFrame], pd.Series]
self: T,
using_features: list[FeatureFactory],
transformation: Callable[[pd.DataFrame, ContractStore], pd.Series],
) -> T:
from aligned.compiler.transformation_factory import PandasTransformationFactory

Expand All @@ -536,7 +539,9 @@ def transformed_using_features_pandas(
dtype.transformation = PandasTransformationFactory(dtype, transformation, using_features or [self])
return dtype # type: ignore [return-value]

def transform_pandas(self, transformation: Callable[[pd.DataFrame], pd.Series], as_dtype: T) -> T:
def transform_pandas(
self, transformation: Callable[[pd.DataFrame, ContractStore], pd.Series], as_dtype: T
) -> T:
from aligned.compiler.transformation_factory import PandasTransformationFactory

dtype: FeatureFactory = as_dtype # type: ignore [assignment]
Expand All @@ -547,7 +552,7 @@ def transform_pandas(self, transformation: Callable[[pd.DataFrame], pd.Series],
def transformed_using_features_polars(
self: T,
using_features: list[FeatureFactory],
transformation: Callable[[pl.LazyFrame, str], pl.LazyFrame] | pl.Expr,
transformation: Callable[[pl.LazyFrame, str, ContractStore], pl.LazyFrame] | pl.Expr,
) -> T:
from aligned.compiler.transformation_factory import PolarsTransformationFactory

Expand Down Expand Up @@ -1845,8 +1850,8 @@ def percentile(self, percentile: float) -> Float:

def transform_polars(
using_features: list[FeatureFactory], return_type: T
) -> Callable[[Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]], T]:
def wrapper(method: Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]) -> T:
) -> Callable[[Callable[[Any, pl.LazyFrame, str, ContractStore], pl.LazyFrame]], T]:
def wrapper(method: Callable[[Any, pl.LazyFrame, str, ContractStore], pl.LazyFrame]) -> T:
return return_type.transformed_using_features_polars(
using_features=using_features, transformation=method # type: ignore
)
Expand All @@ -1856,8 +1861,8 @@ def wrapper(method: Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]) -> T:

def transform_pandas(
using_features: list[FeatureFactory], return_type: T
) -> Callable[[Callable[[Any, pd.DataFrame], pd.Series]], T]:
def wrapper(method: Callable[[Any, pd.DataFrame], pd.Series]) -> T:
) -> Callable[[Callable[[Any, pd.DataFrame, ContractStore], pd.Series]], T]:
def wrapper(method: Callable[[Any, pd.DataFrame, ContractStore], pd.Series]) -> T:
return return_type.transformed_using_features_pandas(
using_features=using_features, transformation=method # type: ignore
)
Expand All @@ -1867,8 +1872,8 @@ def wrapper(method: Callable[[Any, pd.DataFrame], pd.Series]) -> T:

def transform_row(
using_features: list[FeatureFactory], return_type: T
) -> Callable[[Callable[[Any, dict[str, Any]], Any]], T]:
def wrapper(method: Callable[[Any, dict[str, Any]], Any]) -> T:
) -> Callable[[Callable[[Any, dict[str, Any], ContractStore], Any]], T]:
def wrapper(method: Callable[[Any, dict[str, Any], ContractStore], Any]) -> T:
from aligned.compiler.transformation_factory import MapRowTransformation

new_value = return_type.copy_type()
Expand Down
9 changes: 6 additions & 3 deletions aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import dataclass, field
from datetime import timedelta # noqa: TC003
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable

import polars as pl

Expand All @@ -12,6 +12,9 @@
from aligned.compiler.feature_factory import FeatureFactory, Transformation, TransformationFactory
from aligned.schemas.transformation import FillNaValuesColumns, LiteralValue, EmbeddingModel

if TYPE_CHECKING:
from aligned.feature_store import ContractStore

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -683,7 +686,7 @@ def compile(self) -> Transformation:
class PandasTransformationFactory(TransformationFactory):

dtype: FeatureFactory
method: Callable[[pd.DataFrame], pd.Series]
method: Callable[[pd.DataFrame, ContractStore], pd.Series]
_using_features: list[FeatureFactory]

@property
Expand Down Expand Up @@ -739,7 +742,7 @@ def compile(self) -> Transformation:
class PolarsTransformationFactory(TransformationFactory):

dtype: FeatureFactory
method: pl.Expr | Callable[[pl.LazyFrame, pl.Expr], pl.LazyFrame]
method: pl.Expr | Callable[[pl.LazyFrame, pl.Expr, ContractStore], pl.LazyFrame]
_using_features: list[FeatureFactory]

@property
Expand Down
20 changes: 16 additions & 4 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def features_for(
return entities

new_request = FeatureRequest(requests.location, requests.features_to_include, loaded_requests)
return self.features_for_request(new_request, entities, feature_names)
return self.features_for_request(new_request, entities, feature_names).inject_store(self)

def model(self, model: str | ModelContractWrapper) -> ModelFeatureStore:
"""
Expand All @@ -485,7 +485,10 @@ def model(self, model: str | ModelContractWrapper) -> ModelFeatureStore:

return ModelFeatureStore(self.models[name], self)

def vector_index(self, name: str) -> VectorIndexStore:
def vector_index(self, name: str | ModelContractWrapper) -> VectorIndexStore:
if isinstance(name, ModelContractWrapper):
name = name.location.name

return VectorIndexStore(self, self.vector_indexes[name], index_name=name)

def event_triggers_for(self, feature_view: str) -> set[EventTrigger]:
Expand Down Expand Up @@ -659,6 +662,10 @@ class MyFeatureView:
if view.name in self.feature_views:
raise ValueError(f'Feature view with name "{view.name}" already exists')

if isinstance(view.source, VectorIndex):
index_name = view.source.vector_index_name() or view.name
self.vector_indexes[index_name] = view

self.feature_views[view.name] = view
if isinstance(self.feature_source, BatchFeatureSource):
assert isinstance(self.feature_source.sources, dict)
Expand Down Expand Up @@ -1295,7 +1302,11 @@ def all_predictions(self, limit: int | None = None) -> RetrivalJob:
source = selected_source.sources[location.identifier]
request = self.model.predictions_view.request(self.model.name)

return source.all_data(request, limit=limit).select_columns(set(request.all_returned_columns))
return (
source.all_data(request, limit=limit)
.inject_store(self.store)
.select_columns(set(request.all_returned_columns))
)

def using_source(self, source: FeatureSourceable | BatchDataSource) -> ModelFeatureStore:

Expand Down Expand Up @@ -1650,6 +1661,7 @@ def all_columns(self, limit: int | None = None) -> RetrivalJob:
self.source.all_for(request, limit)
.ensure_types(request.needed_requests)
.derive_features(request.needed_requests)
.inject_store(self.store)
)
if self.feature_filter:
selected_columns = self.feature_filter
Expand All @@ -1671,7 +1683,7 @@ def between_dates(self, start_date: datetime, end_date: datetime) -> RetrivalJob
)

request = self.view.request_all
return self.source.all_between(start_date, end_date, request)
return self.source.all_between(start_date, end_date, request).inject_store(self.store)

def previous(self, days: int = 0, minutes: int = 0, seconds: int = 0) -> RetrivalJob:
end_date = datetime.utcnow()
Expand Down
52 changes: 47 additions & 5 deletions aligned/jobs/tests/test_combined_job.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,81 @@
import pytest

from aligned import feature_view, String, Bool
from aligned.compiler.model import model_contract
from aligned.feature_store import ContractStore
from aligned.sources.in_mem_source import InMemorySource
from aligned.retrival_job import CombineFactualJob, RetrivalJob, RetrivalRequest
from aligned.compiler.feature_factory import transform_polars, transform_pandas, transform_row
from aligned.compiler.feature_factory import (
Embedding,
List,
transform_polars,
transform_pandas,
transform_row,
)

import polars as pl
from aligned.lazy_imports import pandas as pd


@model_contract(
name='test_embedding',
input_features=[],
output_source=InMemorySource.from_values(
{
'vec_id': ['a', 'b', 'c'],
'value': ['hello there', 'no', 'something else'],
'embedding': [[1, 2], [1, 0], [0, 9]],
}
),
)
class TestEmbedding:
vec_id = String().as_entity()
value = String()
embedding = Embedding(embedding_size=2)


@feature_view(source=InMemorySource.empty())
class CombinedData:
query = String()
contains_mr = query.contains('mr')
embedding = Embedding(embedding_size=2)

@transform_polars(using_features=[query], return_type=Bool())
def contains_something(self, df: pl.LazyFrame, return_value: str) -> pl.LazyFrame:
def contains_something(self, df: pl.LazyFrame, return_value: str, store: ContractStore) -> pl.LazyFrame:
return df.with_columns((pl.col('query').str.len_chars() > 5).alias(return_value))

@transform_pandas(using_features=[query], return_type=String())
def append_someting(self, df: pd.DataFrame) -> pd.Series:
def append_someting(self, df: pd.DataFrame, store: ContractStore) -> pd.Series:
return df['query'] + ' something'

@transform_row(using_features=[query], return_type=String())
def using_row(self, row: dict) -> str:
def using_row(self, row: dict, store: ContractStore) -> str:
return row['query'] + ' something'

@transform_row(using_features=[embedding], return_type=List(String()))
async def related_entities(self, row: dict, store: ContractStore) -> list[str]:
df = (
await store.vector_index('test_embedding')
.nearest_n_to(entities=[row], number_of_records=2)
.to_polars()
)
print(df)
return df['vec_id'].to_list()

not_contains = contains_something.not_equals(True)


@pytest.mark.asyncio
async def test_feature_view_without_entity():
store = ContractStore.empty()
store.add_model(TestEmbedding)
store.add_feature_view(CombinedData)

job = CombinedData.query().features_for({'query': ['Hello', 'Hello mr']})
job = store.feature_view(CombinedData).features_for(
{'query': ['Hello', 'Hello mr'], 'embedding': [[1, 3], [0, 10]]}
)
df = await job.to_polars()
print(df)

assert df['contains_mr'].sum() == 1

Expand Down
12 changes: 9 additions & 3 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ async def to_lazy_polars(self) -> pl.LazyFrame:


async def aggregate(request: RetrivalRequest, core_data: pl.LazyFrame) -> pl.LazyFrame:
from aligned import ContractStore

aggregate_over = request.aggregate_over()

Expand All @@ -65,7 +66,9 @@ async def aggregate(request: RetrivalRequest, core_data: pl.LazyFrame) -> pl.Laz

exprs = []
for feat in aggregate_over[first_over]:
tran = await feat.derived_feature.transformation.transform_polars(core_data, feat.name)
tran = await feat.derived_feature.transformation.transform_polars(
core_data, feat.name, ContractStore.empty()
)

if not isinstance(tran, pl.Expr):
raise ValueError(f'Aggregation needs to be an expression, got {tran}')
Expand All @@ -86,7 +89,9 @@ async def aggregate(request: RetrivalRequest, core_data: pl.LazyFrame) -> pl.Laz
for over, features in aggregate_over.items():
exprs = []
for feat in features:
tran = await feat.derived_feature.transformation.transform_polars(core_data, feat.name)
tran = await feat.derived_feature.transformation.transform_polars(
core_data, feat.name, ContractStore.empty()
)

if not isinstance(tran, pl.Expr):
raise ValueError(f'Aggregation needs to be an expression, got {tran}')
Expand Down Expand Up @@ -321,6 +326,7 @@ async def aggregate_over(
event_timestamp_col: str,
group_by: list[str] | None = None,
) -> pl.LazyFrame:
from aligned import ContractStore

if not group_by:
group_by = ['row_id']
Expand All @@ -338,7 +344,7 @@ async def aggregate_over(
transformations = []
for feature in features:
expr = await feature.derived_feature.transformation.transform_polars(
subset, feature.derived_feature.name
subset, feature.derived_feature.name, ContractStore.empty()
)
if isinstance(expr, pl.Expr):
transformations.append(expr.alias(feature.name))
Expand Down
Loading

0 comments on commit 9bc1f29

Please sign in to comment.