From eb976e6f341de5176d030e817885be79d6daab50 Mon Sep 17 00:00:00 2001 From: mahaloz Date: Tue, 22 Aug 2023 08:45:32 +0800 Subject: [PATCH] Add code from HITCON demo --- dailalib/binsync_plugin/__init__.py | 86 +++++++++ dailalib/binsync_plugin/ai_bs_user.py | 183 +++++++++++++++++++ dailalib/binsync_plugin/ai_user_config_ui.py | 178 ++++++++++++++++++ dailalib/binsync_plugin/openai_bs_user.py | 128 +++++++++++++ dailalib/binsync_plugin/varmodel_bs_user.py | 40 ++++ dailalib/interfaces/openai_interface.py | 91 ++++----- setup.cfg | 1 + 7 files changed, 662 insertions(+), 45 deletions(-) create mode 100644 dailalib/binsync_plugin/__init__.py create mode 100644 dailalib/binsync_plugin/ai_bs_user.py create mode 100644 dailalib/binsync_plugin/ai_user_config_ui.py create mode 100644 dailalib/binsync_plugin/openai_bs_user.py create mode 100644 dailalib/binsync_plugin/varmodel_bs_user.py diff --git a/dailalib/binsync_plugin/__init__.py b/dailalib/binsync_plugin/__init__.py new file mode 100644 index 0000000..371151f --- /dev/null +++ b/dailalib/binsync_plugin/__init__.py @@ -0,0 +1,86 @@ +from pathlib import Path +import argparse +import subprocess +import sys + +from dailalib.binsync_plugin.ai_bs_user import AIBSUser +from dailalib.binsync_plugin.openai_bs_user import OpenAIBSUser +from dailalib.binsync_plugin.varmodel_bs_user import VARModelBSUser + + +def add_ai_user_to_project( + openai_api_key: str, binary_path: Path, bs_proj_path: Path, username: str = AIBSUser.DEFAULT_USERNAME, + base_on=None, headless=False, copy_proj=False, decompiler_backend=None, model=None, controller=None, progress_callback=None, + range_str="" +): + if headless: + _headlessly_add_ai_user(openai_api_key, binary_path, bs_proj_path, username=username, decompiler_backend=decompiler_backend, base_on=base_on, model=model) + else: + if model is None or model.startswith("gpt"): + ai_user = OpenAIBSUser( + openai_api_key=openai_api_key, binary_path=binary_path, bs_proj_path=bs_proj_path, model=model, + username=username, copy_project=copy_proj, decompiler_backend=decompiler_backend, base_on=base_on, + controller=controller, progress_callback=progress_callback, range_str=range_str + ) + elif model == "VARModel": + ai_user = VARModelBSUser( + openai_api_key=openai_api_key, binary_path=binary_path, bs_proj_path=bs_proj_path, model=model, + username=username, copy_project=copy_proj, decompiler_backend=decompiler_backend, base_on=base_on, + controller=controller, progress_callback=progress_callback, range_str=range_str + ) + else: + raise ValueError(f"Model: {model} is not supported. Please use a supported model.") + + ai_user.add_ai_user_to_project() + + +def _headlessly_add_ai_user( + openai_api_key: str, binary_path: Path, bs_proj_path: Path, username: str = AIBSUser.DEFAULT_USERNAME, + decompiler_backend=None, base_on=None, model=None +): + script_path = Path(__file__).absolute() + python_path = sys.executable + optional_args = [] + if decompiler_backend: + optional_args += ["--dec", decompiler_backend] + if base_on: + optional_args += ["--base-on", base_on] + if model: + optional_args += ["--model", model] + + subpproc = subprocess.Popen([ + python_path, + str(script_path), + openai_api_key, + str(binary_path), + "--username", + username, + "--proj-path", + str(bs_proj_path), + ] + optional_args) + return subpproc + + +def _headless_main(): + parser = argparse.ArgumentParser() + parser.add_argument("openai_api_key", type=str) + parser.add_argument("binary_path", type=Path) + parser.add_argument("--proj-path", type=Path) + parser.add_argument("--username", type=str) + parser.add_argument("--dec", type=str) + parser.add_argument("--base-on", type=str) + parser.add_argument("--model", type=str) + + args = parser.parse_args() + if args.username is None: + args.username = AIBSUser.DEFAULT_USERNAME + + add_ai_user_to_project( + args.openai_api_key, args.binary_path, args.proj_path, username=args.username, headless=False, + copy_proj=True, decompiler_backend=args.dec if args.dec else None, base_on=args.base_on, + model=args.model if args.model else None + ) + + +if __name__ == "__main__": + _headless_main() diff --git a/dailalib/binsync_plugin/ai_bs_user.py b/dailalib/binsync_plugin/ai_bs_user.py new file mode 100644 index 0000000..16abd2c --- /dev/null +++ b/dailalib/binsync_plugin/ai_bs_user.py @@ -0,0 +1,183 @@ +import logging +import os +import shutil +from pathlib import Path +import tempfile +from typing import Union, Dict +import math +import threading + +from binsync.api import load_decompiler_controller, BSController +from binsync.decompilers import ANGR_DECOMPILER +from binsync.data.state import State +from binsync.data import ( + Function, Comment, StackVariable +) +from binsync.ui.qt_objects import ( + QDialog, QMessageBox +) +from binsync.ui.utils import QProgressBarDialog + +from dailalib.interfaces import OpenAIInterface +from tqdm import tqdm + +_l = logging.getLogger(__name__) +_l.setLevel(logging.INFO) + + +class AIBSUser: + MAX_FUNC_SIZE = 0xffff + MIN_FUNC_SIZE = 0x25 + DEFAULT_USERNAME = "ai_user" + + def __init__( + self, + openai_api_key: str, + binary_path: Path, + bs_proj_path: Path = None, + username: str = DEFAULT_USERNAME, + copy_project=True, + decompiler_backend=None, + base_on=None, + controller=None, + model=None, + progress_callback=None, + range_str="", + ): + self._base_on = base_on + self.username = username + self._model = model + self._progress_callback = progress_callback + if bs_proj_path is not None: + bs_proj_path = Path(bs_proj_path) + + # compute the range + if range_str: + range_strings = range_str.split("-") + self.analysis_min = int(range_strings[0], 0) + self.analysis_max = int(range_strings[1], 0) + else: + self.analysis_max = None + self.analysis_min = None + + # copy or create the project path into the temp dir + self.decompiler_backend = decompiler_backend + self.project_path = bs_proj_path or Path(binary_path).with_name(f"{binary_path.with_suffix('').name}.bsproj") + self._is_tmp = False + + self._on_main_thread = True if self.decompiler_backend is None else False + if copy_project and self.project_path.exists(): + proj_dir = Path(tempfile.mkdtemp()) + shutil.copytree(self.project_path, proj_dir / self.project_path.name) + self.project_path = proj_dir / self.project_path.name + self._is_tmp = True + + create = False + if not self.project_path.exists(): + create = True + os.mkdir(self.project_path) + + # connect the controller to a GitClient + _l.info(f"AI User working on copied project at: {self.project_path}") + self.controller: BSController = load_decompiler_controller( + force_decompiler=self.decompiler_backend, headless=True, binary_path=binary_path, callback_on_push=False + ) + self.controller.connect(username, str(self.project_path), init_repo=create, single_thread=True) + self.comments = {} + + def add_ai_user_to_project(self): + # base all changes on another user's state + if self._base_on: + _l.info(f"Basing all AI changes on user {self._base_on}...") + master_state = self.controller.get_state(user=self._base_on) + master_state.user = self.username + else: + _l.info("Basing AI on current decompiler changes...") + master_state = self.controller.get_state() + + # collect decompiled functions + decompiled_functions = self._collect_decompiled_functions() + t = threading.Thread( + target=self._query_and_commit_changes, + args=(master_state, decompiled_functions,) + ) + t.daemon = True + t.start() + + def _collect_decompiled_functions(self) -> Dict: + valid_funcs = [ + addr + for addr, func in self.controller.functions().items() + if self._function_is_large_enough(func) + ] + + if not valid_funcs: + _l.info("No functions with valid size (small or big), to work on...") + return {} + + # open a loading bar for progress updates + pbar = QProgressBarDialog(label_text=f"Decompiling {len(valid_funcs)} functions...") + pbar.show() + self._progress_callback = pbar.update_progress + + # decompile important functions first + decompiled_functions = {} + update_amt_per_func = math.ceil(100 / len(valid_funcs)) + callback_stub = self._progress_callback if self._progress_callback is not None else lambda x: x + for func_addr in tqdm(valid_funcs, desc=f"Decompiling {len(valid_funcs)} functions for analysis..."): + if self.analysis_max is not None and func_addr > self.analysis_max: + callback_stub(update_amt_per_func) + continue + if self.analysis_min is not None and func_addr < self.analysis_min: + callback_stub(update_amt_per_func) + continue + + func = self.controller.function(func_addr) + if func is None: + callback_stub(update_amt_per_func) + continue + + decompilation = self.controller.decompile(func_addr) + if not decompilation: + callback_stub(update_amt_per_func) + continue + + decompiled_functions[func.addr] = (OpenAIInterface.fit_decompilation_to_token_max(decompilation), func) + callback_stub(update_amt_per_func) + + dlg = QMessageBox(None) + dlg.setWindowTitle("Locking Changes Done") + dlg.setText("We've finished decompiling for use with the AI backend. " + "We will now run the rest of our AI tasks in the background. You can use your decompiler normally now.") + dlg.exec_() + return decompiled_functions + + def _query_and_commit_changes(self, state, decompiled_functions): + total_ai_changes = self.commit_ai_changes_to_state(state, decompiled_functions) + if total_ai_changes: + self.controller.client.commit_state(state, msg="AI initiated change to full state") + self.controller.client.push() + + _l.info(f"Pushed {total_ai_changes} AI initiated changes to user {self.username}") + + def _function_is_large_enough(self, func: Function): + return self.MIN_FUNC_SIZE <= func.size <= self.MAX_FUNC_SIZE + + def commit_ai_changes_to_state(self, state: State, decompiled_functions): + ai_initiated_changes = 0 + update_cnt = 0 + for func_addr, (decompilation, func) in tqdm(decompiled_functions.items(), desc=f"Querying AI for {len(decompiled_functions)} funcs..."): + ai_initiated_changes += self.run_all_ai_commands_for_dec(decompilation, func, state) + if ai_initiated_changes: + update_cnt += 1 + + if update_cnt >= 1: + update_cnt = 0 + self.controller.client.commit_state(state, msg="AI Initiated change to functions") + self.controller.client.push() + _l.info(f"Pushed some changes to user {self.username}...") + + return ai_initiated_changes + + def run_all_ai_commands_for_dec(self, decompilation: str, func: Function, state: State): + return 0 diff --git a/dailalib/binsync_plugin/ai_user_config_ui.py b/dailalib/binsync_plugin/ai_user_config_ui.py new file mode 100644 index 0000000..c3240e9 --- /dev/null +++ b/dailalib/binsync_plugin/ai_user_config_ui.py @@ -0,0 +1,178 @@ +import os +from pathlib import Path +import logging +from threading import Thread + +from binsync.ui.qt_objects import ( + QComboBox, + QDialog, + QDir, + QFileDialog, + QGridLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QVBoxLayout, + QTableWidget, + QTableWidgetItem, + QHeaderView +) +from binsync.ui.utils import QProgressBarDialog +from . import AIBSUser, add_ai_user_to_project +from binsync.api.controller import BSController +from binsync.decompilers import ANGR_DECOMPILER, IDA_DECOMPILER + +_l = logging.getLogger(__name__) +AUTO_DECOMPILER = "automatic" + + +class AIUserConfigDialog(QDialog): + TITLE = "AI User Configuration" + + def __init__(self, controller: BSController, parent=None): + super().__init__(parent) + self._controller = controller + self.api_key = os.getenv("OPENAI_API_KEY") or "" + self.username = f"{self._controller.client.master_user}_{AIBSUser.DEFAULT_USERNAME}" + self.project_path = str(Path(controller.client.repo_root).absolute()) + self.binary_path = str(Path(controller.binary_path()).absolute()) if controller.binary_path() else "" + self.base_on = "" + self.model = "gpt-4" + + self.setWindowTitle(self.TITLE) + self._main_layout = QVBoxLayout() + self._grid_layout = QGridLayout() + self.row = 0 + + self._init_widgets() + self._main_layout.addLayout(self._grid_layout) + self.setLayout(self._main_layout) + + def _init_widgets(self): + # model selection + self._model_label = QLabel("AI Model") + self._grid_layout.addWidget(self._model_label, self.row, 0) + self._model_dropdown = QComboBox() + # TODO: add more decompilers + self._model_dropdown.addItems(["gpt-4", "gpt-3.5-turbo", "VARModel"]) + self._grid_layout.addWidget(self._model_dropdown, self.row, 1) + self.row += 1 + + # api key label + self._api_key_label = QLabel("API Key") + self._grid_layout.addWidget(self._api_key_label, self.row, 0) + # api key input + self._api_key_input = QLineEdit(self.api_key) + self._grid_layout.addWidget(self._api_key_input, self.row, 1) + self.row += 1 + + # username label + self._username_label = QLabel("Username") + self._grid_layout.addWidget(self._username_label, self.row, 0) + # username input + self._username_input = QLineEdit(self.username) + self._grid_layout.addWidget(self._username_input, self.row, 1) + self.row += 1 + + # binary label + self._binary_path_label = QLabel("Binary Path") + self._grid_layout.addWidget(self._binary_path_label, self.row, 0) + # binary input + self._binary_path_input = QLineEdit(self.binary_path) + self._grid_layout.addWidget(self._binary_path_input, self.row, 1) + # project button + self._binary_path_button = QPushButton("...") + self._binary_path_button.clicked.connect(self._on_binary_path_button_blocked) + self._grid_layout.addWidget(self._binary_path_button, self.row, 2) + self.row += 1 + + # decompiler dropdown selection + #self._decompiler_label = QLabel("Decompiler Backend") + #self._grid_layout.addWidget(self._decompiler_label, self.row, 0) + #self._decompiler_dropdown = QComboBox() + ## TODO: add more decompilers + #self._decompiler_dropdown.addItems([AUTO_DECOMPILER, ANGR_DECOMPILER]) + #self._grid_layout.addWidget(self._decompiler_dropdown, self.row, 1) + #self.row += 1 + + # user_base dropdown selection + self._user_base_label = QLabel("Base On") + self._grid_layout.addWidget(self._user_base_label, self.row, 0) + self._user_base_dropdown = QComboBox() + + all_users = [user.name for user in self._controller.users()] + curr_user = self._controller.client.master_user + all_users.remove(curr_user) + all_users = [""] + [curr_user] + all_users + self._user_base_dropdown.addItems(all_users) + self._grid_layout.addWidget(self._user_base_dropdown, self.row, 1) + self.row += 1 + + # range selection + self._range_label = QLabel("View Range") + self._grid_layout.addWidget(self._range_label, self.row, 0) + self._range_input = QLineEdit("") + self._grid_layout.addWidget(self._range_input, self.row, 1) + self.row += 1 + + # ok/cancel buttons + self._ok_button = QPushButton("OK") + self._ok_button.clicked.connect(self._on_ok_button_clicked) + self._cancel_button = QPushButton("Cancel") + self._cancel_button.clicked.connect(self._on_cancel_button_clicked) + self._button_layout = QHBoxLayout() + self._button_layout.addWidget(self._ok_button) + self._button_layout.addWidget(self._cancel_button) + + self._main_layout.addLayout(self._grid_layout) + self._main_layout.addLayout(self._button_layout) + + def _on_binary_path_button_blocked(self): + # get the path to the binary + binary_path = QFileDialog.getOpenFileName(self, "Select Binary", QDir.homePath()) + if binary_path[0]: + self._binary_path_input.setText(binary_path[0]) + + def _on_ok_button_clicked(self): + self.api_key = self._api_key_input.text() + self.binary_path = self._binary_path_input.text() + self.username = self._username_input.text() + #self.decompiler_backend = self._decompiler_dropdown.currentText() + self.decompiler_backend = AUTO_DECOMPILER + if self.decompiler_backend == AUTO_DECOMPILER: + self.decompiler_backend = None + + self.base_on = self._user_base_dropdown.currentText() + self.model = self._model_dropdown.currentText() + self.range_str = self._range_input.text() + + if not (self.api_key and self.binary_path and self.username): + if self.model is not None and "gpt" in self.model: + _l.critical("You did not provide a path, username, and API key for the AI user.") + return + + _l.info(f"Starting AI user now! Commits from user {self.username} should appear soon...") + self.hide() + try: + self.threaded_add_ai_user_to_project() + except Exception as e: + _l.info(f"Ran into issue: {e}") + + self.close() + + def _on_cancel_button_clicked(self): + self.close() + + def threaded_add_ai_user_to_project(self): + # angr hack to make sure the workspace is visible! + if hasattr(self._controller, "workspace"): + globals()['workspace'] = self._controller.workspace + + add_ai_user_to_project( + self.api_key, self.binary_path, self.project_path, username=self.username, + base_on=self.base_on, headless=True if self.decompiler_backend else False, copy_proj=True, model=self.model, + decompiler_backend=self.decompiler_backend, range_str=self.range_str + ) + self._controller.used_ai_user = True \ No newline at end of file diff --git a/dailalib/binsync_plugin/openai_bs_user.py b/dailalib/binsync_plugin/openai_bs_user.py new file mode 100644 index 0000000..0540f25 --- /dev/null +++ b/dailalib/binsync_plugin/openai_bs_user.py @@ -0,0 +1,128 @@ +import logging +from typing import Dict + +from binsync.data import Function, StackVariable, Comment, State, FunctionHeader + +from dailalib.interfaces import OpenAIInterface +from dailalib.binsync_plugin.ai_bs_user import AIBSUser + +_l = logging.getLogger(__name__) + + +class OpenAIBSUser(AIBSUser): + DEFAULT_USERNAME = "chatgpt_user" + + def __init__(self, openai_api_key, *args, **kwargs): + super().__init__(openai_api_key, *args, **kwargs) + self.ai_interface = OpenAIInterface(openai_api_key=openai_api_key, decompiler_controller=self.controller, model=self._model) + + def run_all_ai_commands_for_dec(self, decompilation: str, func: Function, state: State): + changes = 0 + artifact_edit_cmds = { + self.ai_interface.RETYPE_VARS_CMD, self.ai_interface.RENAME_VARS_CMD, self.ai_interface.RENAME_FUNCS_CMD, + self.ai_interface.ANSWER_QUESTION_CMD + } + cmt_prepends = { + self.ai_interface.SUMMARIZE_CMD: "==== AI Summarization ====\n", + self.ai_interface.ID_SOURCE_CMD: "==== AI Source Guess ====\n", + self.ai_interface.FIND_VULN_CMD: "==== AI Vuln Guess ====\n", + } + + func_cmt = "" + new_func = Function(func.addr, func.size, header=FunctionHeader("", func.addr, args={}), stack_vars={}) + for cmd in self.ai_interface.AI_COMMANDS: + # TODO: make this more explicit and change what is run + if cmd not in {self.ai_interface.ANSWER_QUESTION_CMD, + self.ai_interface.RENAME_VARS_CMD, + self.ai_interface.SUMMARIZE_CMD, + self.ai_interface.RENAME_FUNCS_CMD + }: + continue + + try: + resp = self.ai_interface.query_for_cmd(cmd, decompilation=decompilation) + except Exception: + continue + + if not resp: + continue + + if cmd not in artifact_edit_cmds: + if cmd == self.ai_interface.ID_SOURCE_CMD: + if "http" not in resp: + continue + + func_cmt += cmt_prepends.get(cmd, "") + resp + "\n" + # fake the comment actually being added to decomp + decompilation = f"/* {Comment.linewrap_comment(resp)} */\n" + decompilation + changes += 1 + + elif cmd == self.ai_interface.RENAME_VARS_CMD: + all_names = set(sv.name for _, sv in func.stack_vars.items()) + for off, sv in func.stack_vars.items(): + old_name = sv.name + if old_name in resp: + proposed_name = resp[old_name] + if not proposed_name or proposed_name == old_name or proposed_name in all_names: + continue + + if off not in new_func.stack_vars: + new_func.stack_vars[off] = StackVariable(sv.offset, "", None, func.stack_vars[off].size, func.addr) + + new_func.stack_vars[off].name = proposed_name + decompilation = decompilation.replace(old_name, proposed_name) + changes += 1 + + elif cmd == self.ai_interface.RETYPE_VARS_CMD: + for off, sv in func.stack_vars.items(): + old_name = sv.name + if old_name in resp: + proposed_type = resp[old_name] + if not proposed_type or proposed_type == sv.type: + continue + + if off not in new_func.stack_vars: + new_func.stack_vars[off] = StackVariable(sv.offset, "", None, func.stack_vars[off].size, func.addr) + + new_func.stack_vars[off].type = proposed_type + # we dont update decompilation here because it would be too weird + changes += 1 + + elif cmd == self.ai_interface.RENAME_FUNCS_CMD: + if func.name in resp: + proposed_name = resp[func.name] + if proposed_name in self.controller.functions() or not proposed_name or proposed_name == func.name: + continue + + new_func.name = proposed_name + _l.info(f"Proposing new name for function {func.name} to {proposed_name}") + changes += 1 + + elif cmd == self.ai_interface.ANSWER_QUESTION_CMD: + answers: Dict[str, str] = resp + current_cmts = state.get_func_comments(func.addr) + for question, answer in answers.items(): + for _, current_cmt in current_cmts.items(): + if question in current_cmt.comment: + current_cmt.comment += f"\n{answer}" + state.set_comment(current_cmt) + changes += 1 + break + + if changes: + _l.info(f"Suggesting updates to {func} with diff: {new_func}") + state.set_function(new_func) + + # send full function comment + if func_cmt: + state.set_comment(Comment(new_func.addr, func_cmt, func_addr=new_func.addr, decompiled=True), append=True) + #self.controller.push_artifact(Comment(new_func.addr, func_cmt, func_addr=new_func.addr, decompiled=True), append=True) + #self.controller.fill_comment(new_func.addr, user=self.username, artifact=Comment(new_func.addr, func_cmt, func_addr=new_func.addr, decompiled=True), append=True) + #self.controller.schedule_job( + # self.controller.push_artifact, + # Comment(new_func.addr, func_cmt, func_addr=new_func.addr, decompiled=True), + # blocking=False, + # append=True + #) + + return changes diff --git a/dailalib/binsync_plugin/varmodel_bs_user.py b/dailalib/binsync_plugin/varmodel_bs_user.py new file mode 100644 index 0000000..a55da06 --- /dev/null +++ b/dailalib/binsync_plugin/varmodel_bs_user.py @@ -0,0 +1,40 @@ +import logging + +from binsync.data import Function, State + +from dailalib.binsync_plugin.ai_bs_user import AIBSUser + +_l = logging.getLogger(__name__) + + +class VARModelBSUser(AIBSUser): + DEFAULT_USERNAME = "varmodel_user" + + """ + Variable Annotation Recovery (VAR) Model + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + try: + from varmodel import VariableRenamingAPI + except ImportError: + _l.error("VARModel is not installed and is still closed source. You will be unable to use this BinSync user.") + return + + self._renaming_api = VariableRenamingAPI() + + def run_all_ai_commands_for_dec(self, decompilation: str, func: Function, state: State): + try: + updated_func: Function = self._renaming_api.predict_variable_names(decompilation, func) + except Exception as e: + _l.warning(f"Skipping {func} due to exception: {e}") + return 0 + + if updated_func is not None and (updated_func.args or updated_func.stack_vars): + # count changes + changes = len(updated_func.args) + len(updated_func.stack_vars) + state.set_function(updated_func) + return changes + + return 0 diff --git a/dailalib/interfaces/openai_interface.py b/dailalib/interfaces/openai_interface.py index 21999a4..8df182e 100644 --- a/dailalib/interfaces/openai_interface.py +++ b/dailalib/interfaces/openai_interface.py @@ -6,6 +6,7 @@ from functools import wraps import openai +import tiktoken from .generic_ai_interface import GenericAIInterface from ..utils import HYPERLINK_REGEX @@ -18,15 +19,19 @@ def _addr_ctx_when_none(self: "OpenAIInterface", *args, **kwargs): if func_addr is None: func_addr = self._current_function_addr() kwargs.update({"func_addr": func_addr, "edit_dec": True}) - return f(self, *args, **kwargs) return _addr_ctx_when_none - +JSON_REGEX = re.compile(r"\{.*\}", flags=re.DOTALL) QUESTION_START = "Q>" QUESTION_REGEX = re.compile(rf"({QUESTION_START})([^?]*\?)") ANSWER_START = "A>" +DEFAULT_MODEL = "gpt-4" +MODEL_TO_TOKENS = { + "gpt-4": 8000, + "gpt-3.5-turbo": 4096 +} class OpenAIInterface(GenericAIInterface): # API Command Constants @@ -53,26 +58,26 @@ class OpenAIInterface(GenericAIInterface): SNIPPET_TEXT = f"\n\"\"\"{SNIPPET_REPLACEMENT_LABEL}\"\"\"" DECOMP_TEXT = f"\n\"\"\"{DECOMP_REPLACEMENT_LABEL}\"\"\"" PROMPTS = { - SUMMARIZE_CMD: f"Please summarize the following code:{DECOMP_TEXT}", - RENAME_FUNCS_CMD: "Rename the functions in this code. Reply with only a JSON array where keys are the " - f"original names and values are the proposed names:{DECOMP_TEXT}", RENAME_VARS_CMD: "Analyze what the following function does. Suggest better variable names. " "Reply with only a JSON array where keys are the original names and values are " f"the proposed names:{DECOMP_TEXT}", RETYPE_VARS_CMD: "Analyze what the following function does. Suggest better C types for the variables. " "Reply with only a JSON where keys are the original names and values are the " f"proposed types: {DECOMP_TEXT}", + SUMMARIZE_CMD: f"Please summarize the following code:{DECOMP_TEXT}", FIND_VULN_CMD: "Can you find the vulnerability in the following function and suggest the " f"possible way to exploit it?{DECOMP_TEXT}", ID_SOURCE_CMD: "What open source project is this code from. Please only give me the program name and " f"package name:{DECOMP_TEXT}", + RENAME_FUNCS_CMD: "The following code is C/C++. Rename the function according to its purpose using underscore_case. Reply with only a JSON array where keys are the " + f"original names and values are the proposed names:{DECOMP_TEXT}", ANSWER_QUESTION_CMD: "You are a code comprehension assistant. You answer questions based on code that is " f"provided. Here is some code: {DECOMP_TEXT}. Focus on this snippet of the code: " f"{SNIPPET_TEXT}\n\n Answer the following question as concisely as possible, guesses " f"are ok: " } - def __init__(self, openai_api_key=None, model="gpt-4", decompiler_controller=None): + def __init__(self, openai_api_key=None, model=DEFAULT_MODEL, decompiler_controller=None): super().__init__(decompiler_controller=decompiler_controller) self.model = model @@ -156,8 +161,13 @@ def _query_openai(self, prompt: str, json_response=False, increase_new_text=True if "}" not in resp: resp += "}" + json_matches = JSON_REGEX.findall(resp) + if not json_matches: + return default_response + + json_data = json_matches[0] try: - data = json.loads(resp) + data = json.loads(json_data) except Exception: data = {} @@ -306,14 +316,6 @@ def rename_functions_in_function(self, *args, func_addr=None, decompilation=None resp: Dict = self.query_for_cmd(self.RENAME_FUNCS_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) if edit_dec and resp: # TODO: reimplement this code with self.decompiler_controller.set_function(func) - """ - for addr, _ in self.decompiler_controller.functions().items(): - func = self.decompiler_controller.functions(addr) - if func.name in resp: - new_name = resp[func.name] - func.name = new_name - self.decompiler_controller.fill_function(artifact=func) - """ pass return resp @@ -323,21 +325,6 @@ def rename_variables_in_function(self, *args, func_addr=None, decompilation=None resp: Dict = self.query_for_cmd(self.RENAME_VARS_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) if edit_dec and resp: # TODO: reimplement this code with self.decompiler_controller.set_function(func) - """ - func = self.decompiler_controller.function(func_addr) - if func: - updates = False - for soff in func.stack_vars: - svar = func.stack_vars[soff] - if svar.name in resp: - new_name = resp[svar.name] - svar.name = new_name - func.stack_vars[soff] = svar - updates = True - - if updates: - self.decompiler_controller.fill_function(artifact=func) - """ pass return resp @@ -347,21 +334,35 @@ def retype_variables_in_function(self, *args, func_addr=None, decompilation=None resp: Dict = self.query_for_cmd(self.RETYPE_VARS_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) if edit_dec and resp: # TODO: reimplement this code with self.decompiler_controller.set_function(func) - """ - func = self.decompiler_controller.function(func_addr) - if func: - updates = False - for soff in func.stack_vars: - svar = func.stack_vars[soff] - if svar.name in resp: - new_type = resp[svar.name] - svar.type = new_type - func.stack_vars[soff] = svar - updates = True - - if updates: - self.decompiler_controller.fill_function(artifact=func) - """ pass return resp + + # + # helpers + # + + @staticmethod + def estimate_token_amount(content: str, model=DEFAULT_MODEL): + enc = tiktoken.encoding_for_model(model) + tokens = enc.encode(content) + return len(tokens) + + @staticmethod + def content_fits_tokens(content: str, model=DEFAULT_MODEL): + max_token_count = MODEL_TO_TOKENS[model] + token_count = OpenAIInterface.estimate_token_amount(content, model=model) + return token_count <= max_token_count - 1000 + + @staticmethod + def fit_decompilation_to_token_max(decompilation: str, delta_step=10, model=DEFAULT_MODEL): + if OpenAIInterface.content_fits_tokens(decompilation, model=model): + return decompilation + + dec_lines = decompilation.split("\n") + last_idx = len(dec_lines) - 1 + # should be: [func_prototype] + [nop] + [mid] + [nop] + [end_of_code] + dec_lines = dec_lines[0:2] + ["// ..."] + dec_lines[delta_step:last_idx-delta_step] + ["// ..."] + dec_lines[-2:-1] + decompilation = "\n".join(dec_lines) + + return OpenAIInterface.fit_decompilation_to_token_max(decompilation, delta_step=delta_step, model=model) diff --git a/setup.cfg b/setup.cfg index d21a7f5..1221e04 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,7 @@ install_requires = openai>=0.27.4 binsync PySide6-Essentials + tiktoken python_requires = >= 3.5 include_package_data = True