Skip to content

Commit

Permalink
CrateDB: Conversational Memory
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Oct 29, 2024
1 parent 447c0dd commit abb2daf
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 5 deletions.
2 changes: 2 additions & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ chardet>=5.1.0,<6
cloudpathlib>=0.18,<0.19
cloudpickle>=2.0.0
cohere>=4,<6
crate==1.0.0dev1
databricks-vectorsearch>=0.21,<0.22
datasets>=2.15.0,<3
dgml-utils>=0.3.0,<0.4
Expand Down Expand Up @@ -76,6 +77,7 @@ requests-toolbelt>=1.0.0,<2
rspace_client>=2.5.0,<3
scikit-learn>=1.2.2,<2
simsimd>=5.0.0,<6
sqlalchemy-cratedb>=0.40.0,<1
sqlite-vss>=0.1.2,<0.2
sqlite-vec>=0.1.0,<0.2
sseclient-py>=1.8.0,<2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from langchain_community.chat_message_histories.cosmos_db import (
CosmosDBChatMessageHistory,
)
from langchain_community.chat_message_histories.cratedb import (
CrateDBChatMessageHistory,
)
from langchain_community.chat_message_histories.dynamodb import (
DynamoDBChatMessageHistory,
)
Expand Down Expand Up @@ -94,6 +97,7 @@
"CassandraChatMessageHistory",
"ChatMessageHistory",
"CosmosDBChatMessageHistory",
"CrateDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
Expand All @@ -120,6 +124,7 @@
"CassandraChatMessageHistory": "langchain_community.chat_message_histories.cassandra", # noqa: E501
"ChatMessageHistory": "langchain_community.chat_message_histories.in_memory",
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories.cosmos_db", # noqa: E501
"CrateDBChatMessageHistory": "langchain_community.chat_message_histories.cratedb", # noqa: E501
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories.dynamodb",
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories.elasticsearch", # noqa: E501
"FileChatMessageHistory": "langchain_community.chat_message_histories.file",
Expand Down
109 changes: 109 additions & 0 deletions libs/community/langchain_community/chat_message_histories/cratedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
import typing as t

import sqlalchemy as sa
from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict

from langchain_community.chat_message_histories.sql import (
BaseMessageConverter,
SQLChatMessageHistory,
)


def create_message_model(table_name, DynamicBase): # type: ignore
"""
Create a message model for a given table name.
This is a specialized version for CrateDB for generating integer-based primary keys.
TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant
returning its integer value.
Args:
table_name: The name of the table to use.
DynamicBase: The base class to use for the model.
Returns:
The model class.
"""

# Model is declared inside a function to be able to use a dynamic table name.
class Message(DynamicBase):
__tablename__ = table_name
id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())
session_id = sa.Column(sa.Text)
message = sa.Column(sa.Text)

return Message


class CrateDBMessageConverter(BaseMessageConverter):
"""
The default message converter for CrateDBMessageConverter.
It is the same as the generic `SQLChatMessageHistory` converter,
but swaps in a different `create_message_model` function.
"""

def __init__(self, table_name: str):
self.model_class = create_message_model(table_name, sa.orm.declarative_base())

def from_sql_model(self, sql_message: t.Any) -> BaseMessage:
return messages_from_dict([json.loads(sql_message.message)])[0]

def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any:
return self.model_class(
session_id=session_id, message=json.dumps(_message_to_dict(message))
)

def get_sql_model_class(self) -> t.Any:
return self.model_class


class CrateDBChatMessageHistory(SQLChatMessageHistory):
"""
It is the same as the generic `SQLChatMessageHistory` implementation,
but swaps in a different message converter by default.
"""

DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter

def __init__(
self,
session_id: str,
connection_string: str,
table_name: str = "message_store",
session_id_field_name: str = "session_id",
custom_message_converter: t.Optional[BaseMessageConverter] = None,
):
from sqlalchemy_cratedb.support import refresh_after_dml

super().__init__(
session_id,
connection_string,
table_name=table_name,
session_id_field_name=session_id_field_name,
custom_message_converter=custom_message_converter,
)

# Patch dialect to invoke `REFRESH TABLE` after each DML operation.
refresh_after_dml(self.Session)

def _messages_query(self) -> sa.sql.Select:
"""
Construct an SQLAlchemy selectable to query for messages.
For CrateDB, add an `ORDER BY` clause on the primary key.
"""
selectable = super()._messages_query()
selectable = selectable.order_by(self.sql_model_class.id)
return selectable

def clear(self) -> None:
"""
Needed for CrateDB to synchronize data because `on_flush` does not catch it.
"""
from sqlalchemy_cratedb.support import refresh_table

outcome = super().clear()
with self.Session() as session:
refresh_table(session, self.sql_model_class)
return outcome
27 changes: 23 additions & 4 deletions libs/community/langchain_community/chat_message_histories/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
List,
Optional,
Sequence,
Type,
Union,
cast,
)

from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import Column, Integer, Text, delete, select
from sqlalchemy import Column, Integer, Text, create_engine, delete, select
from sqlalchemy.sql import Select

try:
from sqlalchemy.orm import declarative_base
Expand All @@ -27,7 +29,6 @@
message_to_dict,
messages_from_dict,
)
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
Expand All @@ -38,7 +39,6 @@
Session as SQLSession,
)
from sqlalchemy.orm import (
declarative_base,
scoped_session,
sessionmaker,
)
Expand All @@ -55,6 +55,10 @@
class BaseMessageConverter(ABC):
"""Convert BaseMessage to the SQLAlchemy model."""

@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError

@abstractmethod
def from_sql_model(self, sql_message: Any) -> BaseMessage:
"""Convert a SQLAlchemy model to a BaseMessage instance."""
Expand Down Expand Up @@ -146,6 +150,8 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
"""

DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter

@property
@deprecated("0.2.2", removal="1.0", alternative="session_maker")
def Session(self) -> Union[scoped_session, async_sessionmaker]:
Expand Down Expand Up @@ -220,7 +226,9 @@ def __init__(
self.session_maker = scoped_session(sessionmaker(bind=self.engine))

self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER(
table_name
)
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
Expand All @@ -241,6 +249,17 @@ async def _acreate_table_if_not_exists(self) -> None:
await conn.run_sync(self.sql_model_class.metadata.create_all)
self._table_created = True

def _messages_query(self) -> Select:
"""Construct an SQLAlchemy selectable to query for messages"""
return (
select(self.sql_model_class)
.where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
.order_by(self.sql_model_class.id.asc())
)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
Expand Down
Loading

0 comments on commit abb2daf

Please sign in to comment.