Skip to content

Commit

Permalink
Clean the sentiments in the insight messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hyperclaw79 committed Sep 17, 2023
1 parent 94e8d29 commit 4604c47
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 5 deletions.
21 changes: 16 additions & 5 deletions database/gpt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class GptClient: # pylint: disable=too-few-public-methods
"""
A client for interacting with the OpenAI GPT API.
"""
insights_pattern = re.compile(
r'(\s\(Sentiment:\s(Positive|Neutral|Negative)\))',
flags=re.IGNORECASE
)

def __init__(self, api_key: str, model: str = MODEL):
openai.api_key = api_key
self.model = model
self.insights_pattern = re.compile(
r'\d+\.\s(.*?)(?:\((positive|neutral|negative) sentiment\))',
flags=re.IGNORECASE
)
self.last_prompted_datetime = None
# pylint: disable=line-too-long
self.behavior_instruction = dedent("""You are a stock market expert, capable of quickly analysing trends and outliers in stock data.
Expand Down Expand Up @@ -62,6 +62,7 @@ async def prompt(self, stock_data: list[dict]) -> InsightsResponse:
logger.info("Sending prompt to GPT API for insights.")
try:
response = await self._send_prompt(messages)
self.clean_insights(response["items"])
except Exception as exc: # pylint: disable=broad-except
logger.error("Failed to get insights from GPT API.")
logger.error(exc)
Expand All @@ -70,7 +71,9 @@ async def prompt(self, stock_data: list[dict]) -> InsightsResponse:
self.cached_insights = InsightsResponse(**response)
return InsightsResponse(**response)

async def _send_prompt(self, messages: list[Message]) -> dict:
async def _send_prompt(self, messages: list[Message]) -> dict[
str, int | list[dict[str, str | list[str]]]
]:
"""Sends a prompt to the OpenAI GPT API and returns the response."""
func_response = await openai.ChatCompletion.acreate(
model=self.model,
Expand All @@ -94,3 +97,11 @@ async def _send_prompt(self, messages: list[Message]) -> dict:
for item in response["items"]:
del item["insights"][5:]
return response

@classmethod
def clean_insights(cls, insights: list[dict[str, str | list[dict[str, str]]]]):
"""Cleans the insights by removing the sentiment and other noise."""
for record in insights:
for insight in record["insights"]:
message = insight.pop("message", None) or insight.pop("insight", None)
insight["message"] = cls.insights_pattern.sub("", message)
89 changes: 89 additions & 0 deletions database/tests/test_gpt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,92 @@ async def test_prompt(gpt_client_fixture, data, prompt_call_count):
assert gpt_client.last_prompted_datetime == msg[-1]["datetime"]
# Call Count is an indication of whether the cached insights were used or not.
assert gpt_client._send_prompt.call_count == prompt_call_count


@pytest.mark.parametrize("insights, expected", [
(
[
{
"datetime": "2021-01-01 00:00:00",
"insights": [
{
"message": "This is a positive insight.",
"sentiment": "positive"
},
{
"message": "This is a negative insight.",
"sentiment": "negative"
},
{
"message": "This is a neutral insight.",
"sentiment": "neutral"
}
]
}
],
[
{
"datetime": "2021-01-01 00:00:00",
"insights": [
{
"message": "This is a positive insight.",
"sentiment": "positive"
},
{
"message": "This is a negative insight.",
"sentiment": "negative"
},
{
"message": "This is a neutral insight.",
"sentiment": "neutral"
}
]
}
]
),
(
[
{
"datetime": "2021-01-01 00:00:00",
"insights": [
{
"message": "This is a positive insight. (Sentiment: Positive)",
"sentiment": "positive"
},
{
"message": "This is a negative insight. (Sentiment: Negative)",
"sentiment": "negative"
},
{
"message": "This is a neutral insight. (Sentiment: neutral)",
"sentiment": "neutral"
}
]
}
],
[
{
"datetime": "2021-01-01 00:00:00",
"insights": [
{
"message": "This is a positive insight.",
"sentiment": "positive"
},
{
"message": "This is a negative insight.",
"sentiment": "negative"
},
{
"message": "This is a neutral insight.",
"sentiment": "neutral"
}
]
}
]
)
])
async def test_clean_insights(insights, expected):
"""Tests the clean_insights method."""
input_insights = insights[:]
GptClient.clean_insights(input_insights)
assert input_insights == expected

0 comments on commit 4604c47

Please sign in to comment.