Skip to content

Commit

Permalink
Add LLM Chat Support (IDA, some Binja) (#60)
Browse files Browse the repository at this point in the history
* Add LLM Chat Support (IDA, some Binja)

* Use local model cost for LiteLLM
  • Loading branch information
mahaloz authored Oct 2, 2024
1 parent 333bd33 commit 23e961d
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 5 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ DAILA was featured in the keynote talk at [HITCON CMT 2023](https://youtu.be/Hbr
## Supported Decompilers and AI Systems
DAILA interacts with the decompiler abstractly through the [LibBS](https://github.com/binsync/libbs) library.
This allows DAILA to support the following decompilers:
- IDA Pro: **>= 7.3**
- Ghidra: **>= 10.1**
- IDA Pro: **>= 8.4**
- Ghidra: **>= 11.1**
- Binary Ninja: **>= 2.4**
- angr-management: **>= 9.0**

Expand Down Expand Up @@ -110,6 +110,7 @@ Currently, DAILA supports the following prompts:
- Identify the source of a function
- Find potential vulnerabilities in a function
- Summarize the man page of a library call
- Free prompting... just type in your own prompt!

### VarBERT
VarBERT is a local BERT model from the S&P 2024 paper [""Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning"]().
Expand Down
14 changes: 12 additions & 2 deletions dailalib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
__version__ = "3.8.2"
__version__ = "3.9.0"

import os
# stop LiteLLM from querying at all to the remote server
# https://github.com/BerriAI/litellm/blob/4d29c1fb6941e49191280c4fd63961dec1a1e7c5/litellm/__init__.py#L286C20-L286C48
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

from .api import AIAPI, LiteLLMAIAPI
from libbs.api import DecompilerInterface

from dailalib.llm_chat import get_llm_chat_creator


def create_plugin(*args, **kwargs):
from libbs.api import DecompilerInterface

#
# LLM API (through LiteLLM api)
Expand All @@ -16,6 +23,9 @@ def create_plugin(*args, **kwargs):
f"DAILA/LLM/{prompt_name}": (prompt.desc, getattr(litellm_api, prompt_name))
for prompt_name, prompt in litellm_api.prompts_by_name.items()
}
# create context menu for llm chat
gui_ctx_menu_actions["DAILA/LLM/chat"] = ("Open LLM Chat...", get_llm_chat_creator(litellm_api))

# create context menus for others
gui_ctx_menu_actions["DAILA/LLM/Settings/update_api_key"] = ("Update API key...", litellm_api.ask_api_key)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style)
Expand Down
4 changes: 3 additions & 1 deletion dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import typing
from typing import Optional
import os

import tiktoken

from ..ai_api import AIAPI


class LiteLLMAIAPI(AIAPI):
prompts_by_name = []
DEFAULT_MODEL = "gpt-4o"
Expand All @@ -30,6 +30,7 @@ def __init__(
model: str = DEFAULT_MODEL,
prompts: Optional[list] = None,
fit_to_tokens: bool = False,
chat_use_ctx: bool = True,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -38,6 +39,7 @@ def __init__(
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.model = model
self.fit_to_tokens = fit_to_tokens
self.chat_use_ctx = chat_use_ctx

# delay prompt import
from .prompts import PROMPTS, DEFAULT_STYLE
Expand Down
26 changes: 26 additions & 0 deletions dailalib/llm_chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

from libbs.api import DecompilerInterface
from libbs.decompilers import IDA_DECOMPILER, ANGR_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER

from ..api import AIAPI

_l = logging.getLogger(__name__)


def get_llm_chat_creator(ai_api: AIAPI) -> callable:
# determine the current decompiler
current_decompiler = DecompilerInterface.find_current_decompiler()
add_llm_chat_to_ui = lambda *args, **kwargs: None
if current_decompiler == IDA_DECOMPILER:
from dailalib.llm_chat.ida import add_llm_chat_to_ui
elif current_decompiler == BINJA_DECOMPILER:
from dailalib.llm_chat.binja import add_llm_chat_to_ui
else:
_l.warning(f"LLM Chat not supported for decompiler %s", current_decompiler)

def llm_chat_creator_wrapper(*args, **kwargs):
ai_api.info(f"Opening LLM Chat with model {ai_api.model}...")
return add_llm_chat_to_ui(ai_api=ai_api, *args, **kwargs)

return llm_chat_creator_wrapper
181 changes: 181 additions & 0 deletions dailalib/llm_chat/binja.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from binaryninjaui import (
UIContext,
DockHandler,
DockContextHandler,
UIAction,
UIActionHandler,
Menu,
)

import traceback
import sys
import logging

from PySide6.QtWidgets import (
QDockWidget,
QWidget,
QApplication,
QMenu,
QMainWindow,
QMenuBar, QVBoxLayout,
)
from PySide6.QtCore import Qt
from binaryninjaui import DockContextHandler
import binaryninja

from .llm_chat_ui import LLMChatClient
from ..api import AIAPI

_l = logging.getLogger(__name__)

# this code is based on:
# https://github.com/binsync/binsync/blob/fa754795b7d55e4b5de4e12ea7d4b5f9706c23af/plugins/binja_binsync/binja_binsync.py

def find_main_window():
main_window = None
for x in QApplication.allWidgets():
if not isinstance(x, QDockWidget):
continue
main_window = x.parent()
if isinstance(main_window, (QMainWindow, QWidget)):
break
else:
main_window = None

if main_window is None:
# oops cannot find the main window
raise Exception("Main window is not found.")
return main_window


dockwidgets = [ ]


# shamelessly copied from https://github.com/Vector35/debugger
def create_widget(widget_class, name, parent, data, *args):
# It is imperative this function return *some* value because Shiboken will try to deref what we return
# If we return nothing (or throw) there will be a null pointer deref (and we won't even get to see why)
# So in the event of an error or a nothing, return an empty widget that at least stops the crash
try:
# binsync specific code
if not isinstance(data, binaryninja.BinaryView):
raise Exception('expected an binary view')
new_bv = args[0]
# uses only a bv_controller
widget = widget_class(new_bv, parent=parent, name=name, data=data)
if not widget:
raise Exception('expected widget, got None')

global dockwidgets

found = False
for (bv, widgets) in dockwidgets:
if bv == data:
widgets[name] = widget
found = True

if not found:
dockwidgets.append((data, {
name: widget
}))

widget.destroyed.connect(lambda destroyed: destroy_widget(destroyed, widget, data, name))

return widget
except Exception:
traceback.print_exc(file=sys.stderr)
return QWidget(parent)


def destroy_widget(destroyed, old, data, name):
# Gotta be careful to delete the correct widget here
for (bv, widgets) in dockwidgets:
if bv == data:
for (name, widget) in widgets.items():
if widget == old:
# If there are no other references to it, this will be the only one and the call
# will delete it and invoke __del__.
widgets.pop(name)
return


class BinjaWidgetBase:
def __init__(self):
self._main_window = None
self._menu_bar = None
self._tool_menu = None

@property
def main_window(self):
if self._main_window is None:
self._main_window = find_main_window()
return self._main_window

@property
def menu_bar(self):
if self._menu_bar is None:
self._menu_bar = next(
iter(x for x in self._main_window.children() if isinstance(x, QMenuBar))
)
return self._menu_bar

@property
def tool_menu(self):
if self._tool_menu is None:
self._tool_menu = next(
iter(
x
for x in self._menu_bar.children()
if isinstance(x, QMenu) and x.title() == u"Tools"
)
)
return self._tool_menu

def add_tool_menu_action(self, name, func):
self.tool_menu.addAction(name, func)


class BinjaDockWidget(QWidget, DockContextHandler):
def __init__(self, name, parent=None):
QWidget.__init__(self, parent)
DockContextHandler.__init__(self, self, name)

self.base = BinjaWidgetBase()

self.show()

def toggle(self):
if self.isVisible():
self.hide()
else:
self.show()


class BinjaWidget(QWidget):
def __init__(self, tabname):
super(BinjaWidget, self).__init__()


class LLMChatClientDockWidget(BinjaDockWidget):
def __init__(self, ai_api, parent=None, name=None, data=None):
super().__init__(name, parent=parent)
self.data = data
self._widget = None
self.ai_api = ai_api

self._widget = LLMChatClient(self.ai_api)
layout = QVBoxLayout()
layout.addWidget(self._widget)
self.setLayout(layout)


def add_llm_chat_to_ui(*args, ai_api: AIAPI = None, **kwargs):
# control panel (per BV)
dock_handler = DockHandler.getActiveDockHandler()
dock_handler.addDockWidget(
"LLM Chat",
lambda n, p, d: create_widget(LLMChatClientDockWidget, n, p, d, ai_api._dec_interface.bv),
Qt.RightDockWidgetArea,
Qt.Vertical,
True
)
61 changes: 61 additions & 0 deletions dailalib/llm_chat/ida.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging

from PyQt5 import sip
from PyQt5.QtWidgets import QWidget, QVBoxLayout

import idaapi

from libbs.ui.version import set_ui_version
set_ui_version("PyQt5")

from dailalib.llm_chat.llm_chat_ui import LLMChatClient

_l = logging.getLogger(__name__)

# disable the annoying "Running Python script" wait box that freezes IDA at times
idaapi.set_script_timeout(0)


class LLMChatWrapper(object):
NAME = "LLM Chat"

def __init__(self, ai_api, context=None):
# create a dockable view
self.twidget = idaapi.create_empty_widget(LLMChatWrapper.NAME)
self.widget = sip.wrapinstance(int(self.twidget), QWidget)
self.widget.name = LLMChatWrapper.NAME
self.width_hint = 250

self._ai_api = ai_api
self._context = context
self._w = None

self._init_widgets()

def _init_widgets(self):
self._w = LLMChatClient(self._ai_api, context=self._context)
layout = QVBoxLayout()
layout.addWidget(self._w)
layout.setContentsMargins(2,2,2,2)
self.widget.setLayout(layout)


def add_llm_chat_to_ui(*args, ai_api=None, context=None, **kwargs):
"""
Open the control panel view and attach it to IDA View-A or Pseudocode-A.
"""
wrapper = LLMChatWrapper(ai_api, context=context)
if not wrapper.twidget:
_l.info("Unable to find a widget to attach to. You are likely running headlessly")
return None

flags = idaapi.PluginForm.WOPN_TAB | idaapi.PluginForm.WOPN_RESTORE | idaapi.PluginForm.WOPN_PERSIST
idaapi.display_widget(wrapper.twidget, flags)
wrapper.widget.visible = True
target = "Pseudocode-A"
dwidget = idaapi.find_widget(target)

if not dwidget:
target = "IDA View-A"

idaapi.set_dock_pos(LLMChatWrapper.NAME, target, idaapi.DP_RIGHT)
Loading

0 comments on commit 23e961d

Please sign in to comment.