Skip to content

Commit

Permalink
fix: anthropic adapter with multiple consecutive user prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
olimorris committed Mar 18, 2024
1 parent 145d04c commit ea789e6
Showing 1 changed file with 42 additions and 17 deletions.
59 changes: 42 additions & 17 deletions lua/codecompanion/adapters/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,46 @@ local function get_system_prompt(tbl)
end
end

local function merge_messages(messages)
local new_messages = {}
local temp_content = {} -- Temporary storage for consecutive "user" messages content
local last_role = nil -- Track the last role we encountered

for _, message in ipairs(messages) do
if message.role == "user" then
if last_role == "user" then
-- If the last role was also "user", we continue accumulating the content
table.insert(temp_content, message.content)
else
-- If we encounter "user" after a different role, start a new accumulation
temp_content = { message.content }
end
else
-- For any non-user message:
if last_role == "user" then
-- If the last message was a user message, we need to insert the accumulated content first
table.insert(new_messages, {
role = "user",
content = table.concat(temp_content, " "),
})
end
-- Insert the current non-user message
table.insert(new_messages, message)
end
last_role = message.role
end

-- After looping, check if the last messages were from "user" and need to be inserted
if last_role == "user" then
table.insert(new_messages, {
role = "user",
content = table.concat(temp_content, " "),
})
end

return new_messages
end

---@class CodeCompanion.Adapter
---@field name string
---@field url string
Expand Down Expand Up @@ -56,23 +96,8 @@ return {
table.remove(messages, system_prompt_index)
end

-- Combine user prompts into a single prompt
local user_role_content = {}
for i = #messages, 1, -1 do
local message = messages[i]
if message.role == "user" then
table.insert(user_role_content, 1, message.content)
table.remove(messages, i)
end
end
local output = table.concat(user_role_content, " ")

if output ~= "" then
table.insert(messages, {
role = "user",
content = output,
})
end
-- Combine consecutive user prompts into a single prompt
messages = merge_messages(messages)

return { messages = messages }
end,
Expand Down

0 comments on commit ea789e6

Please sign in to comment.