Skip to content

Commit

Permalink
Migrated to polars 1 (#30)
Browse files Browse the repository at this point in the history
* Migrated to polars 1

* Minor improvements

* Updated read call

* another one
  • Loading branch information
MatsMoll authored Oct 28, 2024
1 parent e1f30f4 commit f845d59
Show file tree
Hide file tree
Showing 27 changed files with 575 additions and 322 deletions.
26 changes: 23 additions & 3 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,24 @@ def depends_on(self) -> set[FeatureLocation]:
class FilteredDataSource(CodableBatchDataSource):

source: CodableBatchDataSource
condition: DerivedFeature | Feature | str
condition: DerivedFeature | Feature | str | bytes

type_name: str = 'subset'

@property
def polars_expression(self) -> pl.Expr:
if isinstance(self.condition, bytes):
return pl.Expr.deserialize(self.condition, format='binary')
elif isinstance(self.condition, str):
try:
return pl.Expr.deserialize(self.condition, format='json')
except:
return pl.col(self.condition)
elif isinstance(self.condition, (DerivedFeature, Feature)):
return pl.col(self.condition.name)
else:
raise ValueError(f"Unable to `{self.condition}`")

def job_group_key(self) -> str:
return f'subset/{self.source.job_group_key()}'

Expand Down Expand Up @@ -518,7 +532,7 @@ def all_between_dates(

return (
self.source.all_between_dates(request, start_date, end_date)
.filter(self.condition)
.filter(self.polars_expression)
.aggregate(request)
.derive_features([request])
)
Expand All @@ -527,12 +541,18 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

if isinstance(self.condition, DerivedFeature):
request.derived_features.add(self.condition)
return (
self.source.all_data(request, limit)
.aggregate(request)
.derive_features([request])
.filter(self.polars_expression)
)
elif isinstance(self.condition, Feature):
request.features.add(self.condition)

return (
self.source.all_data(request, limit)
.filter(self.condition)
.filter(self.polars_expression)
.aggregate(request)
.derive_features([request])
)
Expand Down
26 changes: 26 additions & 0 deletions aligned/data_source/model_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class PredictModelSource(BatchDataSource):

store: ModelFeatureStore
cache_source: BatchDataSource | None = None
type_name: str = 'pred_model_source'

@property
Expand Down Expand Up @@ -77,6 +78,31 @@ def all_between_dates(
return self.store.predict_over(entities).with_request([request])

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
import polars as pl

if self.cache_source:
preds = self.cache_source.features_for(facts, request)

async def add_missing(df: pl.LazyFrame) -> pl.LazyFrame:
request.feature_names
full_features = df.filter(
pl.all_horizontal([pl.col(feat.name).is_not_null() for feat in request.features])
)
missing_features = df.filter(
pl.all_horizontal([pl.col(feat.name).is_not_null() for feat in request.features]).not_()
)
preds = await self.store.predict_over(
missing_features.select(request.entity_names)
).to_polars()

return (
full_features.collect()
.vstack(preds.select(full_features.columns).cast(full_features.schema)) # type: ignore
.lazy()
)

return preds.transform_polars(add_missing)

return self.store.predict_over(facts).with_request([request])

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ async def function_wrapper(values: RetrivalJob, store: ModelFeatureStore) -> pl.


def openai_embedding(
model: str, batch_on_n_chunks: int | None, prompt_template: str | None = None
model: str, batch_on_n_chunks: int | None = 100, prompt_template: str | None = None
) -> ExposedModel:
"""
Returns an OpenAI embedding model.
Expand Down Expand Up @@ -469,4 +469,6 @@ class MyEmbedding:
"""
from aligned.exposed_model.openai import OpenAiEmbeddingPredictor

return OpenAiEmbeddingPredictor(model=model, prompt_template=prompt_template or '')
return OpenAiEmbeddingPredictor(
model=model, batch_on_n_chunks=batch_on_n_chunks, prompt_template=prompt_template or ''
)
14 changes: 9 additions & 5 deletions aligned/exposed_model/ollama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from aligned.compiler.model import ModelContractWrapper
from aligned.compiler.feature_factory import (
Expand Down Expand Up @@ -108,10 +108,11 @@ class OllamaEmbeddingPredictorWithRef(ExposedModel, PromptModel):
endpoint: str
model_name: str

embedding_name: str = ''
feature_references: list[FeatureReference] = []
embedding_name: str = field(default='')
feature_references: list[FeatureReference] = field(default_factory=list)
prompt_template: str = field(default='')

precomputed_prompt_key_overwrite: str = 'full_prompt'
prompt_template: str = ''
model_type: str = 'ollama_embedding'

@property
Expand Down Expand Up @@ -328,7 +329,10 @@ async def run_polars(self, values: RetrivalJob, store: ModelFeatureStore) -> pl.
missing_cols = set(expected_cols) - set(entities.columns)
if missing_cols:
entities = (
await store.using_version(self.input_features_versions).features_for(values).to_polars()
await store.using_version(self.input_features_versions)
.features_for(values)
.with_subfeatures()
.to_polars()
)

for index, value in enumerate(entities.iter_rows(named=True)):
Expand Down
2 changes: 1 addition & 1 deletion aligned/exposed_model/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def embed_texts(

chunk_size += token_size

if number_of_texts - 1 > chunks[-1]:
if not chunks or number_of_texts - 1 > chunks[-1]:
chunks.append(number_of_texts - 1)

embeddings: list[list[float]] = []
Expand Down
69 changes: 64 additions & 5 deletions aligned/exposed_model/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_mlflow() -> None:

mlflow_client = MlflowClient()

with suppress(mlflow.exceptions.MlflowException):
with suppress(mlflow.MlflowException):
mlflow_client.delete_registered_model(model_name)

def predict(data):
Expand Down Expand Up @@ -200,7 +200,7 @@ class MyModelContract2:
.to_polars()
)
assert preds['other_pred'].null_count() == 1
assert not first_preds['model_version'].series_equal(preds['model_version'])
assert not first_preds['model_version'].equals(preds['model_version'])

preds = (
await store.model(MyModelContract2)
Expand All @@ -213,7 +213,7 @@ class MyModelContract2:
.to_polars()
)
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])
assert not first_preds['model_version'].equals(preds['model_version'])

preds = (
await without_cache.model(MyModelContract2)
Expand All @@ -225,7 +225,7 @@ class MyModelContract2:
.to_polars()
)
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])
assert not first_preds['model_version'].equals(preds['model_version'])

preds = (
await without_cache.model(MyModelContract2)
Expand All @@ -235,4 +235,63 @@ class MyModelContract2:
input_features = InputFeatureView.query().request.all_returned_columns
assert set(input_features) - set(preds.columns) == set(), 'Missing some columns'
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])
assert not first_preds['model_version'].equals(preds['model_version'])


@pytest.mark.asyncio
async def test_if_is_missing() -> None:
@feature_view(
name='input',
source=InMemorySource.from_values(
{'entity_id': ['a', 'b', 'c'], 'x': [1, 2, 3], 'other': [9, 8, 7]} # type: ignore
),
)
class InputFeatureView:
entity_id = String().as_entity()
x = Int32()
other = Int32()

input = InputFeatureView()

@model_contract(
input_features=[InputFeatureView().x],
exposed_model=python_function(lambda df: df['x'] * 2),
output_source=InMemorySource.from_values(
{'entity_id': ['a', 'b'], 'prediction': [4, 4]} # type: ignore
),
)
class MyModelContract:
entity_id = String().as_entity()

prediction = input.x.as_regression_target()

model_version = String().as_model_version()

@model_contract(
input_features=[InputFeatureView().x, MyModelContract().prediction],
exposed_model=python_function(lambda df: df['prediction'] * 3 + df['x']),
)
class MyModelContract2:
entity_id = String().as_entity()

other_pred = input.other.as_regression_target()

model_version = String().as_model_version()

store = ContractStore.empty()
store.add_view(InputFeatureView)
store.add_model(MyModelContract)
store.add_model(MyModelContract2)

predict_when_missing = store.predict_when_missing()
preds = (
await predict_when_missing.model(MyModelContract2)
.predict_over(
{
'entity_id': ['a', 'c'],
}
)
.to_polars()
)
assert preds['other_pred'].null_count() == 0
assert preds['prediction'].null_count() == 0
80 changes: 79 additions & 1 deletion aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def dummy_store(self) -> ContractStore:
self.feature_views, self.models, BatchFeatureSource(sources), self.vector_indexes
)

def source_for(self, location: FeatureLocation) -> BatchDataSource | None:
if not isinstance(self.feature_source, BatchFeatureSource):
return None
return self.feature_source.sources.get(location.identifier)

def predict_when_missing(self) -> ContractStore:
from aligned.data_source.model_predictor import PredictModelSource

new_store = self

for model_name, model in self.models.items():
if not model.exposed_model:
continue

new_store = new_store.update_source_for(
FeatureLocation.model(model_name),
PredictModelSource(
new_store.model(model_name),
cache_source=self.source_for(FeatureLocation.model(model_name)),
),
)

return new_store

def without_model_cache(self) -> ContractStore:
from aligned.data_source.model_predictor import PredictModelSource

Expand Down Expand Up @@ -677,6 +701,55 @@ class MyFeatureView:
view.materialized_source or view.source
)

def remove(self, location: str | FeatureLocation) -> None:
"""
Removing a feature view or a model contract from the store.
```python
store.remove("feature_view:titanic")
# or
location = FeatureLocation.feature_view("titanic")
store.remove(location)
```
Args:
location (str | FeatureLocation): The contract to remove
"""
if isinstance(location, str):
location = FeatureLocation.from_string(location)

if location.location_type == 'feature_view':
del self.feature_views[location.name]
else:
del self.models[location.name]

if not isinstance(self.feature_source, BatchFeatureSource):
return
if not isinstance(self.feature_source.sources, dict):
return
del self.feature_source.sources[location.identifier]

def add(self, contract: FeatureViewWrapper | ModelContractWrapper) -> None:
"""
Adds a feature view or a model contract
```python
@feature_view(...)
class MyFeatures:
feature_id = String().as_entity()
feature = Int32()
store.add(MyFeatures)
```
Args:
contract (FeatureViewWrapper | ModelContractWrappe): The contract to add
"""
if isinstance(contract, FeatureViewWrapper):
self.add_feature_view(contract)
else:
self.add_model(contract)

def add_feature_view(self, feature_view: FeatureView | FeatureViewWrapper | CompiledFeatureView) -> None:
if isinstance(feature_view, FeatureViewWrapper):
self.add_compiled_view(feature_view.compile())
Expand Down Expand Up @@ -1389,7 +1462,12 @@ class NewModel:
```
"""
return {req.location for req in self.request().needed_requests}
locs = {req.location for req in self.request().needed_requests}
label_refs = self.model.predictions_view.labels_estimates_refs()
if label_refs:
for ref in label_refs:
locs.add(ref.location)
return locs

async def upsert_predictions(self, predictions: ConvertableToRetrivalJob | RetrivalJob) -> None:
"""
Expand Down
5 changes: 3 additions & 2 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LiteralRetrivalJob(RetrivalJob):

def __init__(self, df: pl.LazyFrame | pd.DataFrame, requests: list[RetrivalRequest]) -> None:
self.requests = requests

if isinstance(df, pl.DataFrame):
self.df = df.lazy()
elif isinstance(df, pl.LazyFrame):
Expand Down Expand Up @@ -113,10 +114,10 @@ async def aggregate(request: RetrivalRequest, core_data: pl.LazyFrame) -> pl.Laz
.with_columns(pl.col(time_name) + over.window.time_window)
).filter(pl.col(time_name) <= sorted_data.select(pl.col(time_name).max()).collect()[0, 0])
else:
sub = sorted_data.group_by_rolling(
sub = sorted_data.rolling(
time_name,
period=over.window.time_window,
by=over.group_by_names,
group_by=over.group_by_names,
).agg(exprs)

if over.window.offset_interval:
Expand Down
2 changes: 1 addition & 1 deletion aligned/psql/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def to_pandas(self) -> pd.DataFrame:

async def to_lazy_polars(self) -> pl.LazyFrame:
try:
return pl.read_database(self.query, self.config.url).lazy()
return pl.read_database_uri(self.query, self.config.url).lazy()
except Exception as e:
logger.error(f'Error running query: {self.query}')
logger.error(f'Error: {e}')
Expand Down
Loading

0 comments on commit f845d59

Please sign in to comment.