Skip to content

Commit

Permalink
update openai API call and fix output
Browse files Browse the repository at this point in the history
  • Loading branch information
francescofuggitti committed Feb 15, 2024
1 parent 3349d09 commit 88df692
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pip install -e .
Once you have installed all dependencies you are ready to go with:
```python
from nl2ltl import translate
from nl2ltl.engines.rasa.core import RasaEngine
from nl2ltl.engines.gpt.core import GPTEngine, Models
from nl2ltl.filters.simple_filters import BasicFilter
from nl2ltl.engines.utils import pretty

engine = RasaEngine()
engine = GPTEngine()
filter = BasicFilter()
utterance = "Eventually send me a Slack after receiving a Gmail"

Expand All @@ -65,9 +65,10 @@ For instance, Rasa requires a `.tar.gz` format trained model in the
- [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`)
- [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned

To use GPT models you need to have the OPEN_API_KEY set as environment variable. To set it:
**NOTE**: To use OpenAI GPT models don't forget to add your `OPEN_API_KEY` in a `.env` file under your project folder.
The `.env` file should look like:
```bash
export OPENAI_API_KEY=your_api_key
OPENAI_API_KEY=your_api_key
```

## Write your own Engine
Expand Down
11 changes: 5 additions & 6 deletions nl2ltl/engines/gpt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@
"""
import json
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Set

import openai
from openai import OpenAI
from pylogics.syntax.base import Formula

from nl2ltl.engines.base import Engine
from nl2ltl.engines.gpt import ENGINE_ROOT
from nl2ltl.engines.gpt.output import GPTOutput, parse_gpt_output, parse_gpt_result
from nl2ltl.filters.base import Filter

openai.api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI()
engine_root = ENGINE_ROOT
DATA_DIR = engine_root / "data"
PROMPT_PATH = engine_root / DATA_DIR / "prompt.json"
Expand Down Expand Up @@ -75,7 +74,7 @@ def _check_consistency(self) -> None:

def __check_openai_version(self):
"""Check that the GPT tool is at the right version."""
is_right_version = openai.__version__ == "1.12.0"
is_right_version = client._version == "1.12.0"
if not is_right_version:
raise Exception(
"OpenAI needs to be at version 1.12.0. "
Expand Down Expand Up @@ -149,7 +148,7 @@ def _process_utterance(
query = f"NL: {utterance}\n"
messages = [{"role": "user", "content": prompt + query}]
if operation_mode == OperationModes.CHAT.value:
prediction = openai.chat.completions.create(
prediction = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
Expand All @@ -160,7 +159,7 @@ def _process_utterance(
stop=["\n\n"],
)
else:
prediction = openai.completions.create(
prediction = client.completions.create(
model=model,
prompt=messages[0]["content"],
temperature=temperature,
Expand Down
8 changes: 4 additions & 4 deletions nl2ltl/engines/gpt/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def pattern(self) -> str:
Match,
re.search(
"PATTERN: (.*)\n",
self.output["choices"][0]["message"]["content"],
self.output.choices[0].message.content,
),
).group(1)
)
else:
return str(
cast(
Match,
re.search("PATTERN: (.*)\n", self.output["choices"][0]["text"]),
re.search("PATTERN: (.*)\n", self.output.choices[0].text),
).group(1)
)

Expand All @@ -61,14 +61,14 @@ def entities(self) -> Tuple[str]:
return tuple(
cast(
Match,
re.search("SYMBOLS: (.*)", self.output["choices"][0]["message"]["content"]),
re.search("SYMBOLS: (.*)", self.output.choices[0].message.content),
)
.group(1)
.split(", ")
)
else:
return tuple(
cast(Match, re.search("SYMBOLS: (.*)", self.output["choices"][0]["text"])).group(1).split(", ")
cast(Match, re.search("SYMBOLS: (.*)", self.output.choices[0].text)).group(1).split(", ")
)


Expand Down

0 comments on commit 88df692

Please sign in to comment.