Skip to content

Commit

Permalink
Start executor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 2, 2025
1 parent a7a4094 commit 70b49d2
Show file tree
Hide file tree
Showing 76 changed files with 1,721 additions and 3,121 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add create events waiting on db apply.
- Refactor secrets loading method.
- Add db.load in db wait
- Deprecate "free" queries

#### New Features & Functionality

Expand All @@ -37,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add a standalone flag in Streamlit to mark the page as independent.
- Add secrets directory mount for loading secret env vars.
- Remove components recursively
- Enforce strict and developer friendly query developer contract

#### Bug Fixes

Expand Down
3 changes: 1 addition & 2 deletions plugins/ibis/superduper_ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .data_backend import IbisDataBackend as DataBackend
from .query import IbisQuery as Query

__version__ = "0.4.7"

__all__ = ["Query", "DataBackend"]
__all__ = ["DataBackend"]
204 changes: 86 additions & 118 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
import glob
import os
import typing as t
import uuid
from warnings import warn

import click
import ibis
import pandas
from pandas.core.frame import DataFrame
from sqlalchemy.exc import NoSuchTableError
from superduper import CFG, logging
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.base.query import Query, QueryPart
from superduper.backends.local.artifacts import FileSystemArtifactStore
from superduper.base import exceptions
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.components.table import Table

from superduper_ibis.db_helper import get_db_helper
from superduper_ibis.field_types import FieldType, dtype
from superduper_ibis.query import IbisQuery
from superduper_ibis.utils import convert_schema_to_fields

BASE64_PREFIX = "base64:"
# TODO make this a global variable in main project
INPUT_KEY = "_source"


Expand Down Expand Up @@ -95,11 +93,19 @@ def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None):

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}

if uri.startswith('snowflake://'):
if uri.startswith('snowflake://') or uri.startswith('clickhouse://'):
self.bytes_encoding = 'base64'
self.datatype_presets = {
'vector': 'superduper.components.datatype.NativeVector'
}
self.datatype_presets.update(
{'vector': 'superduper.components.datatype.NativeVector'}
)

def random_id(self):
"""Generate a random ID."""
return str(uuid.uuid4())

def to_id(self, id):
"""Convert the ID to a string."""
return str(id)

def _setup(self, conn):
self.dialect = getattr(conn, "name", "base")
Expand Down Expand Up @@ -135,116 +141,9 @@ def build_metadata(self):
logging.warn(f"Falling back to using the uri: {self.uri}.")
return MetaDataStoreProxy(SQLAlchemyMetadata(uri=self.uri))

def insert(self, table_name, raw_documents):
"""Insert data into the database.
:param table_name: The name of the table.
:param raw_documents: The data to insert.
"""
for doc in raw_documents:
for k, v in doc.items():
doc[k] = self.db_helper.convert_data_format(v)
table_name, raw_documents = self.db_helper.process_before_insert(
table_name,
raw_documents,
self.conn,
)
if not self.in_memory:
self.conn.insert(table_name, raw_documents)
else:
# CAUTION: The following is only tested with pandas.
if table_name in self.conn.tables:
t = self.conn.tables[table_name]
df = pandas.concat([t.to_pandas(), raw_documents])
self.conn.create_table(table_name, df, overwrite=True)
else:
df = pandas.DataFrame(raw_documents)
self.conn.create_table(table_name, df)

if self.conn.backend_table_type == DataFrame:
df.to_csv(os.path.join(self.name, table_name + ".csv"), index=False)

def check_ready_ids(
self, query: IbisQuery, keys: t.List[str], ids: t.Optional[t.List[t.Any]] = None
):
"""Check if all the keys are ready in the ids.
:param query: The query object.
:param keys: The keys to check.
:param ids: The ids to check.
"""
if ids:
query = query.filter(query[query.primary_id].isin(ids))
conditions = []
for key in keys:
conditions.append(query[key].notnull())

# TODO: Hotfix, will be removed by the refactor PR
try:
docs = query.filter(*conditions).select(query.primary_id).execute()
except Exception as e:
if "Table not found" in str(e) or "Can't find table" in str(e):
return []
else:
raise e
ready_ids = [doc[query.primary_id] for doc in docs]
self._log_check_ready_ids_message(ids, ready_ids)
return ready_ids

def drop_outputs(self):
def drop_table(self, table):
"""Drop the outputs."""
for table in self.conn.list_tables():
logging.info(f"Dropping table: {table}")
if CFG.output_prefix in table:
self.conn.drop_table(table)

def drop_table_or_collection(self, name: str):
"""Drop the table or collection.
Please use with caution as you will lose all data.
:param name: Table name to drop.
"""
try:
return self.db.databackend.conn.drop_table(name)
except Exception as e:
msg = "Object found is of type 'VIEW'"
if msg in str(e):
return self.db.databackend.conn.drop_view(name)
raise

def create_output_dest(
self,
predict_id: str,
datatype: t.Union[FieldType, BaseDataType],
flatten: bool = False,
):
"""Create a table for the output of the model.
:param predict_id: The identifier of the prediction.
:param datatype: The data type of the output.
:param flatten: Whether to flatten the output.
"""
# TODO: Support output schema
msg = (
"Model must have an encoder to create with the"
f" {type(self).__name__} backend."
)
assert datatype is not None, msg
if isinstance(datatype, FieldType):
output_type = dtype(datatype.identifier)
else:
output_type = datatype

fields = {
INPUT_KEY: "string",
"_source": "string",
"id": "string",
f"{CFG.output_prefix}{predict_id}": output_type,
}
return Table(
identifier=f"{CFG.output_prefix}{predict_id}",
schema=Schema(identifier=f"_schema/{predict_id}", fields=fields),
)
self.conn.drop_table(table)

def check_output_dest(self, predict_id) -> bool:
"""Check if the output destination exists.
Expand Down Expand Up @@ -310,3 +209,72 @@ def disconnect(self):
def list_tables(self):
"""List all tables or collections in the database."""
return self.conn.list_tables()

def insert(self, table, documents):
"""Insert data into the database."""
primary_id = self.primary_id(self.db[table])
for r in documents:
if primary_id not in r:
r[primary_id] = str(uuid.uuid4())
ids = [r[primary_id] for r in documents]
self.conn.insert(table, documents)
return ids

def missing_outputs(self, query, predict_id: str) -> t.List[str]:
"""Get missing outputs from the database."""
query = self._build_native_query(query)
pid = self.primary_id(query)
output_table = self.conn.table(f"{CFG.output_prefix}{predict_id}")
q = query.anti_join(output_table, output_table['_source'] == query[pid])
return q.execute().to_dict(orient='records')

def primary_id(self, query):
"""Get the primary ID of the query."""
return self.db.load('table', query.table).primary_id

def select(self, query):
"""Select data from the database."""
native_query = self._build_native_query(query)
return native_query.execute().to_dict(orient='records')

def _build_native_query(self, query):
q = self.conn.table(query.table)
pid = None

for part in query.parts:
if isinstance(part, QueryPart) and part.name != 'outputs':
args = []
for a in part.args:
if isinstance(a, Query):
args.append(self._build_native_query(a))
else:
args.append(a)
kwargs = {}
for k, v in part.kwargs.items():
if isinstance(v, Query):
kwargs[k] = self._build_native_query(v)
else:
kwargs[k] = v
if part.name == 'select' and len(args) == 0:
pass
else:
q = getattr(q, part.name)(*args, **kwargs)

elif isinstance(part, QueryPart) and part.name == 'outputs':
if pid is None:
pid = self.primary_id(query)

original_q = q
for predict_id in part.args:
output_t = self.conn.table(f"{CFG.output_prefix}{predict_id}")
q = q.join(output_t, output_t['_source'] == original_q[pid])

elif isinstance(part, str):
if part == 'primary_id':
if pid is None:
pid = self.primary_id(query)
part = pid
q = q[part]
else:
raise ValueError(f'Unknown query part: {part}')
return q
1 change: 1 addition & 0 deletions plugins/ibis/superduper_ibis/db_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO remove, no longer relevant
import base64
import collections

Expand Down
Loading

0 comments on commit 70b49d2

Please sign in to comment.