Skip to content

Commit

Permalink
Simplify databackend contract
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 1, 2025
1 parent a18c1aa commit a7a4094
Show file tree
Hide file tree
Showing 19 changed files with 66 additions and 256 deletions.
4 changes: 2 additions & 2 deletions plugins/ibis/superduper_ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data_backend import IbisDataBackend as DataBackend
from .query import IbisQuery
from .query import IbisQuery as Query

__version__ = "0.4.7"

__all__ = ["IbisQuery", "DataBackend"]
__all__ = ["Query", "DataBackend"]
32 changes: 4 additions & 28 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.local.artifacts import FileSystemArtifactStore
from superduper.base import exceptions
from superduper.base.enums import DBType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.components.table import Table
Expand Down Expand Up @@ -84,12 +83,10 @@ class IbisDataBackend(BaseDataBackend):
:param flavour: Flavour of the databackend.
"""

db_type = DBType.SQL

def __init__(self, uri: str, flavour: t.Optional[str] = None):
def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None):
self.connection_callback = lambda: _connection_callback(uri, flavour)
conn, name, in_memory = self.connection_callback()
super().__init__(uri=uri, flavour=flavour)
super().__init__(uri=uri, flavour=flavour, plugin=plugin)
self.conn = conn
self.name = name
self.in_memory = in_memory
Expand All @@ -110,18 +107,10 @@ def _setup(self, conn):

def reconnect(self):
"""Reconnect to the database client."""
# Reconnect to database.
conn, _, _ = self.connection_callback()
self.conn = conn
self._setup(conn)

def get_query_builder(self, table_name):
"""Get the query builder for the data backend.
:param table_name: Which table to get the query builder for
"""
return IbisQuery(table=table_name, db=self.datalayer)

def url(self):
"""Get the URL of the database."""
return self.conn.con.url + self.name
Expand Down Expand Up @@ -302,7 +291,7 @@ def drop(self, force: bool = False):
logging.info(f"Dropping table: {table}")
self.conn.drop_table(table)

def get_table_or_collection(self, identifier):
def get_table(self, identifier):
"""Get a table or collection from the database.
:param identifier: The identifier of the table or collection.
Expand All @@ -316,21 +305,8 @@ def get_table_or_collection(self, identifier):

def disconnect(self):
"""Disconnect the client."""

# TODO: implement me

def list_tables_or_collections(self):
def list_tables(self):
"""List all tables or collections in the database."""
return self.conn.list_tables()

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
"""Infer a schema from a given data object.
:param data: The data object
:param identifier: The identifier for the schema, if None, it will be generated
:return: The inferred schema
"""
from superduper.misc.auto_schema import infer_schema

return infer_schema(data, identifier=identifier)
2 changes: 1 addition & 1 deletion plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _execute_insert(self, parent):
return ids

def _create_table_if_not_exists(self):
tables = self.db.databackend.list_tables_or_collections()
tables = self.db.databackend.list_tables()
if self.table in tables:
return
self.db.databackend.create_table_and_schema(
Expand Down
6 changes: 3 additions & 3 deletions plugins/mongodb/plugin_test/test_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from superduper.components.model import ObjectModel
from superduper.components.vector_index import VectorIndex

from superduper_mongodb.query import MongoQuery
from superduper_mongodb.query import MongoDBQuery

try:
client = pymongo.MongoClient(CFG.data_backend)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_setup_atlas_vector_search(atlas_search_config):
encoder=Vector(dtype='float64', shape=(16,)),
)
db = superduper()
collection = MongoQuery(table='docs')
collection = MongoDBQuery(table='docs')

vector_indexes = db.data_backend.list_vector_indexes()

Expand Down Expand Up @@ -87,7 +87,7 @@ def test_setup_atlas_vector_search(atlas_search_config):
@pytest.mark.skipif(DO_SKIP, reason='Only atlas deployments relevant.')
def test_use_atlas_vector_search(atlas_search_config):
db = superduper()
collection = MongoQuery(table='docs')
collection = MongoDBQuery(table='docs')

query = collection.like(
Document({'text': 'This is a test'}), n=5, vector_index='test-vector-index'
Expand Down
4 changes: 2 additions & 2 deletions plugins/mongodb/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import pytest
from superduper import CFG

from superduper_mongodb.metadata import MongoMetaDataStore
from superduper_mongodb.metadata import MongoDBMetaDataStore

DATABASE_URL = CFG.metadata_store or CFG.data_backend or "mongomock://test_db"


@pytest.fixture
def metadata():
store = MongoMetaDataStore(DATABASE_URL)
store = MongoDBMetaDataStore(DATABASE_URL)
yield store
store.drop(force=True)

Expand Down
8 changes: 4 additions & 4 deletions plugins/mongodb/plugin_test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from superduper.components.table import Table
from superduper.ext.numpy.encoder import Array

from superduper_mongodb.query import MongoQuery
from superduper_mongodb.query import MongoDBQuery


@pytest.fixture
Expand Down Expand Up @@ -63,15 +63,15 @@ def test_mongo_schema(db, schema):


def test_select_missing_outputs(db):
docs = list(db.execute(MongoQuery(table='documents').find({}, {'_id': 1})))
docs = list(db.execute(MongoDBQuery(table='documents').find({}, {'_id': 1})))
ids = [r['_id'] for r in docs[: len(docs) // 2]]
db.execute(
MongoQuery(table='documents').update_many(
MongoDBQuery(table='documents').update_many(
{'_id': {'$in': ids}},
Document({'$set': {'_outputs__x::test_model_output::0::0': 'test'}}),
)
)
select = MongoQuery(table='documents').find({}, {'_id': 1})
select = MongoDBQuery(table='documents').find({}, {'_id': 1})
select.db = db
modified_select = select.select_ids_of_missing_outputs('x::test_model_output::0::0')
out = list(db.execute(modified_select))
Expand Down
8 changes: 4 additions & 4 deletions plugins/mongodb/superduper_mongodb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .artifacts import MongoArtifactStore as ArtifactStore
from .artifacts import MongoDBArtifactStore as ArtifactStore
from .data_backend import MongoDBDataBackend as DataBackend
from .metadata import MongoMetaDataStore as MetaDataStore
from .query import MongoQuery
from .metadata import MongoDBMetaDataStore as MetaDataStore
from .query import MongoDBQuery as Query
from .vector_search import MongoAtlasVectorSearcher as VectorSearcher

__version__ = "0.4.5"

__all__ = [
"ArtifactStore",
"MongoQuery",
"Query",
"DataBackend",
"MetaDataStore",
"VectorSearcher",
Expand Down
2 changes: 1 addition & 1 deletion plugins/mongodb/superduper_mongodb/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from superduper_mongodb.utils import connection_callback


class MongoArtifactStore(ArtifactStore):
class MongoDBArtifactStore(ArtifactStore):
"""
Artifact store for MongoDB.
Expand Down
Loading

0 comments on commit a7a4094

Please sign in to comment.