diff --git a/Cargo.lock b/Cargo.lock index 4142b7b..037f023 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2077,21 +2077,6 @@ dependencies = [ "adler", ] -[[package]] -name = "minreq" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00a000cf8bbbfb123a9bdc66b61c2885a4bb038df4f2629884caafabeb76b0f9" -dependencies = [ - "log", - "once_cell", - "rustls 0.21.10", - "rustls-webpki 0.101.7", - "serde", - "serde_json", - "webpki-roots 0.25.4", -] - [[package]] name = "mio" version = "0.8.11" @@ -2317,13 +2302,14 @@ checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "openai-api-rs" -version = "2.1.7" +version = "5.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6290b87380dad483fd1d8a001d8dd125292657864fb37d8e4a57e432b006fc49" +checksum = "33f94182337c31471e7da9dea9add6e69005aa072f5b4d667fbc39191ee96356" dependencies = [ - "minreq", + "reqwest", "serde", "serde_json", + "tokio", ] [[package]] @@ -3390,6 +3376,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -4986,12 +4973,6 @@ dependencies = [ "rustls-webpki 0.100.3", ] -[[package]] -name = "webpki-roots" -version = "0.25.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" - [[package]] name = "whoami" version = "1.5.1" diff --git a/Cargo.toml b/Cargo.toml index 2f95caa..a9b7691 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ itertools = "0.13.0" urlencoding = "2.1.3" # Models -openai-api-rs = "2.1.4" +openai-api-rs = "5.0.4" # Algorithms kiddo = "4.2.0" # for KNN diff --git a/examples/example.sql b/examples/example.sql index 7a73e91..9216109 100644 --- a/examples/example.sql +++ b/examples/example.sql @@ -17,7 +17,8 @@ SELECT vector_to_float4(rte.embedding, 400, false), vector_to_float4(ee2.embedding, 400, false), 12.0, - true + true, + false ) AS score FROM biomedgps_entity_embedding ee1, diff --git a/migrations/20230912_enable_searching.up.sql b/migrations/20230912_enable_searching.up.sql index 9b79245..34a157a 100644 --- a/migrations/20230912_enable_searching.up.sql +++ b/migrations/20230912_enable_searching.up.sql @@ -4,7 +4,7 @@ ALTER TABLE biomedgps_knowledge_curation ADD COLUMN payload JSONB DEFAULT '{"project_id": "0", "organization_id": "0"}'; -- Enable intelligent searching for the entity table -CREATE EXTENSION pg_trgm; +CREATE EXTENSION IF NOT EXISTS pg_trgm; CREATE INDEX IF NOT EXISTS idx_trgm_id_entity_table ON biomedgps_entity USING gin(id gin_trgm_ops); diff --git a/src/api/publication.rs b/src/api/publication.rs index 3dfa972..8496134 100644 --- a/src/api/publication.rs +++ b/src/api/publication.rs @@ -1,3 +1,4 @@ +use crate::model::llm::ChatBot; use anyhow; use log::info; use poem_openapi::Object; @@ -216,6 +217,37 @@ impl Publication { }) } + pub async fn fetch_summary_by_chatgpt( + question: &str, + publications: &Vec, + ) -> Result { + let openai_api_key = std::env::var("OPENAI_API_KEY").unwrap(); + if openai_api_key.is_empty() { + return Err(anyhow::Error::msg("OPENAI_API_KEY not found")); + } + + let chatbot = ChatBot::new("GPT4", &openai_api_key); + + let publications = publications.iter().map(|p| { + format!("Title: {}\nAuthors: {}\nJournal: {}\nYear: {}\nSummary: {}\nAbstract: {}\nDOI: {}\n", p.title, p.authors.join(", "), p.journal, p.year.unwrap_or(0), p.summary, p.article_abstract.as_ref().unwrap_or(&"".to_string()), p.doi.as_ref().unwrap_or(&"".to_string())) + }).collect::>(); + + let prompt = format!( + "I have a collection of papers wrappered by the ```:\n```\n{}\n```\n\nPlease carefully analyze these papers to answer the following question: \n{}\n\nIn your response, please provide a well-integrated analysis that directly answers the question. Include citations from specific papers to support your answer, and ensure that the reasoning behind your answer is clearly explained. Reference relevant details from the papers' summaries or abstracts as needed.", + publications.join("\n"), + question, + ); + + let response = chatbot.answer(prompt).await?; + Ok(PublicationsSummary { + summary: response, + daily_limit_reached: false, + is_disputed: false, + is_incomplete: false, + results_analyzed_count: 0, + }) + } + pub async fn fetch_summary(search_id: &str) -> Result { let api_token = match std::env::var("GUIDESCOPER_API_TOKEN") { Ok(token) => token, diff --git a/src/api/route.rs b/src/api/route.rs index 4faf3b4..b09e005 100644 --- a/src/api/route.rs +++ b/src/api/route.rs @@ -80,6 +80,31 @@ impl BiomedgpsApi { } } + /// Call `/api/v1/publications-summary` with query params to fetch publication summary. + #[oai( + path = "/publications-summary", + method = "post", + tag = "ApiTags::KnowledgeGraph", + operation_id = "answerQuestionWithPublications" + )] + async fn answer_question_with_publications( + &self, + publications: Json>, + question: Query, + _token: CustomSecurityScheme, + ) -> GetPublicationsSummaryResponse { + let question = question.0; + let publications = publications.0; + match Publication::fetch_summary_by_chatgpt(&question, &publications).await { + Ok(result) => GetPublicationsSummaryResponse::ok(result), + Err(e) => { + let err = format!("Failed to fetch publications summary: {}", e); + warn!("{}", err); + return GetPublicationsSummaryResponse::bad_request(err); + } + } + } + /// Call `/api/v1/publications/:id` to fetch a publication. #[oai( path = "/publications/:id", diff --git a/src/bin/biomedgps-cli.rs b/src/bin/biomedgps-cli.rs index 0930c3d..40f3cfd 100644 --- a/src/bin/biomedgps-cli.rs +++ b/src/bin/biomedgps-cli.rs @@ -79,9 +79,9 @@ pub struct CleanDBArguments { #[structopt(name = "database_url", short = "d", long = "database-url")] database_url: Option, - /// Which table to clean. e.g. entity, relation, entity_metadata, relation_metadata, knowledge_curation, subgraph, entity2d, compound-disease-symptom, gene-disease-symptom, knowledge-score, embedding, graph etc. - #[structopt(name = "table", short = "t", long = "table", possible_values = &["entity", "entity2d", "relation", "relation_metadata", "entity_metadata", "knowledge_curation", "subgraph", "compound-disease-symptom", "gene-disease-symptom", "knowledge-score", "embedding", "graph"])] - table: String, + /// [Required] The table name to clean. e.g We will empty all entity-related tables if you use the entity table name. such as entity, entity_metadata, entity2d. + #[structopt(name = "table", short = "t", long = "table", possible_values = &["entity", "relation", "embedding", "subgraph", "curation", "score", "message", "metadata"], multiple = true)] + table: Vec, } /// Import data files into database, such as entity, relation, entity_metadata, relation_metadata, knowledge_curation, subgraph, entity2d etc. When you import the entity data, we will also sync the entity data to the graph database. But the relation data will be synced to the graph database in the cachetable command, because we need to compute the score for the relation data first. The entity_metadata and relation_metadata are generated by the importdb command automatically, actually, you don't need to prepare the entity_metadata and relation_metadata files. But you must use the importdb command manually to upgrade the entity_metadata and relation_metadata tables after the entity and relation tables are upgraded or the first time you run the application. In the most cases, you don't need to import knowledge_curation and subgraph data, we might import them at the migration stage. The entity_2d table is used to store the 2D embedding data, you need to prepare the 2D embedding data manually. If you have multiple models, you might need to choose one model to compute the 2D embedding data. The 2D embedding data is used to visualize the entity data in the 2D space. @@ -739,7 +739,61 @@ async fn main() { .await } SubCommands::CleanDB(arguments) => { - info!("To be implemented.") + let database_url = if arguments.database_url.is_none() { + match std::env::var("DATABASE_URL") { + Ok(v) => v, + Err(_) => { + error!("{}", "DATABASE_URL is not set."); + std::process::exit(1); + } + } + } else { + arguments.database_url.unwrap() + }; + + let pool = match sqlx::PgPool::connect(&database_url).await { + Ok(v) => v, + Err(e) => { + error!("Connect to database failed: {}", e); + std::process::exit(1); + } + }; + + let mut table_names_map = HashMap::<&str, Vec<&str>>::new(); + let pairs = vec![ + ("message", vec!["biomedgps_ai_message"]), + ("score", vec!["biomedgps_compound_disease_symptom_score", "biomedgps_gene_disease_symptom_score", "biomedgps_relation_with_score"]), + ("metadata", vec!["biomedgps_compound_metadata", "biomedgps_journal_metadata"]), + ("entity", vec!["biomedgps_entity", "biomedgps_entity2d", "biomedgps_entity_metadata"]), + ("relation", vec!["biomedgps_relation", "biomedgps_relation_metadata"]), + ("embedding", vec!["biomedgps_entity_embedding", "biomedgps_relation_embedding", "biomedgps_embedding_metadata"]), + ("subgraph", vec!["biomedgps_subgraph"]), + ("curation", vec!["biomedgps_knowledge_curation"]) + ]; + + for pair in pairs { + table_names_map.insert(pair.0, pair.1); + } + + + let tables = arguments.table; + for table in tables { + let table_names = table_names_map.get(table.as_str()); + if table_names.is_none() { + error!("The table name is not supported."); + std::process::exit(1); + } + + let table_names = table_names.unwrap(); + for table_name in table_names { + let sql = format!("TRUNCATE TABLE {}", table_name); + match sqlx::query(&sql).execute(&pool).await { + Ok(_) => info!("Clean the {} table successfully.", table_name), + Err(e) => error!("Clean the {} table failed: {}", table_name, e), + } + } + } + } SubCommands::StatDB(arguments) => { let database_url = if arguments.database_url.is_none() { diff --git a/src/model/graph.rs b/src/model/graph.rs index 8480aa3..83a3323 100644 --- a/src/model/graph.rs +++ b/src/model/graph.rs @@ -1247,7 +1247,8 @@ impl Graph { // vector_to_float4(rte.embedding, 400, false), // vector_to_float4(ee2.embedding, 400, false), // 12.0, - // true + // true, + // false // ) AS score // FROM // biomedgps_entity_embedding ee1, diff --git a/src/model/llm.rs b/src/model/llm.rs index 9f5e089..05e2a54 100644 --- a/src/model/llm.rs +++ b/src/model/llm.rs @@ -5,9 +5,9 @@ use chrono::serde::ts_seconds; use chrono::{DateTime, Utc}; use lazy_static::lazy_static; use log::warn; -use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCall, MessageRole}; -use openai_api_rs::v1::common::{GPT3_5_TURBO, GPT4, GPT4_1106_PREVIEW}; +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, MessageRole, ToolCall}; +use openai_api_rs::v1::common::{GPT3_5_TURBO, GPT4_O}; use openssl::hash::{hash, MessageDigest}; use poem_openapi::{Enum, Object}; use regex::Regex; @@ -474,7 +474,7 @@ where if self.message.len() > 0 { return Ok(self); } else { - self.message = match chatbot.answer(prompt) { + self.message = match chatbot.answer(prompt).await { Ok(message) => message, Err(e) => { warn!("Failed to answer the question: {}", e.to_string()); @@ -508,9 +508,9 @@ pub struct ChatBot { role: MessageRole, name: Option, content: Option, - function_call: Option, + tool_call: Option>, model_name: String, - client: Client, + client: OpenAIClient, } impl ChatBot { @@ -519,37 +519,39 @@ impl ChatBot { // GPT4 or GPT4_1106_PREVIEW // https://platform.openai.com/account/limits // - GPT4_1106_PREVIEW.to_string() + GPT4_O.to_string() } else { GPT3_5_TURBO.to_string() }; - let client = Client::new(openai_api_key.to_string()); + let client = OpenAIClient::new(openai_api_key.to_string()); ChatBot { role: MessageRole::user, name: None, content: None, - function_call: None, + tool_call: None, model_name: model, client: client, } } - pub fn answer(&self, prompt: String) -> Result { + pub async fn answer(&self, prompt: String) -> Result { let model_name = self.model_name.clone(); let req = ChatCompletionRequest::new( model_name, vec![chat_completion::ChatCompletionMessage { role: self.role.clone(), - content: prompt, + content: chat_completion::Content::Text(prompt), name: self.name.clone(), - function_call: self.function_call.clone(), + // TODO: How to use the tool_call? + tool_calls: None, + tool_call_id: None, }], ); let req = req.temperature(0.5); - let result = self.client.chat_completion(req)?; + let result = self.client.chat_completion(req).await?; let message = result.choices[0].message.content.clone(); match message { @@ -580,7 +582,9 @@ mod tests { xrefs: None, }; - let mut llm_msg = super::LlmMessage::new("node_summary", node, None).unwrap(); + super::init_prompt_templates(); + + let mut llm_msg = super::LlmMessage::new("explain_node_summary", node, None).unwrap(); let answer = llm_msg.answer(&chatbot, None).await.unwrap(); println!("Prompt: {}", answer.prompt); println!("Answer: {}", answer.message); diff --git a/studio/src/EdgeInfoPanel/PublicationDesc.tsx b/studio/src/EdgeInfoPanel/PublicationDesc.tsx index cd39369..7042561 100644 --- a/studio/src/EdgeInfoPanel/PublicationDesc.tsx +++ b/studio/src/EdgeInfoPanel/PublicationDesc.tsx @@ -3,27 +3,22 @@ import { Button, Tag } from 'antd'; import parse from 'html-react-parser'; import type { PublicationDetail } from 'biominer-components/dist/typings'; -export const SEPARATOR = '#'; - const Desc: React.FC<{ publication: PublicationDetail, + abstract: string, showAbstract: (doc_id: string) => Promise, showPublication: (publication: PublicationDetail) => void, - queryStr: string + startNode?: string, + endNode?: string, }> = (props) => { const { publication } = props; - const [abstract, setAbstract] = useState(''); - const [abstractVisible, setAbstractVisible] = useState(false); + const words = [props.startNode || '', props.endNode || '']; const fetchAbstract = (doc_id: string) => { props.showAbstract(doc_id).then((publication) => { - console.log('fetchAbstract for a publication: ', publication); - setAbstract(publication.article_abstract || ''); - setAbstractVisible(true); + console.log('Publication: ', publication); }).catch((error) => { console.error('Error: ', error); - setAbstract(''); - setAbstractVisible(false); }); }; @@ -45,20 +40,18 @@ const Desc: React.FC<{ return (

- {parse(highlightWords(publication.summary, props.queryStr.split(SEPARATOR)))} + {parse(highlightWords(publication.summary, words))}

{ - abstractVisible ? -

{parse(highlightWords(abstract, props.queryStr.split(SEPARATOR)))}

: null + props.abstract ? +

{parse(highlightWords(props.abstract, words))}

: null }

{publication.year}Journal | {publication.journal} {publication.authors ? publication.authors.join(', ') : 'Unknown'} diff --git a/studio/src/EdgeInfoPanel/PublicationPanel.tsx b/studio/src/EdgeInfoPanel/PublicationPanel.tsx index d5b552f..cd288fd 100644 --- a/studio/src/EdgeInfoPanel/PublicationPanel.tsx +++ b/studio/src/EdgeInfoPanel/PublicationPanel.tsx @@ -1,27 +1,32 @@ import React, { useEffect, useState } from 'react'; +import { MarkdownViewer } from 'biominer-components'; +import RehypeRaw from 'rehype-raw'; import { Button, List, message, Row, Col, Tag } from 'antd'; import { FileProtectOutlined } from '@ant-design/icons'; import type { Publication, PublicationDetail } from 'biominer-components/dist/typings'; import PublicationDesc from './PublicationDesc'; -import { fetchPublication, fetchPublications, fetchPublicationsSummary } from '@/services/swagger/KnowledgeGraph'; +import { fetchPublication, fetchPublications, fetchPublicationsSummary, answerQuestionWithPublications } from '@/services/swagger/KnowledgeGraph'; import './index.less'; export type PublicationPanelProps = { queryStr: string; + startNode?: string; + endNode?: string; }; const PublicationPanel: React.FC = (props) => { - const [publications, setPublications] = useState([]); const [page, setPage] = useState(0); const [total, setTotal] = useState(0); const [pageSize, setPageSize] = useState(10); const [loading, setLoading] = useState(false); const [publicationMap, setPublicationMap] = useState>({}); + const [abstractMap, setAbstractMap] = useState>({}); const [searchId, setSearchId] = useState(''); - const [publicationSummary, setPublicationSummary] = useState(''); + const [publicationSummary, setPublicationSummary] = useState('Loading...'); + const [generating, setGenerating] = useState(false); - const showAbstract = (doc_id: string): Promise => { + const showAbstract = async (doc_id: string): Promise => { console.log('Show Abstract: ', doc_id); return new Promise((resolve, reject) => { fetchPublication({ id: doc_id }).then((publication) => { @@ -55,7 +60,11 @@ const PublicationPanel: React.FC = (props) => { fetchPublicationSummary(data.search_id); } - setPublications(data.records); + let publicationMap: Record = {}; + data.records.forEach((publication) => { + publicationMap[publication.doc_id] = publication; + }); + setPublicationMap(publicationMap); setPage(data.page); setTotal(data.total); setPageSize(data.page_size); @@ -67,6 +76,42 @@ const PublicationPanel: React.FC = (props) => { }); }, [props.queryStr, page, pageSize]); + const loadAbstractsAndAnswer = async (docIds: string[]) => { + const tempAbstractMap: Record = {}; + for (let i = 0; i < docIds.length; i++) { + const docId = docIds[i]; + if (!publicationMap[docId].article_abstract) { + const msg = `Load ${i} publication...`; + setPublicationSummary(msg); + await showAbstract(docId).then((publication) => { + tempAbstractMap[docId] = publication.article_abstract || ''; + }).catch((error) => { + setGenerating(false); + console.error('Error: ', error); + }); + setTimeout(() => console.log(msg), 200 * i) + } + } + + setAbstractMap(tempAbstractMap); + setPublicationSummary('Publications loaded, answering question...'); + + answerQuestionWithPublications( + { + question: props.queryStr, + }, + // @ts-ignore Don't need to care about this warning, just because the authors field is not defined as a string[]. + Object.values(publicationMap) + ).then((response) => { + console.log('Answer: ', response); + setPublicationSummary(response.summary); + }).catch((error) => { + setGenerating(false); + console.error('Error: ', error); + setPublicationSummary('Failed to answer question, because of the following error: ' + error); + }); + } + const showPublication = async (publication: PublicationDetail) => { console.log('Show Publication: ', publication); if (publication) { @@ -109,15 +154,23 @@ const PublicationPanel: React.FC = (props) => { return ( - Question + Summary Question - {props.queryStr} + {props.queryStr}

- Answer by AI - {publicationSummary.length > 0 ? publicationSummary : `Generating answers for the question above...`} + {/* Answer by AI */} +

@@ -127,7 +180,7 @@ const PublicationPanel: React.FC = (props) => { loading={loading} itemLayout="horizontal" rowKey={'doc_id'} - dataSource={publications} + dataSource={Object.values(publicationMap)} size="large" pagination={{ disabled: false, @@ -146,8 +199,8 @@ const PublicationPanel: React.FC = (props) => { avatar={} title={ { onClickPublication(item); }}>{item.title}} description={ - onClickPublication(publication)} /> } diff --git a/studio/src/services/swagger/KnowledgeGraph.ts b/studio/src/services/swagger/KnowledgeGraph.ts index 77aeca8..0817623 100644 --- a/studio/src/services/swagger/KnowledgeGraph.ts +++ b/studio/src/services/swagger/KnowledgeGraph.ts @@ -288,6 +288,26 @@ export async function fetchPublicationsConsensus( }); } +/** Call `/api/v1/publications-summary` with query params to fetch publication summary. POST /api/v1/publications-summary */ +export async function answerQuestionWithPublications( + // 叠加生成的Param类型 (非body参数swagger默认没有生成对象) + params: swagger.answerQuestionWithPublicationsParams, + body: swagger.Publication[], + options?: { [key: string]: any }, +) { + return request('/api/v1/publications-summary', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + params: { + ...params, + }, + data: body, + ...(options || {}), + }); +} + /** Call `/api/v1/publications-summary` with query params to fetch publication summary. GET /api/v1/publications-summary/${param0} */ export async function fetchPublicationsSummary( // 叠加生成的Param类型 (非body参数swagger默认没有生成对象) diff --git a/studio/src/services/swagger/typings.d.ts b/studio/src/services/swagger/typings.d.ts index c2398f3..aa0a1e6 100644 --- a/studio/src/services/swagger/typings.d.ts +++ b/studio/src/services/swagger/typings.d.ts @@ -1,4 +1,8 @@ declare namespace swagger { + type answerQuestionWithPublicationsParams = { + question: string; + }; + type Article = { ref_id: string; pubmed_id: string;