diff --git a/database/gpt_client.py b/database/gpt_client.py index 14e56b3..df434b6 100644 --- a/database/gpt_client.py +++ b/database/gpt_client.py @@ -17,14 +17,14 @@ class GptClient: # pylint: disable=too-few-public-methods """ A client for interacting with the OpenAI GPT API. """ + insights_pattern = re.compile( + r'(\s\(Sentiment:\s(Positive|Neutral|Negative)\))', + flags=re.IGNORECASE + ) def __init__(self, api_key: str, model: str = MODEL): openai.api_key = api_key self.model = model - self.insights_pattern = re.compile( - r'\d+\.\s(.*?)(?:\((positive|neutral|negative) sentiment\))', - flags=re.IGNORECASE - ) self.last_prompted_datetime = None # pylint: disable=line-too-long self.behavior_instruction = dedent("""You are a stock market expert, capable of quickly analysing trends and outliers in stock data. @@ -62,6 +62,7 @@ async def prompt(self, stock_data: list[dict]) -> InsightsResponse: logger.info("Sending prompt to GPT API for insights.") try: response = await self._send_prompt(messages) + self.clean_insights(response["items"]) except Exception as exc: # pylint: disable=broad-except logger.error("Failed to get insights from GPT API.") logger.error(exc) @@ -70,7 +71,9 @@ async def prompt(self, stock_data: list[dict]) -> InsightsResponse: self.cached_insights = InsightsResponse(**response) return InsightsResponse(**response) - async def _send_prompt(self, messages: list[Message]) -> dict: + async def _send_prompt(self, messages: list[Message]) -> dict[ + str, int | list[dict[str, str | list[str]]] + ]: """Sends a prompt to the OpenAI GPT API and returns the response.""" func_response = await openai.ChatCompletion.acreate( model=self.model, @@ -94,3 +97,11 @@ async def _send_prompt(self, messages: list[Message]) -> dict: for item in response["items"]: del item["insights"][5:] return response + + @classmethod + def clean_insights(cls, insights: list[dict[str, str | list[dict[str, str]]]]): + """Cleans the insights by removing the sentiment and other noise.""" + for record in insights: + for insight in record["insights"]: + message = insight.pop("message", None) or insight.pop("insight", None) + insight["message"] = cls.insights_pattern.sub("", message) diff --git a/database/tests/test_gpt_client.py b/database/tests/test_gpt_client.py index b2cf8ea..bf2be40 100644 --- a/database/tests/test_gpt_client.py +++ b/database/tests/test_gpt_client.py @@ -53,3 +53,92 @@ async def test_prompt(gpt_client_fixture, data, prompt_call_count): assert gpt_client.last_prompted_datetime == msg[-1]["datetime"] # Call Count is an indication of whether the cached insights were used or not. assert gpt_client._send_prompt.call_count == prompt_call_count + + +@pytest.mark.parametrize("insights, expected", [ + ( + [ + { + "datetime": "2021-01-01 00:00:00", + "insights": [ + { + "message": "This is a positive insight.", + "sentiment": "positive" + }, + { + "message": "This is a negative insight.", + "sentiment": "negative" + }, + { + "message": "This is a neutral insight.", + "sentiment": "neutral" + } + ] + } + ], + [ + { + "datetime": "2021-01-01 00:00:00", + "insights": [ + { + "message": "This is a positive insight.", + "sentiment": "positive" + }, + { + "message": "This is a negative insight.", + "sentiment": "negative" + }, + { + "message": "This is a neutral insight.", + "sentiment": "neutral" + } + ] + } + ] + ), + ( + [ + { + "datetime": "2021-01-01 00:00:00", + "insights": [ + { + "message": "This is a positive insight. (Sentiment: Positive)", + "sentiment": "positive" + }, + { + "message": "This is a negative insight. (Sentiment: Negative)", + "sentiment": "negative" + }, + { + "message": "This is a neutral insight. (Sentiment: neutral)", + "sentiment": "neutral" + } + ] + } + ], + [ + { + "datetime": "2021-01-01 00:00:00", + "insights": [ + { + "message": "This is a positive insight.", + "sentiment": "positive" + }, + { + "message": "This is a negative insight.", + "sentiment": "negative" + }, + { + "message": "This is a neutral insight.", + "sentiment": "neutral" + } + ] + } + ] + ) +]) +async def test_clean_insights(insights, expected): + """Tests the clean_insights method.""" + input_insights = insights[:] + GptClient.clean_insights(input_insights) + assert input_insights == expected