Skip to content

Commit

Permalink
fixed using the incorrect date formatter in azure sources
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 26, 2024
1 parent 246ce6d commit bd59831
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 32 deletions.
2 changes: 1 addition & 1 deletion aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def decode_timestamps(df: pl.LazyFrame, request: RetrivalRequest, formatter: Dat
exprs = []

for column, time_zone in columns:
logger.info(f'Decoding column {column} with timezone {time_zone}')
logger.info(f'Decoding column {column} using {formatter} with timezone {time_zone}')

if time_zone is None:
exprs.append(formatter.decode_polars(column).alias(column))
Expand Down
37 changes: 8 additions & 29 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,53 +1900,32 @@ async def to_lazy_polars(self) -> pl.LazyFrame:
if request.aggregated_features:
features_to_check.update({feature.derived_feature for feature in request.aggregated_features})

if request.event_timestamp:
features_to_check.add(request.event_timestamp.as_feature())

for feature in features_to_check:

if feature.name not in org_schema:
continue

if feature.dtype.polars_type.is_(org_schema[feature.name]):
logger.debug(f'Skipping feature {feature.name}, already correct type')
continue

if feature.dtype == FeatureType.bool():
df = df.with_columns(pl.col(feature.name).cast(pl.Int8).cast(pl.Boolean))
elif feature.dtype.is_datetime:

current_dtype = df.select([feature.name]).dtypes[0]

tz_value = feature.dtype.datetime_timezone

if not isinstance(current_dtype, pl.Datetime):
expr = self.date_formatter.decode_polars(feature.name)
else:
expr = pl.col(feature.name)

if tz_value and tz_value != current_dtype.time_zone:
df = df.with_columns(expr.dt.convert_time_zone(tz_value))
else:
df = df.with_columns(expr)

elif (feature.dtype.is_array) or (feature.dtype == FeatureType.embedding()):
dtype = df.select(feature.name).dtypes[0]
if dtype == pl.Utf8:
df = df.with_columns(pl.col(feature.name).str.json_extract(pl.List(pl.Utf8)))
elif (feature.dtype == FeatureType.json()) or feature.dtype.is_datetime:
logger.debug(f'Converting {feature.name} to {feature.dtype.name}')
pass
else:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type, strict=False))

if request.event_timestamp:
feature = request.event_timestamp
if feature.name not in df.columns:
continue
current_dtype = df.select([feature.name]).dtypes[0]

if not isinstance(current_dtype, pl.Datetime):
df = df.with_columns(
(pl.col(feature.name).cast(pl.Int64) * 1000)
.cast(pl.Datetime(time_zone='UTC'))
.alias(feature.name)
logger.debug(
f'Converting {feature.name} to {feature.dtype.name} - {feature.dtype.polars_type}'
)
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type, strict=False))

return df

Expand Down
54 changes: 53 additions & 1 deletion aligned/sources/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,26 @@ async def write(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None
raise ValueError(f"Only support writing on request, got {len(requests)}.")

features = requests[0].all_returned_columns
df = await job.to_polars()
df = await job.to_lazy_polars()
await self.write_polars(df.select(features))

@classmethod
def multi_source_features_for(
cls, facts: RetrivalJob, requests: list[tuple[AzureBlobCsvDataSource, RetrivalRequest]]
) -> RetrivalJob:

source = requests[0][0]
if not isinstance(source, cls):
raise ValueError(f'Only {cls} is supported, recived: {source}')

# Group based on config
return FileFactualJob(
source=source,
requests=[request for _, request in requests],
facts=facts,
date_formatter=source.date_formatter,
)

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return FileFactualJob(self, [request], facts, date_formatter=self.date_formatter)

Expand Down Expand Up @@ -430,8 +447,26 @@ async def write_pandas(self, df: pd.DataFrame) -> None:
async def write_polars(self, df: pl.LazyFrame) -> None:
url = f"az://{self.path}"
creds = self.config.read_creds()
df.collect().write_parquet(url, storage_options=creds)
df.collect().to_pandas().to_parquet(url, storage_options=creds)

@classmethod
def multi_source_features_for(
cls, facts: RetrivalJob, requests: list[tuple[AzureBlobParquetDataSource, RetrivalRequest]]
) -> RetrivalJob:

source = requests[0][0]
if not isinstance(source, cls):
raise ValueError(f'Only {cls} is supported, recived: {source}')

# Group based on config
return FileFactualJob(
source=source,
requests=[request for _, request in requests],
facts=facts,
date_formatter=source.date_formatter,
)

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return FileFactualJob(self, [request], facts, date_formatter=self.date_formatter)

Expand Down Expand Up @@ -511,6 +546,23 @@ async def schema(self) -> dict[str, FeatureType]:
except HTTPStatusError as error:
raise UnableToFindFileException() from error

@classmethod
def multi_source_features_for(
cls, facts: RetrivalJob, requests: list[tuple[AzureBlobDeltaDataSource, RetrivalRequest]]
) -> RetrivalJob:

source = requests[0][0]
if not isinstance(source, cls):
raise ValueError(f'Only {cls} is supported, recived: {source}')

# Group based on config
return FileFactualJob(
source=source,
requests=[request for _, request in requests],
facts=facts,
date_formatter=source.date_formatter,
)

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
return FileFactualJob(self, [request], facts, date_formatter=self.date_formatter)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.89"
version = "0.0.90"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down

0 comments on commit bd59831

Please sign in to comment.