Skip to content

Commit

Permalink
[FEATURE] Compile time model selection, summarization
Browse files Browse the repository at this point in the history
[FEATURE] Compile time model selection, summarization
  • Loading branch information
darrensiegel authored Nov 27, 2023
2 parents 36c955f + 718d96c commit 717a733
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 90 deletions.
20 changes: 20 additions & 0 deletions lib/oli/conversation/common.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
defmodule Oli.Conversation.Common do
def estimate_token_length(function) when is_map(function) do
Jason.encode!(function)
|> estimate_token_length()
end

def estimate_token_length(content) do
String.length(content) |> div(4)
end

def summarize_prompt() do
"""
You are a converstational summarization agent that is tasked with summarizing long conversations between
human users and a large language model conversational agent. Your primary goal is to summarize the conversation
into a single "summary" paragraph of no longer that 5 or 6 sentences which captures the essence of the conversation.
You are given a conversation below as a series of "assistant" and "user" messages. You can summarize the conversation
along the lines of "user asked about X. Assistant responded with Y. User asked about Z. Assistant responded with A. etc"
"""
end
end
141 changes: 68 additions & 73 deletions lib/oli/conversation/dialogue.ex
Original file line number Diff line number Diff line change
@@ -1,89 +1,44 @@
defmodule Oli.Conversation.Dialogue do
require Logger
alias Oli.Conversation.Message
alias Oli.Conversation.Functions
import Oli.Conversation.Common
alias Oli.Conversation.Model

defstruct [
:model,
:rendered_messages,
:messages,
:response_handler_fn,
:functions
:functions,
:functions_token_length
]

def init(system_message, response_handler_fn) do
@token_usage_high_watermark 0.9

def new(system_message, response_handler_fn, options \\ []) do
model = options[:model] || Oli.Conversation.Model.default()

system_message = Message.new(:system, system_message)

%__MODULE__{
messages: [%Message{role: :system, content: system_message}],
model: Oli.Conversation.Model.model(model),
rendered_messages: [],
messages: [system_message],
response_handler_fn: response_handler_fn,
functions: [
%{
name: "up_next",
description:
"Returns the next scheduled lessons in the course as a list of objects with the following keys: title, url, due_date, num_attempts_taken",
parameters: %{
type: "object",
properties: %{
current_user_id: %{
type: "integer",
description: "The current student's user id"
},
section_id: %{
type: "integer",
description: "The current course section's id"
}
},
required: ["current_user_id", "section_id"]
}
},
%{
name: "avg_score_for",
description:
"Returns average score across all scored assessments, as a floating point number between 0 and 1, for a given user and section",
parameters: %{
type: "object",
properties: %{
current_user_id: %{
type: "integer",
description: "The current student's user id"
},
section_id: %{
type: "integer",
description: "The current course section's id"
}
},
required: ["current_user_id", "section_id"]
}
},
%{
name: "relevant_course_content",
description: """
Useful when a question asked by a student cannot be adequately answered by the context of the current lesson.
Allows the retrieval of relevant course content from other lessons in the course based on the
student's question. Returns an array of course lessons with the following keys: title, url, content.
""",
parameters: %{
type: "object",
properties: %{
student_input: %{
type: "string",
description: "The student question or input"
},
section_id: %{
type: "integer",
description: "The current course section's id"
}
},
required: ["student_input", "section_id"]
}
}
]
functions: Functions.functions(),
functions_token_length: Functions.total_token_length()
}
end

def engage(
%__MODULE__{messages: messages, response_handler_fn: response_handler_fn} = dialogue,
%__MODULE__{model: model, messages: messages, response_handler_fn: response_handler_fn} =
dialogue,
:async
) do
OpenAI.chat_completion(
[
model: "gpt-3.5-turbo",
model: model,
messages: encode_messages(messages),
functions: dialogue.functions,
stream: true
Expand All @@ -95,20 +50,18 @@ defmodule Oli.Conversation.Dialogue do
{:delta, type, content} ->
response_handler_fn.(dialogue, type, content)

e ->
IO.inspect(e)
_e ->
Logger.info("Response finished")
end
end)
|> Enum.to_list()
|> IO.inspect()
end

def engage(%__MODULE__{messages: messages} = dialogue, :sync) do
def engage(%__MODULE__{messages: messages, model: model} = dialogue, :sync) do
result =
OpenAI.chat_completion(
[
model: "gpt-3.5-turbo",
model: model,
messages: encode_messages(messages),
functions: dialogue.functions
],
Expand All @@ -127,10 +80,52 @@ defmodule Oli.Conversation.Dialogue do
end)
end

def add_message(%__MODULE__{messages: messages} = dialog, message) do
def add_message(
%__MODULE__{messages: messages, rendered_messages: rendered_messages} = dialog,
message
) do
dialog = %{dialog | rendered_messages: rendered_messages ++ [message]}
%{dialog | messages: messages ++ [message]}
end

def summarize(%__MODULE__{messages: messages, model: model} = dialog) do
summarize_messages =
case messages do
[_system | rest] -> [Message.new(:system, summarize_prompt()) | rest]
end

[system | _rest] = messages

case OpenAI.chat_completion(
[model: model, messages: encode_messages(summarize_messages)],
config(:sync)
) do
{:ok, %{choices: [first | _rest]}} ->
summary = Message.new(:system, first["message"]["content"])

messages = [system, summary]

%{dialog | messages: messages}

_e ->
IO.inspect("Failed to summarize")
dialog
end
end

def should_summarize?(%__MODULE__{model: model} = dialog) do
total_token_length(dialog) > Model.token_limit(model) * @token_usage_high_watermark
end

def total_token_length(%__MODULE__{
messages: messages,
functions_token_length: functions_token_length
}) do
Enum.reduce(messages, functions_token_length, fn message, acc ->
acc + message.token_length
end)
end

defp delta(chunk) do
case chunk["choices"] do
[] -> {:finished}
Expand Down
Empty file removed lib/oli/conversation/function.ex
Empty file.
71 changes: 71 additions & 0 deletions lib/oli/conversation/functions.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
defmodule Oli.Conversation.Functions do
import Oli.Conversation.Common

@functions [
%{
name: "up_next",
description:
"Returns the next scheduled lessons in the course as a list of objects with the following keys: title, url, due_date, num_attempts_taken",
parameters: %{
type: "object",
properties: %{
current_user_id: %{
type: "integer",
description: "The current student's user id"
},
section_id: %{
type: "integer",
description: "The current course section's id"
}
},
required: ["current_user_id", "section_id"]
}
},
%{
name: "avg_score_for",
description:
"Returns average score across all scored assessments, as a floating point number between 0 and 1, for a given user and section",
parameters: %{
type: "object",
properties: %{
current_user_id: %{
type: "integer",
description: "The current student's user id"
},
section_id: %{
type: "integer",
description: "The current course section's id"
}
},
required: ["current_user_id", "section_id"]
}
},
%{
name: "relevant_course_content",
description: """
Useful when a question asked by a student cannot be adequately answered by the context of the current lesson.
Allows the retrieval of relevant course content from other lessons in the course based on the
student's question. Returns an array of course lessons with the following keys: title, url, content.
""",
parameters: %{
type: "object",
properties: %{
student_input: %{
type: "string",
description: "The student question or input"
},
section_id: %{
type: "integer",
description: "The current course section's id"
}
},
required: ["student_input", "section_id"]
}
}
]

def functions, do: @functions

def total_token_length,
do: Enum.reduce(@functions, 0, fn f, acc -> acc + estimate_token_length(f) end)
end
11 changes: 8 additions & 3 deletions lib/oli/conversation/message.ex
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
defmodule Oli.Conversation.Message do
import Oli.Conversation.Common

@derive Jason.Encoder
defstruct [
:role,
:content,
:name
:name,
:token_length
]

def new(role, content) do
%__MODULE__{
role: role,
content: content
content: content,
token_length: estimate_token_length(content)
}
end

def new(role, content, name) do
%__MODULE__{
role: role,
content: content,
name: name
name: name,
token_length: estimate_token_length(content)
}
end
end
18 changes: 18 additions & 0 deletions lib/oli/conversation/model.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
defmodule Oli.Conversation.Model do
@default_model :fast

def default(), do: @default_model

def model(:fast), do: "gpt-3.5-turbo"
def model(:large_context), do: "gpt-4"
def model(:largest_context), do: "gpt-4-1106-preview"

def token_limit("gpt-3.5-turbo"), do: 4096
def token_limit("gpt-4"), do: 8192
def token_limit("gpt-4-1106-preview"), do: 128_000

def token_limit(atom) when is_atom(atom) do
model(atom)
|> token_limit()
end
end
1 change: 1 addition & 0 deletions lib/oli/search/embedding_worker.ex
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ defmodule Oli.Search.EmbeddingWorker do
|> Map.delete(:id)
|> Map.delete(:updated_at)
|> Map.delete(:inserted_at)
|> Map.delete(:distance)
end)

expected_num_inserts = Enum.count(attrs)
Expand Down
6 changes: 3 additions & 3 deletions lib/oli/search/embeddings.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ defmodule Oli.Search.Embeddings do
on: r.id == re.revision_id,
where: pr.publication_id == ^publication_id,
where: re.chunk_type == :paragraph,
order_by: cosine_distance(re.embedding, ^embedding),
limit: 5,
order_by: l2_distance(re.embedding, ^embedding),
limit: 10,
select_merge: %{title: r.title, distance: cosine_distance(re.embedding, ^embedding)}

Repo.all(query)
Expand All @@ -48,7 +48,7 @@ defmodule Oli.Search.Embeddings do
[spp, _p, re, _r],
spp.section_id == ^section_id and re.chunk_type == :paragraph
)
|> order_by([_spp, _p, re, _r], cosine_distance(re.embedding, ^embedding))
|> order_by([_spp, _p, re, _r], l2_distance(re.embedding, ^embedding))
|> limit(10)
|> select([_spp, _p, re, _r], re.revision_id)

Expand Down
Loading

0 comments on commit 717a733

Please sign in to comment.