diff --git a/README.md b/README.md index 39cde68..0598abf 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ DAILA was featured in the keynote talk at [HITCON CMT 2023](https://youtu.be/Hbr ## Supported Decompilers and AI Systems DAILA interacts with the decompiler abstractly through the [LibBS](https://github.com/binsync/libbs) library. This allows DAILA to support the following decompilers: -- IDA Pro: **>= 7.3** -- Ghidra: **>= 10.1** +- IDA Pro: **>= 8.4** +- Ghidra: **>= 11.1** - Binary Ninja: **>= 2.4** - angr-management: **>= 9.0** @@ -110,6 +110,7 @@ Currently, DAILA supports the following prompts: - Identify the source of a function - Find potential vulnerabilities in a function - Summarize the man page of a library call +- Free prompting... just type in your own prompt! ### VarBERT VarBERT is a local BERT model from the S&P 2024 paper [""Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning"](). diff --git a/dailalib/__init__.py b/dailalib/__init__.py index ebaa895..e29c024 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -1,10 +1,17 @@ -__version__ = "3.8.2" +__version__ = "3.9.0" + +import os +# stop LiteLLM from querying at all to the remote server +# https://github.com/BerriAI/litellm/blob/4d29c1fb6941e49191280c4fd63961dec1a1e7c5/litellm/__init__.py#L286C20-L286C48 +os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" from .api import AIAPI, LiteLLMAIAPI -from libbs.api import DecompilerInterface + +from dailalib.llm_chat import get_llm_chat_creator def create_plugin(*args, **kwargs): + from libbs.api import DecompilerInterface # # LLM API (through LiteLLM api) @@ -16,6 +23,9 @@ def create_plugin(*args, **kwargs): f"DAILA/LLM/{prompt_name}": (prompt.desc, getattr(litellm_api, prompt_name)) for prompt_name, prompt in litellm_api.prompts_by_name.items() } + # create context menu for llm chat + gui_ctx_menu_actions["DAILA/LLM/chat"] = ("Open LLM Chat...", get_llm_chat_creator(litellm_api)) + # create context menus for others gui_ctx_menu_actions["DAILA/LLM/Settings/update_api_key"] = ("Update API key...", litellm_api.ask_api_key) gui_ctx_menu_actions["DAILA/LLM/Settings/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style) diff --git a/dailalib/api/litellm/litellm_api.py b/dailalib/api/litellm/litellm_api.py index a91d681..c788fd4 100644 --- a/dailalib/api/litellm/litellm_api.py +++ b/dailalib/api/litellm/litellm_api.py @@ -1,4 +1,3 @@ -import typing from typing import Optional import os @@ -6,6 +5,7 @@ from ..ai_api import AIAPI + class LiteLLMAIAPI(AIAPI): prompts_by_name = [] DEFAULT_MODEL = "gpt-4o" @@ -30,6 +30,7 @@ def __init__( model: str = DEFAULT_MODEL, prompts: Optional[list] = None, fit_to_tokens: bool = False, + chat_use_ctx: bool = True, **kwargs ): super().__init__(**kwargs) @@ -38,6 +39,7 @@ def __init__( self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.model = model self.fit_to_tokens = fit_to_tokens + self.chat_use_ctx = chat_use_ctx # delay prompt import from .prompts import PROMPTS, DEFAULT_STYLE diff --git a/dailalib/llm_chat/__init__.py b/dailalib/llm_chat/__init__.py new file mode 100644 index 0000000..97691c8 --- /dev/null +++ b/dailalib/llm_chat/__init__.py @@ -0,0 +1,26 @@ +import logging + +from libbs.api import DecompilerInterface +from libbs.decompilers import IDA_DECOMPILER, ANGR_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER + +from ..api import AIAPI + +_l = logging.getLogger(__name__) + + +def get_llm_chat_creator(ai_api: AIAPI) -> callable: + # determine the current decompiler + current_decompiler = DecompilerInterface.find_current_decompiler() + add_llm_chat_to_ui = lambda *args, **kwargs: None + if current_decompiler == IDA_DECOMPILER: + from dailalib.llm_chat.ida import add_llm_chat_to_ui + elif current_decompiler == BINJA_DECOMPILER: + from dailalib.llm_chat.binja import add_llm_chat_to_ui + else: + _l.warning(f"LLM Chat not supported for decompiler %s", current_decompiler) + + def llm_chat_creator_wrapper(*args, **kwargs): + ai_api.info(f"Opening LLM Chat with model {ai_api.model}...") + return add_llm_chat_to_ui(ai_api=ai_api, *args, **kwargs) + + return llm_chat_creator_wrapper diff --git a/dailalib/llm_chat/binja.py b/dailalib/llm_chat/binja.py new file mode 100644 index 0000000..b8a169e --- /dev/null +++ b/dailalib/llm_chat/binja.py @@ -0,0 +1,181 @@ +from binaryninjaui import ( + UIContext, + DockHandler, + DockContextHandler, + UIAction, + UIActionHandler, + Menu, +) + +import traceback +import sys +import logging + +from PySide6.QtWidgets import ( + QDockWidget, + QWidget, + QApplication, + QMenu, + QMainWindow, + QMenuBar, QVBoxLayout, +) +from PySide6.QtCore import Qt +from binaryninjaui import DockContextHandler +import binaryninja + +from .llm_chat_ui import LLMChatClient +from ..api import AIAPI + +_l = logging.getLogger(__name__) + +# this code is based on: +# https://github.com/binsync/binsync/blob/fa754795b7d55e4b5de4e12ea7d4b5f9706c23af/plugins/binja_binsync/binja_binsync.py + +def find_main_window(): + main_window = None + for x in QApplication.allWidgets(): + if not isinstance(x, QDockWidget): + continue + main_window = x.parent() + if isinstance(main_window, (QMainWindow, QWidget)): + break + else: + main_window = None + + if main_window is None: + # oops cannot find the main window + raise Exception("Main window is not found.") + return main_window + + +dockwidgets = [ ] + + +# shamelessly copied from https://github.com/Vector35/debugger +def create_widget(widget_class, name, parent, data, *args): + # It is imperative this function return *some* value because Shiboken will try to deref what we return + # If we return nothing (or throw) there will be a null pointer deref (and we won't even get to see why) + # So in the event of an error or a nothing, return an empty widget that at least stops the crash + try: + # binsync specific code + if not isinstance(data, binaryninja.BinaryView): + raise Exception('expected an binary view') + new_bv = args[0] + # uses only a bv_controller + widget = widget_class(new_bv, parent=parent, name=name, data=data) + if not widget: + raise Exception('expected widget, got None') + + global dockwidgets + + found = False + for (bv, widgets) in dockwidgets: + if bv == data: + widgets[name] = widget + found = True + + if not found: + dockwidgets.append((data, { + name: widget + })) + + widget.destroyed.connect(lambda destroyed: destroy_widget(destroyed, widget, data, name)) + + return widget + except Exception: + traceback.print_exc(file=sys.stderr) + return QWidget(parent) + + +def destroy_widget(destroyed, old, data, name): + # Gotta be careful to delete the correct widget here + for (bv, widgets) in dockwidgets: + if bv == data: + for (name, widget) in widgets.items(): + if widget == old: + # If there are no other references to it, this will be the only one and the call + # will delete it and invoke __del__. + widgets.pop(name) + return + + +class BinjaWidgetBase: + def __init__(self): + self._main_window = None + self._menu_bar = None + self._tool_menu = None + + @property + def main_window(self): + if self._main_window is None: + self._main_window = find_main_window() + return self._main_window + + @property + def menu_bar(self): + if self._menu_bar is None: + self._menu_bar = next( + iter(x for x in self._main_window.children() if isinstance(x, QMenuBar)) + ) + return self._menu_bar + + @property + def tool_menu(self): + if self._tool_menu is None: + self._tool_menu = next( + iter( + x + for x in self._menu_bar.children() + if isinstance(x, QMenu) and x.title() == u"Tools" + ) + ) + return self._tool_menu + + def add_tool_menu_action(self, name, func): + self.tool_menu.addAction(name, func) + + +class BinjaDockWidget(QWidget, DockContextHandler): + def __init__(self, name, parent=None): + QWidget.__init__(self, parent) + DockContextHandler.__init__(self, self, name) + + self.base = BinjaWidgetBase() + + self.show() + + def toggle(self): + if self.isVisible(): + self.hide() + else: + self.show() + + +class BinjaWidget(QWidget): + def __init__(self, tabname): + super(BinjaWidget, self).__init__() + + +class LLMChatClientDockWidget(BinjaDockWidget): + def __init__(self, ai_api, parent=None, name=None, data=None): + super().__init__(name, parent=parent) + self.data = data + self._widget = None + self.ai_api = ai_api + + self._widget = LLMChatClient(self.ai_api) + layout = QVBoxLayout() + layout.addWidget(self._widget) + self.setLayout(layout) + + +def add_llm_chat_to_ui(*args, ai_api: AIAPI = None, **kwargs): + # control panel (per BV) + dock_handler = DockHandler.getActiveDockHandler() + dock_handler.addDockWidget( + "LLM Chat", + lambda n, p, d: create_widget(LLMChatClientDockWidget, n, p, d, ai_api._dec_interface.bv), + Qt.RightDockWidgetArea, + Qt.Vertical, + True + ) \ No newline at end of file diff --git a/dailalib/llm_chat/ida.py b/dailalib/llm_chat/ida.py new file mode 100644 index 0000000..accbe27 --- /dev/null +++ b/dailalib/llm_chat/ida.py @@ -0,0 +1,61 @@ +import logging + +from PyQt5 import sip +from PyQt5.QtWidgets import QWidget, QVBoxLayout + +import idaapi + +from libbs.ui.version import set_ui_version +set_ui_version("PyQt5") + +from dailalib.llm_chat.llm_chat_ui import LLMChatClient + +_l = logging.getLogger(__name__) + +# disable the annoying "Running Python script" wait box that freezes IDA at times +idaapi.set_script_timeout(0) + + +class LLMChatWrapper(object): + NAME = "LLM Chat" + + def __init__(self, ai_api, context=None): + # create a dockable view + self.twidget = idaapi.create_empty_widget(LLMChatWrapper.NAME) + self.widget = sip.wrapinstance(int(self.twidget), QWidget) + self.widget.name = LLMChatWrapper.NAME + self.width_hint = 250 + + self._ai_api = ai_api + self._context = context + self._w = None + + self._init_widgets() + + def _init_widgets(self): + self._w = LLMChatClient(self._ai_api, context=self._context) + layout = QVBoxLayout() + layout.addWidget(self._w) + layout.setContentsMargins(2,2,2,2) + self.widget.setLayout(layout) + + +def add_llm_chat_to_ui(*args, ai_api=None, context=None, **kwargs): + """ + Open the control panel view and attach it to IDA View-A or Pseudocode-A. + """ + wrapper = LLMChatWrapper(ai_api, context=context) + if not wrapper.twidget: + _l.info("Unable to find a widget to attach to. You are likely running headlessly") + return None + + flags = idaapi.PluginForm.WOPN_TAB | idaapi.PluginForm.WOPN_RESTORE | idaapi.PluginForm.WOPN_PERSIST + idaapi.display_widget(wrapper.twidget, flags) + wrapper.widget.visible = True + target = "Pseudocode-A" + dwidget = idaapi.find_widget(target) + + if not dwidget: + target = "IDA View-A" + + idaapi.set_dock_pos(LLMChatWrapper.NAME, target, idaapi.DP_RIGHT) diff --git a/dailalib/llm_chat/llm_chat_ui.py b/dailalib/llm_chat/llm_chat_ui.py new file mode 100644 index 0000000..3c1427d --- /dev/null +++ b/dailalib/llm_chat/llm_chat_ui.py @@ -0,0 +1,184 @@ +import typing + +from PyQt5.QtWidgets import ( + QApplication, QWidget, QVBoxLayout, QHBoxLayout, QTextEdit, + QPushButton, QLabel, QScrollArea, QFrame +) +from PyQt5.QtCore import Qt, QThread, pyqtSignal, QCoreApplication +from PyQt5.QtGui import QFont + +from libbs.artifacts.context import Context + +if typing.TYPE_CHECKING: + from ..api.litellm.litellm_api import LiteLLMAIAPI + +CONTEXT_PROMPT = """ +You are reverse engineering assistant that helps to understand binaries in a decompiler. Given decompilation +and questions you answer them to the best of your ability. Here is the function you are currently working on: +``` +DEC_TEXT +``` + +Acknowledging the context, by responding with: +"I see you are working on function . How can I help you today?" +""" + + +class LLMChatClient(QWidget): + def __init__(self, ai_api: "LiteLLMAIAPI", parent=None, context: Context = None): + super(LLMChatClient, self).__init__(parent) + self.ai_api = ai_api + self.context = context + self.setWindowTitle('LLM Chat') + self.setGeometry(100, 100, 600, 800) + + # Main layout + self.layout = QVBoxLayout(self) + self.setLayout(self.layout) + + # Scroll area for chat messages + self.chat_area = QScrollArea() + self.chat_area.setWidgetResizable(True) + self.chat_content = QWidget() + self.chat_layout = QVBoxLayout(self.chat_content) + self.chat_layout.addStretch(1) + self.chat_area.setWidget(self.chat_content) + + # Input area + self.input_text = QTextEdit() + self.input_text.setFixedHeight(80) + + # Send button + self.send_button = QPushButton('Send') + self.send_button.setFixedHeight(40) + self.send_button.clicked.connect(lambda: self.send_message()) + + # Arrange input and send button horizontally + self.input_layout = QHBoxLayout() + self.input_layout.addWidget(self.input_text) + self.input_layout.addWidget(self.send_button) + + # Add widgets to the main layout + self.layout.addWidget(self.chat_area) + self.layout.addLayout(self.input_layout) + + # Chat history + self.chat_history = [] + + # create a context for this first message + if ai_api.chat_use_ctx: + ai_api.info("Collecting context for the current function...") + #import remote_pdb; remote_pdb.RemotePdb('localhost', 4444).set_trace() + if context is None: + context = ai_api._dec_interface.gui_active_context() + dec = ai_api._dec_interface.decompile(context.func_addr) + dec_text = dec.text if dec is not None else None + if dec_text: + # put a number in front of each line + dec_lines = dec_text.split("\n") + dec_text = "\n".join([f"{i + 1} {line}" for i, line in enumerate(dec_lines)]) + prompt = CONTEXT_PROMPT.replace("DEC_TEXT", dec_text) + # set the text to the prompt + self.input_text.setText(prompt) + self.send_message(add_text=False, role="system") + + def add_message(self, text, is_user): + # Message bubble + message_label = QLabel(text) + message_label.setWordWrap(True) + message_label.setFont(QFont('Arial', 12)) + message_label.setTextInteractionFlags(Qt.TextSelectableByMouse) + + # Bubble styling + bubble = QFrame() + bubble_layout = QHBoxLayout() + bubble.setLayout(bubble_layout) + + if is_user: + # User message on the right + message_label.setStyleSheet(""" + background-color: #DCF8C6; + color: black; + padding: 10px; + border-radius: 10px; + """) + bubble_layout.addStretch() + bubble_layout.addWidget(message_label) + else: + # Assistant message on the left + message_label.setStyleSheet(""" + background-color: #FFFFFF; + color: black; + padding: 10px; + border-radius: 10px; + """) + bubble_layout.addWidget(message_label) + bubble_layout.addStretch() + + self.chat_layout.insertWidget(self.chat_layout.count() - 1, bubble) + QCoreApplication.processEvents() + self.chat_area.verticalScrollBar().setValue(self.chat_area.verticalScrollBar().maximum()) + + def send_message(self, add_text=True, role="user"): + user_text = self.input_text.toPlainText().strip() + if not user_text: + return + + # Display user message + if add_text: + self.add_message(user_text, is_user=True) + self.input_text.clear() + + # Append to chat history + self.chat_history.append({"role": role, "content": user_text}) + + # Disable input while waiting for response + self.input_text.setDisabled(True) + self.send_button.setDisabled(True) + + # Start a thread to get the response + self.thread = LLMThread(self.chat_history, self.ai_api.model if self.ai_api else "gpt-4o") + self.thread.response_received.connect(lambda msg: self.receive_message(msg)) + self.thread.start() + + def receive_message(self, assistant_message): + # Display assistant message + self.add_message(assistant_message, is_user=False) + + # Append to chat history + self.chat_history.append({"role": "user", "content": assistant_message}) + + # Re-enable input + self.input_text.setDisabled(False) + self.send_button.setDisabled(False) + + def closeEvent(self, event): + # Ensure that the thread is properly terminated when the window is closed + if hasattr(self, 'thread') and self.thread.isRunning(): + self.thread.terminate() + event.accept() + + +class LLMThread(QThread): + response_received = pyqtSignal(str) + + def __init__(self, chat_history, model_name): + super().__init__() + self.chat_history = chat_history.copy() + self.model_name = model_name + + def run(self): + from litellm import completion + + response = completion( + model=self.model_name, + messages=self.chat_history, + timeout=60, + ) + + try: + answer = response.choices[0].message.content + except (KeyError, IndexError) as e: + answer = f"Error: {e}. Please close the window and try again." + + self.response_received.emit(answer)