-
Notifications
You must be signed in to change notification settings - Fork 16k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
314 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
libs/community/langchain_community/chat_message_histories/cratedb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import json | ||
import typing as t | ||
|
||
import sqlalchemy as sa | ||
from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict | ||
|
||
from langchain_community.chat_message_histories.sql import ( | ||
BaseMessageConverter, | ||
SQLChatMessageHistory, | ||
) | ||
|
||
|
||
def create_message_model(table_name, DynamicBase): # type: ignore | ||
""" | ||
Create a message model for a given table name. | ||
This is a specialized version for CrateDB for generating integer-based primary keys. | ||
TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant | ||
returning its integer value. | ||
Args: | ||
table_name: The name of the table to use. | ||
DynamicBase: The base class to use for the model. | ||
Returns: | ||
The model class. | ||
""" | ||
|
||
# Model is declared inside a function to be able to use a dynamic table name. | ||
class Message(DynamicBase): | ||
__tablename__ = table_name | ||
id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) | ||
session_id = sa.Column(sa.Text) | ||
message = sa.Column(sa.Text) | ||
|
||
return Message | ||
|
||
|
||
class CrateDBMessageConverter(BaseMessageConverter): | ||
""" | ||
The default message converter for CrateDBMessageConverter. | ||
It is the same as the generic `SQLChatMessageHistory` converter, | ||
but swaps in a different `create_message_model` function. | ||
""" | ||
|
||
def __init__(self, table_name: str): | ||
self.model_class = create_message_model(table_name, sa.orm.declarative_base()) | ||
|
||
def from_sql_model(self, sql_message: t.Any) -> BaseMessage: | ||
return messages_from_dict([json.loads(sql_message.message)])[0] | ||
|
||
def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: | ||
return self.model_class( | ||
session_id=session_id, message=json.dumps(_message_to_dict(message)) | ||
) | ||
|
||
def get_sql_model_class(self) -> t.Any: | ||
return self.model_class | ||
|
||
|
||
class CrateDBChatMessageHistory(SQLChatMessageHistory): | ||
""" | ||
It is the same as the generic `SQLChatMessageHistory` implementation, | ||
but swaps in a different message converter by default. | ||
""" | ||
|
||
DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter | ||
|
||
def __init__( | ||
self, | ||
session_id: str, | ||
connection_string: str, | ||
table_name: str = "message_store", | ||
session_id_field_name: str = "session_id", | ||
custom_message_converter: t.Optional[BaseMessageConverter] = None, | ||
): | ||
from sqlalchemy_cratedb.support import refresh_after_dml | ||
|
||
super().__init__( | ||
session_id, | ||
connection_string, | ||
table_name=table_name, | ||
session_id_field_name=session_id_field_name, | ||
custom_message_converter=custom_message_converter, | ||
) | ||
|
||
# Patch dialect to invoke `REFRESH TABLE` after each DML operation. | ||
refresh_after_dml(self.Session) | ||
|
||
def _messages_query(self) -> sa.sql.Select: | ||
""" | ||
Construct an SQLAlchemy selectable to query for messages. | ||
For CrateDB, add an `ORDER BY` clause on the primary key. | ||
""" | ||
selectable = super()._messages_query() | ||
selectable = selectable.order_by(self.sql_model_class.id) | ||
return selectable | ||
|
||
def clear(self) -> None: | ||
""" | ||
Needed for CrateDB to synchronize data because `on_flush` does not catch it. | ||
""" | ||
from sqlalchemy_cratedb.support import refresh_table | ||
|
||
outcome = super().clear() | ||
with self.Session() as session: | ||
refresh_table(session, self.sql_model_class) | ||
return outcome |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.