diff --git a/README.md b/README.md index a9f385c..d45ebc4 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,11 @@ pip install -e . Once you have installed all dependencies you are ready to go with: ```python from nl2ltl import translate -from nl2ltl.engines.rasa.core import RasaEngine +from nl2ltl.engines.gpt.core import GPTEngine, Models from nl2ltl.filters.simple_filters import BasicFilter from nl2ltl.engines.utils import pretty -engine = RasaEngine() +engine = GPTEngine() filter = BasicFilter() utterance = "Eventually send me a Slack after receiving a Gmail" @@ -65,9 +65,10 @@ For instance, Rasa requires a `.tar.gz` format trained model in the - [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`) - [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned -To use GPT models you need to have the OPEN_API_KEY set as environment variable. To set it: +**NOTE**: To use OpenAI GPT models don't forget to add your `OPEN_API_KEY` in a `.env` file under your project folder. +The `.env` file should look like: ```bash -export OPENAI_API_KEY=your_api_key +OPENAI_API_KEY=your_api_key ``` ## Write your own Engine diff --git a/nl2ltl/engines/gpt/core.py b/nl2ltl/engines/gpt/core.py index 4cce8d0..949cc00 100644 --- a/nl2ltl/engines/gpt/core.py +++ b/nl2ltl/engines/gpt/core.py @@ -6,12 +6,11 @@ """ import json -import os from enum import Enum from pathlib import Path from typing import Dict, Set -import openai +from openai import OpenAI from pylogics.syntax.base import Formula from nl2ltl.engines.base import Engine @@ -19,7 +18,7 @@ from nl2ltl.engines.gpt.output import GPTOutput, parse_gpt_output, parse_gpt_result from nl2ltl.filters.base import Filter -openai.api_key = os.getenv("OPENAI_API_KEY") +client = OpenAI() engine_root = ENGINE_ROOT DATA_DIR = engine_root / "data" PROMPT_PATH = engine_root / DATA_DIR / "prompt.json" @@ -75,7 +74,7 @@ def _check_consistency(self) -> None: def __check_openai_version(self): """Check that the GPT tool is at the right version.""" - is_right_version = openai.__version__ == "1.12.0" + is_right_version = client._version == "1.12.0" if not is_right_version: raise Exception( "OpenAI needs to be at version 1.12.0. " @@ -149,7 +148,7 @@ def _process_utterance( query = f"NL: {utterance}\n" messages = [{"role": "user", "content": prompt + query}] if operation_mode == OperationModes.CHAT.value: - prediction = openai.chat.completions.create( + prediction = client.chat.completions.create( model=model, messages=messages, temperature=temperature, @@ -160,7 +159,7 @@ def _process_utterance( stop=["\n\n"], ) else: - prediction = openai.completions.create( + prediction = client.completions.create( model=model, prompt=messages[0]["content"], temperature=temperature, diff --git a/nl2ltl/engines/gpt/output.py b/nl2ltl/engines/gpt/output.py index 5377933..05f5877 100644 --- a/nl2ltl/engines/gpt/output.py +++ b/nl2ltl/engines/gpt/output.py @@ -40,7 +40,7 @@ def pattern(self) -> str: Match, re.search( "PATTERN: (.*)\n", - self.output["choices"][0]["message"]["content"], + self.output.choices[0].message.content, ), ).group(1) ) @@ -48,7 +48,7 @@ def pattern(self) -> str: return str( cast( Match, - re.search("PATTERN: (.*)\n", self.output["choices"][0]["text"]), + re.search("PATTERN: (.*)\n", self.output.choices[0].text), ).group(1) ) @@ -61,14 +61,14 @@ def entities(self) -> Tuple[str]: return tuple( cast( Match, - re.search("SYMBOLS: (.*)", self.output["choices"][0]["message"]["content"]), + re.search("SYMBOLS: (.*)", self.output.choices[0].message.content), ) .group(1) .split(", ") ) else: return tuple( - cast(Match, re.search("SYMBOLS: (.*)", self.output["choices"][0]["text"])).group(1).split(", ") + cast(Match, re.search("SYMBOLS: (.*)", self.output.choices[0].text)).group(1).split(", ") )