Skip to content

Commit

Permalink
CrateDB: Conversational Memory
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Nov 6, 2024
1 parent 90189f5 commit 1ebbde1
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 5 deletions.
1 change: 1 addition & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ requests-toolbelt>=1.0.0,<2
rspace_client>=2.5.0,<3
scikit-learn>=1.2.2,<2
simsimd>=5.0.0,<6
sqlalchemy-cratedb>=0.40.1,<1
sqlite-vss>=0.1.2,<0.2
sqlite-vec>=0.1.0,<0.2
sseclient-py>=1.8.0,<2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from langchain_community.chat_message_histories.cosmos_db import (
CosmosDBChatMessageHistory,
)
from langchain_community.chat_message_histories.cratedb import (
CrateDBChatMessageHistory,
)
from langchain_community.chat_message_histories.dynamodb import (
DynamoDBChatMessageHistory,
)
Expand Down Expand Up @@ -94,6 +97,7 @@
"CassandraChatMessageHistory",
"ChatMessageHistory",
"CosmosDBChatMessageHistory",
"CrateDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
Expand All @@ -120,6 +124,7 @@
"CassandraChatMessageHistory": "langchain_community.chat_message_histories.cassandra", # noqa: E501
"ChatMessageHistory": "langchain_community.chat_message_histories.in_memory",
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories.cosmos_db", # noqa: E501
"CrateDBChatMessageHistory": "langchain_community.chat_message_histories.cratedb", # noqa: E501
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories.dynamodb",
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories.elasticsearch", # noqa: E501
"FileChatMessageHistory": "langchain_community.chat_message_histories.file",
Expand Down
107 changes: 107 additions & 0 deletions libs/community/langchain_community/chat_message_histories/cratedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import json
import typing as t

import sqlalchemy as sa
from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict

from langchain_community.chat_message_histories.sql import (
BaseMessageConverter,
SQLChatMessageHistory,
)


def create_message_model(table_name, DynamicBase): # type: ignore
"""
Create a message model for a given table name.
This is a specialized version for CrateDB for generating integer-based primary keys.
Args:
table_name: The name of the table to use.
DynamicBase: The base class to use for the model.
Returns:
The model class.
"""

# Model is declared inside a function to be able to use a dynamic table name.
class Message(DynamicBase):
__tablename__ = table_name
id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())
session_id = sa.Column(sa.Text)
message = sa.Column(sa.Text)

return Message


class CrateDBMessageConverter(BaseMessageConverter):
"""
The default message converter for CrateDBMessageConverter.
It is the same as the generic `SQLChatMessageHistory` converter,
but swaps in a different `create_message_model` function.
"""

def __init__(self, table_name: str):
self.model_class = create_message_model(table_name, sa.orm.declarative_base())

def from_sql_model(self, sql_message: t.Any) -> BaseMessage:
return messages_from_dict([json.loads(sql_message.message)])[0]

def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any:
return self.model_class(
session_id=session_id, message=json.dumps(_message_to_dict(message))
)

def get_sql_model_class(self) -> t.Any:
return self.model_class


class CrateDBChatMessageHistory(SQLChatMessageHistory):
"""
It is the same as the generic `SQLChatMessageHistory` implementation,
but swaps in a different message converter by default.
"""

DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter

def __init__(
self,
session_id: str,
connection_string: str,
table_name: str = "message_store",
session_id_field_name: str = "session_id",
custom_message_converter: t.Optional[BaseMessageConverter] = None,
):
from sqlalchemy_cratedb.support import refresh_after_dml

super().__init__(
session_id,
connection_string,
table_name=table_name,
session_id_field_name=session_id_field_name,
custom_message_converter=custom_message_converter,
)

# Patch dialect to invoke `REFRESH TABLE` after each DML operation.
refresh_after_dml(self.Session)

def _messages_query(self) -> sa.sql.Select:
"""
Construct an SQLAlchemy selectable to query for messages.
For CrateDB, add an `ORDER BY` clause on the primary key.
"""
selectable = super()._messages_query()
selectable = selectable.order_by(self.sql_model_class.id)
return selectable

def clear(self) -> None:
"""
Needed for CrateDB to synchronize data because `on_flush` does not catch it.
"""
from sqlalchemy_cratedb.support import refresh_table

outcome = super().clear()
with self.Session() as session:
refresh_table(session, self.sql_model_class)
return outcome
27 changes: 23 additions & 4 deletions libs/community/langchain_community/chat_message_histories/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
List,
Optional,
Sequence,
Type,
Union,
cast,
)

from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import Column, Integer, Text, delete, select
from sqlalchemy import Column, Integer, Text, create_engine, delete, select
from sqlalchemy.sql import Select

try:
from sqlalchemy.orm import declarative_base
Expand All @@ -27,7 +29,6 @@
message_to_dict,
messages_from_dict,
)
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
Expand All @@ -38,7 +39,6 @@
Session as SQLSession,
)
from sqlalchemy.orm import (
declarative_base,
scoped_session,
sessionmaker,
)
Expand All @@ -55,6 +55,10 @@
class BaseMessageConverter(ABC):
"""Convert BaseMessage to the SQLAlchemy model."""

@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError

@abstractmethod
def from_sql_model(self, sql_message: Any) -> BaseMessage:
"""Convert a SQLAlchemy model to a BaseMessage instance."""
Expand Down Expand Up @@ -146,6 +150,8 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
"""

DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter

@property
@deprecated("0.2.2", removal="1.0", alternative="session_maker")
def Session(self) -> Union[scoped_session, async_sessionmaker]:
Expand Down Expand Up @@ -220,7 +226,9 @@ def __init__(
self.session_maker = scoped_session(sessionmaker(bind=self.engine))

self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER(
table_name
)
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
Expand All @@ -241,6 +249,17 @@ async def _acreate_table_if_not_exists(self) -> None:
await conn.run_sync(self.sql_model_class.metadata.create_all)
self._table_created = True

def _messages_query(self) -> Select:
"""Construct an SQLAlchemy selectable to query for messages"""
return (
select(self.sql_model_class)
.where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
.order_by(self.sql_model_class.id.asc())
)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
Expand Down
169 changes: 169 additions & 0 deletions libs/community/tests/integration_tests/memory/test_cratedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
import os
from typing import Any, Generator, Tuple

import pytest
import sqlalchemy as sa
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import CrateDBChatMessageHistory
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
from langchain.schema.messages import AIMessage, HumanMessage, _message_to_dict
from sqlalchemy import Column, Integer, Text
from sqlalchemy.orm import DeclarativeBase


@pytest.fixture()
def connection_string() -> str:
return os.environ.get(
"TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive"
)


@pytest.fixture()
def engine(connection_string: str) -> sa.Engine:
"""
Return an SQLAlchemy engine object.
"""
return sa.create_engine(connection_string, echo=True)


@pytest.fixture(autouse=True)
def reset_database(engine: sa.Engine) -> None:
"""
Provision database with table schema and data.
"""
with engine.connect() as connection:
connection.execute(sa.text("DROP TABLE IF EXISTS test_table;"))
connection.commit()


@pytest.fixture()
def sql_histories(
connection_string: str,
) -> Generator[Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], None, None]:
"""
Provide the test cases with data fixtures.
"""
message_history = CrateDBChatMessageHistory(
session_id="123", connection_string=connection_string, table_name="test_table"
)
# Create history for other session
other_history = CrateDBChatMessageHistory(
session_id="456", connection_string=connection_string, table_name="test_table"
)

yield message_history, other_history
message_history.clear()
other_history.clear()


def test_add_messages(
sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory],
) -> None:
history1, _ = sql_histories
history1.add_user_message("Hello!")
history1.add_ai_message("Hi there!")

messages = history1.messages
assert len(messages) == 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"


def test_multiple_sessions(
sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory],
) -> None:
history1, history2 = sql_histories

# first session
history1.add_user_message("Hello!")
history1.add_ai_message("Hi there!")
history1.add_user_message("Whats cracking?")

# second session
history2.add_user_message("Hellox")

messages1 = history1.messages
messages2 = history2.messages

# Ensure the messages are added correctly in the first session
assert len(messages1) == 3, "waat"
assert messages1[0].content == "Hello!"
assert messages1[1].content == "Hi there!"
assert messages1[2].content == "Whats cracking?"

assert len(messages2) == 1
assert len(messages1) == 3
assert messages2[0].content == "Hellox"
assert messages1[0].content == "Hello!"
assert messages1[1].content == "Hi there!"
assert messages1[2].content == "Whats cracking?"


def test_clear_messages(
sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory],
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
assert len(sql_history.messages) == 2
# Now create another history with different session id
other_history.add_user_message("Hellox")
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 2
# Now clear the first history
sql_history.clear()
assert len(sql_history.messages) == 0
assert len(other_history.messages) == 1


def test_model_no_session_id_field_error(connection_string: str) -> None:
class Base(DeclarativeBase):
pass

class Model(Base):
__tablename__ = "test_table"
id = Column(Integer, primary_key=True)
test_field = Column(Text)

class CustomMessageConverter(DefaultMessageConverter):
def get_sql_model_class(self) -> Any:
return Model

with pytest.raises(ValueError):
CrateDBChatMessageHistory(
"test",
connection_string,
custom_message_converter=CustomMessageConverter("test_table"),
)


def test_memory_with_message_store(connection_string: str) -> None:
"""
Test ConversationBufferMemory with a message store.
"""
# Setup CrateDB as a message store.
message_history = CrateDBChatMessageHistory(
connection_string=connection_string, session_id="test-session"
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)

# Add a few messages.
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")

# Get the message history from the memory store and turn it into JSON.
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])

# Verify the outcome.
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

# Clear the conversation history, and verify that.
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
Loading

0 comments on commit 1ebbde1

Please sign in to comment.