Skip to content

Commit

Permalink
Improved sequential model prediciton
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed May 1, 2024
1 parent aa96c6b commit 829cb6b
Show file tree
Hide file tree
Showing 35 changed files with 761 additions and 328 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ store.add_model(MyModel)
This makes it possible to define different contracts per project, or team. As a result, you can also combine differnet stores with.

```python
forecasting_store = await ContractStore.from_dir("path/for/forecasting")
recommendation_store = await ContractStore.from_dir("path/for/recommendation")

combined_store = recommendation_store.combined_with(forecasting_store)
```

Expand Down
2 changes: 2 additions & 0 deletions aligned/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Bool,
Entity,
EventTimestamp,
ValidFrom,
Float,
Int8,
Int16,
Expand Down Expand Up @@ -69,6 +70,7 @@
'Int64',
'Float',
'EventTimestamp',
'ValidFrom',
'Timestamp',
'List',
'Embedding',
Expand Down
22 changes: 13 additions & 9 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,10 +1118,10 @@ def dtype(self) -> FeatureType:
def aggregate(self) -> StringAggregation:
return StringAggregation(self)

def ollama_embedding(self, model: str, host_env: str | None = None) -> Embedding:
def ollama_embedding(self, model: str, embedding_size: int, host_env: str | None = None) -> Embedding:
from aligned.compiler.transformation_factory import OllamaEmbedding

feature = Embedding()
feature = Embedding(embedding_size)
feature.transformation = OllamaEmbedding(model, self, host_env)
return feature

Expand Down Expand Up @@ -1156,9 +1156,8 @@ def contains(self, value: str) -> Bool:
def sentence_vector(self, model: EmbeddingModel) -> Embedding:
from aligned.compiler.transformation_factory import WordVectoriserFactory

feature = Embedding()
feature = Embedding(model.embedding_size or 0)
feature.transformation = WordVectoriserFactory(self, model)
feature.embedding_size = model.embedding_size
return feature

def embedding(self, model: EmbeddingModel) -> Embedding:
Expand Down Expand Up @@ -1291,20 +1290,25 @@ def event_timestamp(self) -> EventTimestampFeature:
)


ValidFrom = EventTimestamp


@dataclass
class Embedding(FeatureFactory):

sub_type: FeatureFactory
embedding_size: int | None = None
embedding_size: int
indexes: list[VectorIndexFactory] | None = None
sub_type: FeatureFactory = field(default_factory=Float)

def copy_type(self) -> Embedding:
if self.constraints and Optional() in self.constraints:
return Embedding().is_optional()
return Embedding()
return Embedding(sub_type=self.sub_type, embedding_size=self.embedding_size).is_optional()

return Embedding(sub_type=self.sub_type, embedding_size=self.embedding_size)

@property
def dtype(self) -> FeatureType:
return FeatureType.embedding()
return FeatureType.embedding(self.embedding_size or 0)

def indexed(
self,
Expand Down
30 changes: 16 additions & 14 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import TYPE_CHECKING, Awaitable, TypeVar, Any, Callable, Coroutine
from dataclasses import dataclass
from uuid import uuid4

from mashumaro.types import SerializableType
from aligned.data_file import DataFileReference
Expand Down Expand Up @@ -350,6 +349,9 @@ def location_id(self) -> set[FeatureLocation]:
def depends_on(self) -> set[FeatureLocation]:
return set()

def tags(self) -> list[str]:
return [self.type_name]


@dataclass
class CustomMethodDataSource(BatchDataSource):
Expand Down Expand Up @@ -469,10 +471,10 @@ def multi_source_features_for(
)
source, request = requests[0]

if isinstance(source.condition, Feature):
request.features.add(source.condition)
else:
if isinstance(source.condition, DerivedFeature):
request.derived_features.add(source.condition)
else:
request.features.add(source.condition)

return source.source.features_for(facts, request).filter(source.condition)

Expand All @@ -497,10 +499,10 @@ def all_between_dates(

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

if isinstance(self.condition, Feature):
request.features.add(self.condition)
else:
if isinstance(self.condition, DerivedFeature):
request.derived_features.add(self.condition)
else:
request.features.add(self.condition)

return (
self.source.all_data(request, limit)
Expand Down Expand Up @@ -801,7 +803,7 @@ class StackSource(BatchDataSource):
type_name: str = 'stack'

@property
def source_column_config(self):
def source_column_config(self): # type: ignore
from aligned.retrival_job import StackSourceColumn

if not self.source_column:
Expand All @@ -813,7 +815,7 @@ def source_column_config(self):
source_column=self.source_column,
)

def sub_request(self, request: RetrivalRequest, config) -> RetrivalRequest:
def sub_request(self, request: RetrivalRequest, config) -> RetrivalRequest: # type: ignore
return RetrivalRequest(
name=request.name,
location=request.location,
Expand Down Expand Up @@ -1041,7 +1043,7 @@ def data_for_request(request: RetrivalRequest, size: int) -> pl.DataFrame:
needed_features = request.features.union(request.entities)
schema = {feature.name: feature.dtype.polars_type for feature in needed_features}

exprs = []
exprs = {}

for feature in needed_features:
dtype = feature.dtype
Expand Down Expand Up @@ -1073,7 +1075,7 @@ def data_for_request(request: RetrivalRequest, size: int) -> pl.DataFrame:
if is_unique:
values = np.arange(0, size, dtype=dtype.pandas_type)
else:
values = np.random.random(size)
values = np.random.random(size) * 1000

if max_value is not None:
values = values * max_value
Expand All @@ -1093,12 +1095,12 @@ def data_for_request(request: RetrivalRequest, size: int) -> pl.DataFrame:
if is_optional:
values = np.where(np.random.random(size) > 0.5, values, np.NaN)

exprs.append(pl.lit(values).alias(feature.name))
exprs[feature.name] = values

return pl.DataFrame(exprs, schema=schema)


class DummyDataBatchSource(BatchDataSource):
class DummyDataSource(BatchDataSource):
"""
The DummyDataBatchSource is a data source that generates random data for a given request.
This can be useful for testing and development purposes.
Expand All @@ -1123,7 +1125,7 @@ class MyView:
type_name: str = 'dummy_data'

def job_group_key(self) -> str:
return str(uuid4())
return self.type_name

@classmethod
def multi_source_features_for(
Expand Down
85 changes: 85 additions & 0 deletions aligned/data_source/model_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime

from aligned.feature_store import ModelFeatureStore
from aligned.request.retrival_request import RetrivalRequest
from aligned.schemas.feature import FeatureLocation, FeatureType
from aligned.schemas.model import Model
from aligned.retrival_job import RetrivalJob


@dataclass
class PredictModelSource:

store: ModelFeatureStore
type_name: str = 'pred_model_source'

@property
def model(self) -> Model:
return self.store.model

def job_group_key(self) -> str:
loc = FeatureLocation.model(self.model.name).identifier
return f"{loc}_pred"

def location_id(self) -> set[FeatureLocation]:
return {FeatureLocation.model(self.model.name)}

async def schema(self) -> dict[str, FeatureType]:
if self.model.predictions_view.source:
return await self.model.predictions_view.source.schema()
return {}

def all_data(self, request: RetrivalRequest, limit: int | None = None) -> RetrivalJob:
reqs = self.store.request()
if len(reqs.needed_requests) != 1:
raise NotImplementedError(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

location = reqs.needed_requests[0].location
if location.location != 'feature_view':
raise NotImplementedError(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

entities = self.store.store.feature_view(location.name).all_columns(limit=limit)
return self.store.predict_over(entities).with_request([request])

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:
reqs = self.store.request()
if len(reqs.needed_requests) != 1:
raise NotImplementedError(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

location = reqs.needed_requests[0].location
if location.location != 'feature_view':
raise NotImplementedError(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

entities = self.store.store.feature_view(location.name).between_dates(start_date, end_date)
return self.store.predict_over(entities).with_request([request])

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return self.store.predict_over(facts).with_request([request])

@classmethod
def multi_source_features_for(
cls, facts: RetrivalJob, requests: list[tuple[PredictModelSource, RetrivalRequest]]
) -> RetrivalJob:

if len(requests) != 1:
raise NotImplementedError(
f'Type: {cls} have not implemented how to load fact data with multiple sources.'
)

source, _ = requests[0]
return source.features_for(facts, requests[0][1])

def depends_on(self) -> set[FeatureLocation]:
return {FeatureLocation.model(self.model.name)}
56 changes: 55 additions & 1 deletion aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import polars as pl
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from dataclasses import dataclass
from aligned.retrival_job import RetrivalJob
from aligned.schemas.codable import Codable
Expand Down Expand Up @@ -290,3 +290,57 @@ async def run_polars(self, values: RetrivalJob, store: ModelFeatureStore) -> pl.

def ab_test_model(models: list[tuple[ExposedModel, float]]) -> ABTestModel:
return ABTestModel(models=models)


@dataclass
class DillFunction(ExposedModel):

function: bytes

model_type: str = 'dill_function'

@property
def exposed_at_url(self) -> str | None:
return None

@property
def as_markdown(self) -> str:
return 'A function stored in a dill file.'

async def needed_features(self, store: ModelFeatureStore) -> list[FeatureReference]:
default = store.model.features.default_version
return store.feature_references_for(store.selected_version or default)

async def needed_entities(self, store: ModelFeatureStore) -> set[Feature]:
return store.request().request_result.entities

async def run_polars(self, values: RetrivalJob, store: ModelFeatureStore) -> pl.DataFrame:
import dill
import inspect

function = dill.loads(self.function)
if inspect.iscoroutinefunction(function):
return await function(values, store)
else:
return function(values, store)


def python_function(function: Callable[[pl.DataFrame], pl.Series]) -> DillFunction:
import dill

async def function_wrapper(values: RetrivalJob, store: ModelFeatureStore) -> pl.DataFrame:

pred_columns = store.model.predictions_view.labels()
if len(pred_columns) != 1:
raise ValueError(f"Expected exactly one prediction column, got {len(pred_columns)} columns.")

feature_request = store.features_for(values).log_each_job()
input_features = feature_request.request_result.feature_columns
features = await feature_request.to_polars()

result = features.with_columns(
function(features.select(input_features)).alias(list(pred_columns)[0].name)
)
return result

return DillFunction(function=dill.dumps(function_wrapper))
Loading

0 comments on commit 829cb6b

Please sign in to comment.