Skip to content

Commit

Permalink
chat: flake fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rmackay9 committed Jan 3, 2024
1 parent eacee97 commit 90f18a4
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
2 changes: 2 additions & 0 deletions MAVProxy/modules/mavproxy_chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from threading import Thread
import time


class chat(mp_module.MPModule):
def __init__(self, mpstate):

Expand Down Expand Up @@ -110,6 +111,7 @@ def wait_for_command_ack(self, mav_cmd, timeout=1):
del self.command_ack_waiting[mav_cmd]
return False


# initialise module
def init(mpstate):
return chat(mpstate)
47 changes: 29 additions & 18 deletions MAVProxy/modules/mavproxy_chat/chat_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
'''

from pymavlink import mavutil
import time, re
import time
import re
from datetime import datetime
from threading import Thread, Lock
import json
import math

try:
from openai import OpenAI
except:
except Exception:
print("chat: failed to import openai. See https://ardupilot.org/mavproxy/docs/modules/chat.html")
exit()


class chat_openai():
def __init__(self, mpstate, status_cb=None, wait_for_command_ack_fn=None):
# keep reference to mpstate
Expand Down Expand Up @@ -51,7 +53,7 @@ def check_connection(self):
if self.client is None:
try:
self.client = OpenAI()
except:
except Exception:
print("chat: failed to connect to OpenAI")
return False

Expand Down Expand Up @@ -90,7 +92,7 @@ def check_connection(self):

# set the OpenAI API key
def set_api_key(self, api_key_str):
self.client = OpenAI(api_key = api_key_str)
self.client = OpenAI(api_key=api_key_str)
self.assistant = None
self.assistant_thread = None

Expand Down Expand Up @@ -126,7 +128,7 @@ def send_to_assistant(self, text):
# wait for one second
time.sleep(0.1)

# retrieve the run
# retrieve the run
latest_run = self.client.beta.threads.runs.retrieve(
thread_id=self.assistant_thread.id,
run_id=self.run.id
Expand Down Expand Up @@ -157,7 +159,9 @@ def send_to_assistant(self, text):
self.send_status(status_message)

# retrieve messages on the thread
reply_messages = self.client.beta.threads.messages.list(self.assistant_thread.id, order = "asc", after=input_message.id)
reply_messages = self.client.beta.threads.messages.list(self.assistant_thread.id,
order="asc",
after=input_message.id)
if reply_messages is None:
return "chat: failed to retrieve messages"

Expand Down Expand Up @@ -228,7 +232,7 @@ def handle_function_call(self, run):
# convert to json
stage_str = "convert output to json"
output = json.dumps(output)
except:
except Exception:
error_message = str(func_name) + ": " + stage_str + " failed"
print("chat: " + error_message)
output = error_message
Expand All @@ -242,12 +246,11 @@ def handle_function_call(self, run):

# send function replies to assistant
try:
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=tool_outputs
)
except:
tool_outputs=tool_outputs)
except Exception:
print("chat: error replying to function call")
print(tool_outputs)

Expand Down Expand Up @@ -278,7 +281,7 @@ def get_vehicle_type(self, arguments):
mavutil.mavlink.MAV_TYPE_OCTOROTOR,
mavutil.mavlink.MAV_TYPE_TRICOPTER,
mavutil.mavlink.MAV_TYPE_DODECAROTOR]:
vehicle_type_str = "Copter"
vehicle_type_str = "Copter"
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_HELICOPTER:
vehicle_type_str = "Heli"
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_ANTENNA_TRACKER:
Expand Down Expand Up @@ -306,7 +309,7 @@ def get_mode_mapping(self, arguments):
# prepare list of modes
mode_list = []
mode_mapping = self.mpstate.master().mode_mapping()

# handle request for all modes
if mode_name is None and mode_number is None:
for mname in mode_mapping:
Expand Down Expand Up @@ -335,7 +338,7 @@ def get_vehicle_state(self, arguments):
hearbeat_msg = self.mpstate.master().messages.get('HEARTBEAT', None)
if hearbeat_msg is None:
mode_number = 0
print ("chat: get_vehicle_state: vehicle mode is unknown")
print("chat: get_vehicle_state: vehicle mode is unknown")
else:
mode_number = hearbeat_msg.custom_mode
return {
Expand Down Expand Up @@ -397,7 +400,10 @@ def send_mavlink_command_int(self, arguments):
x = arguments.get("x", 0)
y = arguments.get("y", 0)
z = arguments.get("z", 0)
self.mpstate.master().mav.command_int_send(target_system, target_component, frame, command, current, autocontinue, param1, param2, param3, param4, x, y, z)
self.mpstate.master().mav.command_int_send(target_system, target_component,
frame, command, current, autocontinue,
param1, param2, param3, param4,
x, y, z)

# wait for command ack
mav_result = self.wait_for_command_ack_fn(command)
Expand Down Expand Up @@ -441,7 +447,12 @@ def send_mavlink_set_position_target_global_int(self, arguments):
afz = arguments.get("afz", 0)
yaw = arguments.get("yaw", 0)
yaw_rate = arguments.get("yaw_rate", 0)
self.mpstate.master().mav.set_position_target_global_int_send(time_boot_ms, target_system, target_component, coordinate_frame, type_mask, lat_int, lon_int, alt, vx, vy, vz, afx, afy, afz, yaw, yaw_rate)
self.mpstate.master().mav.set_position_target_global_int_send(time_boot_ms, target_system, target_component,
coordinate_frame, type_mask,
lat_int, lon_int, alt,
vx, vy, vz,
afx, afy, afz,
yaw, yaw_rate)
return "set_position_target_global_int sent"

# get a list of mavlink message names that can be retrieved using the get_mavlink_message function
Expand Down Expand Up @@ -631,7 +642,7 @@ def wrap_latitude(self, latitude_deg):
if latitude_deg < -90:
return -(180 + latitude_deg)
return latitude_deg

# wrap longitude to range -180 to 180
def wrap_longitude(self, longitude_deg):
if longitude_deg > 180:
Expand All @@ -647,7 +658,7 @@ def send_status(self, status):

# returns true if string contains regex characters
def contains_regex(self, string):
regex_characters = ".^$*+?{}[]\|()"
regex_characters = ".^$*+?{}[]\\|()"
for x in regex_characters:
if string.count(x):
return True
Expand Down
13 changes: 7 additions & 6 deletions MAVProxy/modules/mavproxy_chat/chat_voice_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import pyaudio # install using, "sudo apt-get install python3-pyaudio"
import wave # install with "pip3 install wave"
from openai import OpenAI
except:
except Exception:
print("chat: failed to import pyaudio, wave or openai. See https://ardupilot.org/mavproxy/docs/modules/chat.html")
exit()


class chat_voice_to_text():
def __init__(self):
# initialise OpenAI connection
Expand All @@ -21,7 +22,7 @@ def __init__(self):

# set the OpenAI API key
def set_api_key(self, api_key_str):
self.client = OpenAI(api_key = api_key_str)
self.client = OpenAI(api_key=api_key_str)

# check connection to OpenAI assistant and connect if necessary
# returns True if connection is good, False if not
Expand All @@ -30,7 +31,7 @@ def check_connection(self):
if self.client is None:
try:
self.client = OpenAI()
except:
except Exception:
print("chat: failed to connect to OpenAI")
return False

Expand All @@ -46,7 +47,7 @@ def record_audio(self):
# Open stream
try:
stream = p.open(format=pyaudio.paInt16, channels=1, rate=44100, input=True, frames_per_buffer=1024)
except:
except Exception:
print("chat: failed to connect to microphone")
return None

Expand Down Expand Up @@ -85,7 +86,7 @@ def convert_audio_to_text(self, audio_filename):
# Process with Whisper
audio_file = open(audio_filename, "rb")
transcript = self.client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
model="whisper-1",
file=audio_file,
response_format="text")
return transcript
11 changes: 6 additions & 5 deletions MAVProxy/modules/mavproxy_chat/chat_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from MAVProxy.modules.mavproxy_chat import chat_openai, chat_voice_to_text
from threading import Thread


class chat_window():
def __init__(self, mpstate, wait_for_command_ack_fn):
# keep reference to mpstate
Expand All @@ -35,7 +36,7 @@ def __init__(self, mpstate, wait_for_command_ack_fn):

# add api key input window
self.apikey_frame = wx.Frame(None, title="Input OpenAI API Key", size=(560, 50))
self.apikey_text_input = wx.TextCtrl(self.apikey_frame, id=-1, pos=(10, 10), size=(450, -1), style = wx.TE_PROCESS_ENTER)
self.apikey_text_input = wx.TextCtrl(self.apikey_frame, id=-1, pos=(10, 10), size=(450, -1), style=wx.TE_PROCESS_ENTER)
self.apikey_set_button = wx.Button(self.apikey_frame, id=-1, label="Set", pos=(470, 10), size=(75, 25))
self.apikey_frame.Bind(wx.EVT_BUTTON, self.apikey_set_button_click, self.apikey_set_button)
self.apikey_frame.Bind(wx.EVT_TEXT_ENTER, self.apikey_set_button_click, self.apikey_text_input)
Expand All @@ -48,17 +49,17 @@ def __init__(self, mpstate, wait_for_command_ack_fn):
# add a record button
self.record_button = wx.Button(self.frame, id=-1, label="Rec", size=(75, 25))
self.frame.Bind(wx.EVT_BUTTON, self.record_button_click, self.record_button)
self.horiz_sizer.Add(self.record_button, proportion = 0, flag = wx.ALIGN_TOP | wx.ALL, border = 5)
self.horiz_sizer.Add(self.record_button, proportion=0, flag=wx.ALIGN_TOP | wx.ALL, border=5)

# add an input text box
self.text_input = wx.TextCtrl(self.frame, id=-1, value="", size=(450, -1), style = wx.TE_PROCESS_ENTER)
self.text_input = wx.TextCtrl(self.frame, id=-1, value="", size=(450, -1), style=wx.TE_PROCESS_ENTER)
self.frame.Bind(wx.EVT_TEXT_ENTER, self.text_input_change, self.text_input)
self.horiz_sizer.Add(self.text_input, proportion = 1, flag = wx.ALIGN_TOP | wx.ALL, border = 5)
self.horiz_sizer.Add(self.text_input, proportion=1, flag=wx.ALIGN_TOP | wx.ALL, border=5)

# add a send button
self.send_button = wx.Button(self.frame, id=-1, label="Send", size=(75, 25))
self.frame.Bind(wx.EVT_BUTTON, self.send_button_click, self.send_button)
self.horiz_sizer.Add(self.send_button, proportion = 0, flag = wx.ALIGN_TOP | wx.ALL, border = 5)
self.horiz_sizer.Add(self.send_button, proportion=0, flag=wx.ALIGN_TOP | wx.ALL, border=5)

# add a reply box and read-only text box
self.text_reply = wx.TextCtrl(self.frame, id=-1, size=(600, 80), style=wx.TE_READONLY | wx.TE_MULTILINE | wx.TE_RICH)
Expand Down

0 comments on commit 90f18a4

Please sign in to comment.