Skip to content

Commit

Permalink
Add Chain-of-Thought Prompting Style (#56)
Browse files Browse the repository at this point in the history
* Add Chain-of-Thought Prompting Style

* Add all CoT prompts and bump libbs
  • Loading branch information
mahaloz authored Sep 12, 2024
1 parent b832788 commit b14764c
Show file tree
Hide file tree
Showing 8 changed files with 407 additions and 35 deletions.
7 changes: 4 additions & 3 deletions dailalib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.5.0"
__version__ = "3.6.0"

from .api import AIAPI, LiteLLMAIAPI
from libbs.api import DecompilerInterface
Expand All @@ -13,7 +13,7 @@ def create_plugin(*args, **kwargs):
litellm_api = LiteLLMAIAPI(delay_init=True)
# create context menus for prompts
gui_ctx_menu_actions = {
f"DAILA/LLM/{prompt_name}": (prompt.desc, getattr(litellm_api, prompt_name))
f"DAILA/LLM/{prompt_name}": (prompt.desc, lambda *x, **y: getattr(litellm_api, prompt_name)(*x, **y))
for prompt_name, prompt in litellm_api.prompts_by_name.items()
}
# create context menus for others
Expand All @@ -27,12 +27,13 @@ def create_plugin(*args, **kwargs):

VARBERT_AVAILABLE = True
try:
from varbert.api import VariableRenamingAPI
import varbert
except ImportError:
VARBERT_AVAILABLE = False

var_api = None
if VARBERT_AVAILABLE:
from varbert.api import VariableRenamingAPI
var_api = VariableRenamingAPI(delay_init=True)

# add single interface, which is to rename variables
Expand Down
4 changes: 2 additions & 2 deletions dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def ask_api_key(self, *args, **kwargs):
api_key = api_key_or_path
self.api_key = api_key

def ask_prompt_style(self):
def ask_prompt_style(self, *args, **kwargs):
if self._dec_interface is not None:
from .prompts import ALL_STYLES

Expand All @@ -165,7 +165,7 @@ def ask_prompt_style(self):
self.prompt_style = p_style
self._dec_interface.info(f"Prompt style set to {p_style}")

def ask_model(self):
def ask_model(self, *args, **kwargs):
if self._dec_interface is not None:
model_choices = list(LiteLLMAIAPI.MODEL_TO_TOKENS.keys())
model_choices.remove(self.model)
Expand Down
41 changes: 34 additions & 7 deletions dailalib/api/litellm/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,61 @@
from pathlib import Path
from .prompt_type import PromptType, DEFAULT_STYLE, ALL_STYLES
from .prompt import Prompt
from .prompts import SUMMARIZE_FUNCTION, IDENTIFY_SOURCE, RENAME_FUNCTION, RENAME_VARIABLES

FILE_DIR = Path(__file__).absolute().parent

class PromptNames:
RENAME_FUNC = "RENAME_FUNCTION"
RENAME_VARS = "RENAME_VARIABLES"
SUMMARIZE_FUNC = "SUMMARIZE_FUNCTION"
ID_SRC = "IDENTIFY_SOURCE"


def get_prompt_template(prompt_name, prompt_style):
if prompt_style in [PromptType.FEW_SHOT, PromptType.ZERO_SHOT]:
from .few_shot_prompts import RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE
d = {
PromptNames.RENAME_FUNC: RENAME_FUNCTION,
PromptNames.RENAME_VARS: RENAME_VARIABLES,
PromptNames.SUMMARIZE_FUNC: SUMMARIZE_FUNCTION,
PromptNames.ID_SRC: IDENTIFY_SOURCE
}
elif prompt_style == PromptType.COT:
from .cot_prompts import RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE
d = {
PromptNames.RENAME_FUNC: RENAME_FUNCTION,
PromptNames.RENAME_VARS: RENAME_VARIABLES,
PromptNames.SUMMARIZE_FUNC: SUMMARIZE_FUNCTION,
PromptNames.ID_SRC: IDENTIFY_SOURCE
}
else:
raise ValueError("Invalid prompt style")

return d[prompt_name]


PROMPTS = [
Prompt(
"summarize",
SUMMARIZE_FUNCTION,
PromptNames.SUMMARIZE_FUNC,
desc="Summarize the function",
response_key="summary",
gui_result_callback=Prompt.comment_function
),
Prompt(
"identify_source",
IDENTIFY_SOURCE,
PromptNames.ID_SRC,
desc="Identify the source of the function",
response_key="link",
gui_result_callback=Prompt.comment_function
),
Prompt(
"rename_variables",
RENAME_VARIABLES,
PromptNames.RENAME_VARS,
desc="Suggest variable names",
gui_result_callback=Prompt.rename_variables
),
Prompt(
"rename_function",
RENAME_FUNCTION,
PromptNames.RENAME_FUNC,
desc="Suggest a function name",
gui_result_callback=Prompt.rename_function
),
Expand Down
Loading

0 comments on commit b14764c

Please sign in to comment.