From a517a588c4927aa5c5c2a93e4f82a58f0599d251 Mon Sep 17 00:00:00 2001 From: Pablo Orgaz Date: Mon, 30 Oct 2023 21:54:41 +0100 Subject: [PATCH] fix: sagemaker config and chat methods (#1142) --- .../components/llm/custom/sagemaker.py | 31 +++++++++++++++++-- private_gpt/components/llm/llm_component.py | 2 -- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/private_gpt/components/llm/custom/sagemaker.py b/private_gpt/components/llm/custom/sagemaker.py index 070f8ed67..284ee2c93 100644 --- a/private_gpt/components/llm/custom/sagemaker.py +++ b/private_gpt/components/llm/custom/sagemaker.py @@ -4,7 +4,7 @@ import io import json import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import boto3 # type: ignore from llama_index.bridge.pydantic import Field @@ -13,7 +13,14 @@ CustomLLM, LLMMetadata, ) -from llama_index.llms.base import llm_completion_callback +from llama_index.llms.base import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) from llama_index.llms.llama_utils import ( completion_to_prompt as generic_completion_to_prompt, ) @@ -22,8 +29,14 @@ ) if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any + from llama_index.callbacks import CallbackManager from llama_index.llms import ( + ChatMessage, + ChatResponse, + ChatResponseGen, CompletionResponseGen, ) @@ -247,3 +260,17 @@ def get_stream(): yield CompletionResponse(delta=delta, text=text, raw=data) return get_stream() + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + prompt = self.messages_to_prompt(messages) + completion_response = self.complete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion_response) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + prompt = self.messages_to_prompt(messages) + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) + return stream_completion_response_to_chat_response(completion_response) diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index cbd71ce1f..2c32897c6 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -37,8 +37,6 @@ def __init__(self) -> None: self.llm = SagemakerLLM( endpoint_name=settings.sagemaker.endpoint_name, - messages_to_prompt=messages_to_prompt, - completion_to_prompt=completion_to_prompt, ) case "openai": from llama_index.llms import OpenAI