Skip to content

Commit

Permalink
Improve the answer using ChatGPT.
Browse files Browse the repository at this point in the history
  • Loading branch information
yjcyxky committed Aug 12, 2024
1 parent a0cf4cb commit 7d7227b
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 75 deletions.
29 changes: 5 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ itertools = "0.13.0"
urlencoding = "2.1.3"

# Models
openai-api-rs = "2.1.4"
openai-api-rs = "5.0.4"

# Algorithms
kiddo = "4.2.0" # for KNN
Expand Down
3 changes: 2 additions & 1 deletion examples/example.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ SELECT
vector_to_float4(rte.embedding, 400, false),
vector_to_float4(ee2.embedding, 400, false),
12.0,
true
true,
false
) AS score
FROM
biomedgps_entity_embedding ee1,
Expand Down
2 changes: 1 addition & 1 deletion migrations/20230912_enable_searching.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ALTER TABLE biomedgps_knowledge_curation
ADD COLUMN payload JSONB DEFAULT '{"project_id": "0", "organization_id": "0"}';

-- Enable intelligent searching for the entity table
CREATE EXTENSION pg_trgm;
CREATE EXTENSION IF NOT EXISTS pg_trgm;


CREATE INDEX IF NOT EXISTS idx_trgm_id_entity_table ON biomedgps_entity USING gin(id gin_trgm_ops);
Expand Down
32 changes: 32 additions & 0 deletions src/api/publication.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::model::llm::ChatBot;
use anyhow;
use log::info;
use poem_openapi::Object;
Expand Down Expand Up @@ -216,6 +217,37 @@ impl Publication {
})
}

pub async fn fetch_summary_by_chatgpt(
question: &str,
publications: &Vec<Publication>,
) -> Result<PublicationsSummary, anyhow::Error> {
let openai_api_key = std::env::var("OPENAI_API_KEY").unwrap();
if openai_api_key.is_empty() {
return Err(anyhow::Error::msg("OPENAI_API_KEY not found"));
}

let chatbot = ChatBot::new("GPT4", &openai_api_key);

let publications = publications.iter().map(|p| {
format!("Title: {}\nAuthors: {}\nJournal: {}\nYear: {}\nSummary: {}\nAbstract: {}\nDOI: {}\n", p.title, p.authors.join(", "), p.journal, p.year.unwrap_or(0), p.summary, p.article_abstract.as_ref().unwrap_or(&"".to_string()), p.doi.as_ref().unwrap_or(&"".to_string()))
}).collect::<Vec<String>>();

let prompt = format!(
"I have a collection of papers wrappered by the ```:\n```\n{}\n```\n\nPlease carefully analyze these papers to answer the following question: \n{}\n\nIn your response, please provide a well-integrated analysis that directly answers the question. Include citations from specific papers to support your answer, and ensure that the reasoning behind your answer is clearly explained. Reference relevant details from the papers' summaries or abstracts as needed.",
publications.join("\n"),
question,
);

let response = chatbot.answer(prompt).await?;
Ok(PublicationsSummary {
summary: response,
daily_limit_reached: false,
is_disputed: false,
is_incomplete: false,
results_analyzed_count: 0,
})
}

pub async fn fetch_summary(search_id: &str) -> Result<PublicationsSummary, anyhow::Error> {
let api_token = match std::env::var("GUIDESCOPER_API_TOKEN") {
Ok(token) => token,
Expand Down
25 changes: 25 additions & 0 deletions src/api/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,31 @@ impl BiomedgpsApi {
}
}

/// Call `/api/v1/publications-summary` with query params to fetch publication summary.
#[oai(
path = "/publications-summary",
method = "post",
tag = "ApiTags::KnowledgeGraph",
operation_id = "answerQuestionWithPublications"
)]
async fn answer_question_with_publications(
&self,
publications: Json<Vec<Publication>>,
question: Query<String>,
_token: CustomSecurityScheme,
) -> GetPublicationsSummaryResponse {
let question = question.0;
let publications = publications.0;
match Publication::fetch_summary_by_chatgpt(&question, &publications).await {
Ok(result) => GetPublicationsSummaryResponse::ok(result),
Err(e) => {
let err = format!("Failed to fetch publications summary: {}", e);
warn!("{}", err);
return GetPublicationsSummaryResponse::bad_request(err);
}
}
}

/// Call `/api/v1/publications/:id` to fetch a publication.
#[oai(
path = "/publications/:id",
Expand Down
62 changes: 58 additions & 4 deletions src/bin/biomedgps-cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ pub struct CleanDBArguments {
#[structopt(name = "database_url", short = "d", long = "database-url")]
database_url: Option<String>,

/// Which table to clean. e.g. entity, relation, entity_metadata, relation_metadata, knowledge_curation, subgraph, entity2d, compound-disease-symptom, gene-disease-symptom, knowledge-score, embedding, graph etc.
#[structopt(name = "table", short = "t", long = "table", possible_values = &["entity", "entity2d", "relation", "relation_metadata", "entity_metadata", "knowledge_curation", "subgraph", "compound-disease-symptom", "gene-disease-symptom", "knowledge-score", "embedding", "graph"])]
table: String,
/// [Required] The table name to clean. e.g We will empty all entity-related tables if you use the entity table name. such as entity, entity_metadata, entity2d.
#[structopt(name = "table", short = "t", long = "table", possible_values = &["entity", "relation", "embedding", "subgraph", "curation", "score", "message", "metadata"], multiple = true)]
table: Vec<String>,
}

/// Import data files into database, such as entity, relation, entity_metadata, relation_metadata, knowledge_curation, subgraph, entity2d etc. When you import the entity data, we will also sync the entity data to the graph database. But the relation data will be synced to the graph database in the cachetable command, because we need to compute the score for the relation data first. The entity_metadata and relation_metadata are generated by the importdb command automatically, actually, you don't need to prepare the entity_metadata and relation_metadata files. But you must use the importdb command manually to upgrade the entity_metadata and relation_metadata tables after the entity and relation tables are upgraded or the first time you run the application. In the most cases, you don't need to import knowledge_curation and subgraph data, we might import them at the migration stage. The entity_2d table is used to store the 2D embedding data, you need to prepare the 2D embedding data manually. If you have multiple models, you might need to choose one model to compute the 2D embedding data. The 2D embedding data is used to visualize the entity data in the 2D space.
Expand Down Expand Up @@ -739,7 +739,61 @@ async fn main() {
.await
}
SubCommands::CleanDB(arguments) => {
info!("To be implemented.")
let database_url = if arguments.database_url.is_none() {
match std::env::var("DATABASE_URL") {
Ok(v) => v,
Err(_) => {
error!("{}", "DATABASE_URL is not set.");
std::process::exit(1);
}
}
} else {
arguments.database_url.unwrap()
};

let pool = match sqlx::PgPool::connect(&database_url).await {
Ok(v) => v,
Err(e) => {
error!("Connect to database failed: {}", e);
std::process::exit(1);
}
};

let mut table_names_map = HashMap::<&str, Vec<&str>>::new();
let pairs = vec![
("message", vec!["biomedgps_ai_message"]),
("score", vec!["biomedgps_compound_disease_symptom_score", "biomedgps_gene_disease_symptom_score", "biomedgps_relation_with_score"]),
("metadata", vec!["biomedgps_compound_metadata", "biomedgps_journal_metadata"]),
("entity", vec!["biomedgps_entity", "biomedgps_entity2d", "biomedgps_entity_metadata"]),
("relation", vec!["biomedgps_relation", "biomedgps_relation_metadata"]),
("embedding", vec!["biomedgps_entity_embedding", "biomedgps_relation_embedding", "biomedgps_embedding_metadata"]),
("subgraph", vec!["biomedgps_subgraph"]),
("curation", vec!["biomedgps_knowledge_curation"])
];

for pair in pairs {
table_names_map.insert(pair.0, pair.1);
}


let tables = arguments.table;
for table in tables {
let table_names = table_names_map.get(table.as_str());
if table_names.is_none() {
error!("The table name is not supported.");
std::process::exit(1);
}

let table_names = table_names.unwrap();
for table_name in table_names {
let sql = format!("TRUNCATE TABLE {}", table_name);
match sqlx::query(&sql).execute(&pool).await {
Ok(_) => info!("Clean the {} table successfully.", table_name),
Err(e) => error!("Clean the {} table failed: {}", table_name, e),
}
}
}

}
SubCommands::StatDB(arguments) => {
let database_url = if arguments.database_url.is_none() {
Expand Down
3 changes: 2 additions & 1 deletion src/model/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,8 @@ impl Graph {
// vector_to_float4(rte.embedding, 400, false),
// vector_to_float4(ee2.embedding, 400, false),
// 12.0,
// true
// true,
// false
// ) AS score
// FROM
// biomedgps_entity_embedding ee1,
Expand Down
32 changes: 18 additions & 14 deletions src/model/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use chrono::serde::ts_seconds;
use chrono::{DateTime, Utc};
use lazy_static::lazy_static;
use log::warn;
use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCall, MessageRole};
use openai_api_rs::v1::common::{GPT3_5_TURBO, GPT4, GPT4_1106_PREVIEW};
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, MessageRole, ToolCall};
use openai_api_rs::v1::common::{GPT3_5_TURBO, GPT4_O};
use openssl::hash::{hash, MessageDigest};
use poem_openapi::{Enum, Object};
use regex::Regex;
Expand Down Expand Up @@ -474,7 +474,7 @@ where
if self.message.len() > 0 {
return Ok(self);
} else {
self.message = match chatbot.answer(prompt) {
self.message = match chatbot.answer(prompt).await {
Ok(message) => message,
Err(e) => {
warn!("Failed to answer the question: {}", e.to_string());
Expand Down Expand Up @@ -508,9 +508,9 @@ pub struct ChatBot {
role: MessageRole,
name: Option<String>,
content: Option<String>,
function_call: Option<FunctionCall>,
tool_call: Option<Vec<ToolCall>>,
model_name: String,
client: Client,
client: OpenAIClient,
}

impl ChatBot {
Expand All @@ -519,37 +519,39 @@ impl ChatBot {
// GPT4 or GPT4_1106_PREVIEW
// https://platform.openai.com/account/limits
//
GPT4_1106_PREVIEW.to_string()
GPT4_O.to_string()
} else {
GPT3_5_TURBO.to_string()
};

let client = Client::new(openai_api_key.to_string());
let client = OpenAIClient::new(openai_api_key.to_string());

ChatBot {
role: MessageRole::user,
name: None,
content: None,
function_call: None,
tool_call: None,
model_name: model,
client: client,
}
}

pub fn answer(&self, prompt: String) -> Result<String, anyhow::Error> {
pub async fn answer(&self, prompt: String) -> Result<String, anyhow::Error> {
let model_name = self.model_name.clone();
let req = ChatCompletionRequest::new(
model_name,
vec![chat_completion::ChatCompletionMessage {
role: self.role.clone(),
content: prompt,
content: chat_completion::Content::Text(prompt),
name: self.name.clone(),
function_call: self.function_call.clone(),
// TODO: How to use the tool_call?
tool_calls: None,
tool_call_id: None,
}],
);

let req = req.temperature(0.5);
let result = self.client.chat_completion(req)?;
let result = self.client.chat_completion(req).await?;
let message = result.choices[0].message.content.clone();

match message {
Expand Down Expand Up @@ -580,7 +582,9 @@ mod tests {
xrefs: None,
};

let mut llm_msg = super::LlmMessage::new("node_summary", node, None).unwrap();
super::init_prompt_templates();

let mut llm_msg = super::LlmMessage::new("explain_node_summary", node, None).unwrap();
let answer = llm_msg.answer(&chatbot, None).await.unwrap();
println!("Prompt: {}", answer.prompt);
println!("Answer: {}", answer.message);
Expand Down
Loading

0 comments on commit 7d7227b

Please sign in to comment.