diff --git a/Cargo.toml b/Cargo.toml index 5fb1e98..396f5cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ dotenv = "0.15" sqlx = { version = "*", features = ["runtime-tokio-rustls", "sqlite"] } tokio = { version = "*", features = ["full"] } axum = "*" -reqwest = { version = "*", features = ["json"]} +reqwest = { version = "*", features = ["json", "multipart"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/src/assistant.rs b/src/assistant.rs index 52c459c..874c560 100644 --- a/src/assistant.rs +++ b/src/assistant.rs @@ -7,7 +7,7 @@ use axum::{ Extension, Json, }; -use reqwest::Client; +use reqwest::{Client, multipart::Form}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -116,6 +116,7 @@ pub struct Assistant { model: String, instructions: String, folder_path: String, + file_ids: Vec, scrape_urls: Vec, } impl Assistant { @@ -132,6 +133,7 @@ impl Assistant { }); let response = client .post("https://api.openai.com/v1/assistants") + .header("OpenAI-Beta", "assistants=v1") .bearer_auth(&api_key) .json(&payload) .send() @@ -197,46 +199,61 @@ impl Assistant { } Ok(()) } - /// upload a file to OpenAI and return the file ID - pub async fn upload_file(&self, file_path: &str) -> Result { + /// upload a all files in a folder and return a vector of file IDs + pub async fn upload_files(&mut self) -> Result<(), AssistantError> { let api_key = env::var("OPENAI_API_KEY") .map_err(|_| AssistantError::OpenAIError("OPENAI_API_KEY not set".to_string()))?; let client = Client::new(); - let payload = json!({ - "purpose": "assistants", - "file": file_path, - }); - let response = client - .post("https://api.openai.com/v1/files") - .bearer_auth(&api_key) - .json(&payload) - .send() - .await; - match response { - Ok(res) if res.status().is_success() => match res.json::().await { - Ok(file_response) => Ok(file_response.id), - Err(_) => Err(AssistantError::OpenAIError( - "Failed to parse response from OpenAI".to_string(), - )), - }, - Ok(res) => { - let error_message = res.text().await.unwrap_or_default(); - Err(AssistantError::OpenAIError(error_message)) + let paths = fs::read_dir(Path::new(&self.folder_path)) + .map_err(|e| AssistantError::DatabaseError(e.to_string()))?; + for path in paths { + let path = path + .map_err(|e| AssistantError::DatabaseError(e.to_string()))? + .path(); + if path.is_file() { + let form = reqwest::multipart::Form::new() + .file("file", path.to_str().unwrap()) + .map_err(|e| AssistantError::OpenAIError(e.to_string()))?; + let response = client + .post("https://api.openai.com/v1/files") + .header("OpenAI-Beta", "assistants=v1") + .bearer_auth(&api_key) + .multipart(form) + .send() + .await; + match response { + Ok(res) if res.status().is_success() => { + if let Ok(file_response) = res.json::().await { + self.file_ids.push(file_response.id); + } else { + return Err(AssistantError::OpenAIError( + "Failed to parse response from OpenAI".to_string(), + )); + } + } + Ok(res) => { + let error_message = res.text().await.unwrap_or_default(); + return Err(AssistantError::OpenAIError(error_message)); + } + Err(e) => return Err(AssistantError::OpenAIError(e.to_string())), + } } - Err(e) => Err(AssistantError::OpenAIError(e.to_string())), } + Ok(()) } - pub async fn attach_files(&self, file_ids: Vec) -> Result<(), AssistantError> { + /// Attach the files with IDs stored in the file_ids field to the assistant. + pub async fn attach_files(&self) -> Result<(), AssistantError> { let api_key = env::var("OPENAI_API_KEY") .map_err(|_| AssistantError::OpenAIError("OPENAI_API_KEY not set".to_string()))?; let client = Client::new(); - for file_id in file_ids { - let payload = AttachFilesRequest { file_id }; + for file_id in &self.file_ids { + let payload = json!({ "file_id": file_id }); let response = client .post(format!( "https://api.openai.com/v1/assistants/{}/files", self.id )) + .header("OpenAI-Beta", "assistants=v1") .bearer_auth(&api_key) .json(&payload) .send() @@ -266,23 +283,21 @@ pub async fn create_assistant( instructions: instructions.to_string(), folder_path: folder_path.to_string(), scrape_urls: Vec::new(), + file_ids: Vec::new(), // Make sure this field exists in the Assistant struct }; + // Initialize the assistant by creating it on the OpenAI platform assistant.initialize().await?; - let paths = fs::read_dir(Path::new(folder_path)) - .map_err(|e| AssistantError::DatabaseError(e.to_string()))?; - let mut file_ids = Vec::new(); - for path in paths { - let path = path - .map_err(|e| AssistantError::DatabaseError(e.to_string()))? - .path(); - if path.is_file() { - let file_id = assistant.upload_file(path.to_str().unwrap()).await?; - file_ids.push(file_id); - } + // Scrape context if needed (optional, based on whether scrape_urls are provided) + if !assistant.scrape_urls.is_empty() { + assistant.scrape_context().await?; } - assistant.attach_files(file_ids).await?; + // Upload files from the folder_path to OpenAI and store the file IDs in the assistant + assistant.upload_files().await?; + // Attach the uploaded files to the assistant using the stored file IDs + assistant.attach_files().await?; Ok(assistant) } + struct Chat { id: String, user_id: String, diff --git a/src/main.rs b/src/main.rs index 526c13e..b1760ec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,7 @@ async fn main() { "My Assistant", "gpt-4", "Your instructions here", - "path/to/folder", + "context", ) .await {