Skip to content

Commit

Permalink
Improved handeling of datetimes
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 25, 2024
1 parent 7ddf763 commit 1001fb4
Show file tree
Hide file tree
Showing 20 changed files with 915 additions and 141 deletions.
6 changes: 5 additions & 1 deletion aligned/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from aligned.feature_view import feature_view, combined_feature_view, check_schema
from aligned.schemas.text_vectoriser import EmbeddingModel
from aligned.sources.kafka import KafkaConfig
from aligned.sources.local import FileSource
from aligned.sources.local import FileSource, Directory, ParquetConfig, CsvConfig
from aligned.sources.psql import PostgreSQLConfig
from aligned.sources.redis import RedisConfig
from aligned.sources.redshift import RedshiftSQLConfig
Expand Down Expand Up @@ -66,4 +66,8 @@
'FeatureLocation',
'FeatureInputVersions',
'check_schema',
'Directory',
# File Configs
'CsvConfig',
'ParquetConfig',
]
3 changes: 3 additions & 0 deletions aligned/feature_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> Non
async def upsert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
raise NotImplementedError(f'Upsert write is not implemented for {type(self)}.')

async def overwrite(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
raise NotImplementedError(f'Overwrite write is not implemented for {type(self)}.')


class RangeFeatureSource:
def all_for(self, request: FeatureRequest, limit: int | None = None) -> RetrivalJob:
Expand Down
18 changes: 13 additions & 5 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,28 +124,36 @@ async def aggregate(request: RetrivalRequest, core_data: pl.LazyFrame) -> pl.Laz

def decode_timestamps(df: pl.LazyFrame, request: RetrivalRequest, formatter: DateFormatter) -> pl.LazyFrame:

columns: set[str] = set()
columns: set[tuple[str, str | None]] = set()
dtypes = dict(zip(df.columns, df.dtypes))

for feature in request.all_features:
if (
feature.dtype == FeatureType.datetime
feature.dtype.is_datetime
and feature.name in df.columns
and not isinstance(dtypes[feature.name], pl.Datetime)
):
columns.add(feature.name)
columns.add((feature.name, None))

if (
request.event_timestamp
and request.event_timestamp.name in df.columns
and not isinstance(dtypes[request.event_timestamp.name], pl.Datetime)
):
columns.add(request.event_timestamp.name)
columns.add((request.event_timestamp.name, None))

if not columns:
return df

return df.with_columns([formatter.decode_polars(column).alias(column) for column in columns])
exprs = []

for column, time_zone in columns:
if time_zone is None:
exprs.append(formatter.decode_polars(column).alias(column))
else:
exprs.append(formatter.decode_polars(column).dt.convert_time_zone(time_zone).alias(column))

return df.with_columns(exprs)


@dataclass
Expand Down
25 changes: 25 additions & 0 deletions aligned/local/tests/test_directory_interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from aligned import Directory, AwsS3Config, FileSource, ParquetConfig, CsvConfig


def test_directory_interfaces() -> None:
aws_config = AwsS3Config('', '', '', '')

dirs = [aws_config, FileSource]

parquet_config = ParquetConfig(compression='snappy')
csv_config = CsvConfig(seperator=',')
mapping_keys = {'key': 'value'}

for config in dirs:

directory: Directory = config.directory('path')

sub_dir = directory.sub_directory('sub_path')

parquet = sub_dir.parquet_at('test.parquet', mapping_keys=mapping_keys, config=parquet_config)
csv = sub_dir.csv_at('test.csv', mapping_keys=mapping_keys, csv_config=csv_config)
json = sub_dir.json_at('test.json')

assert parquet is not None
assert csv is not None
assert json is not None
25 changes: 17 additions & 8 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,11 +1905,11 @@ async def to_pandas(self) -> pd.DataFrame:
~mask, other=df.loc[mask, feature.name].str.strip('"')
)

if feature.dtype == FeatureType.datetime():
if feature.dtype.is_datetime:
df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
elif feature.dtype == FeatureType.datetime() or feature.dtype == FeatureType.string():
elif feature.dtype == FeatureType.string():
continue
elif (feature.dtype == FeatureType.array()) or (feature.dtype == FeatureType.embedding()):
elif (feature.dtype.is_array) or (feature.dtype == FeatureType.embedding()):
import json

if df[feature.name].dtype == 'object':
Expand Down Expand Up @@ -1950,14 +1950,23 @@ async def to_lazy_polars(self) -> pl.LazyFrame:

if feature.dtype == FeatureType.bool():
df = df.with_columns(pl.col(feature.name).cast(pl.Int8).cast(pl.Boolean))
elif feature.dtype == FeatureType.datetime():
elif feature.dtype.is_datetime:

current_dtype = df.select([feature.name]).dtypes[0]

if isinstance(current_dtype, pl.Datetime):
continue
tz_value = feature.dtype.datetime_timezone

if not isinstance(current_dtype, pl.Datetime):
expr = self.date_formatter.decode_polars(feature.name)
else:
expr = pl.col(feature.name)

if tz_value and tz_value != current_dtype.time_zone:
df = df.with_columns(expr.dt.convert_time_zone(tz_value))
else:
df = df.with_columns(expr)

df = df.with_columns(self.date_formatter.decode_polars(feature.name))
elif (feature.dtype == FeatureType.array()) or (feature.dtype == FeatureType.embedding()):
elif (feature.dtype.is_array) or (feature.dtype == FeatureType.embedding()):
dtype = df.select(feature.name).dtypes[0]
if dtype == pl.Utf8:
df = df.with_columns(pl.col(feature.name).str.json_extract(pl.List(pl.Utf8)))
Expand Down
11 changes: 11 additions & 0 deletions aligned/schemas/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def is_numeric(self) -> bool:
def is_datetime(self) -> bool:
return self.name.startswith('datetime')

@property
def is_array(self) -> bool:
return self.name.startswith('array')

@property
def datetime_timezone(self) -> str | None:
if not self.is_datetime:
return None

return self.name.split('-')[1] if '-' in self.name else None

@property
def python_type(self) -> type:
from datetime import date, datetime, time, timedelta
Expand Down
Loading

0 comments on commit 1001fb4

Please sign in to comment.