diff --git a/Cargo.toml b/Cargo.toml index 2415241..76a3a02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" base64 = "0.22.1" serde = { version = "1.0", features = ["derive"] } reqwest = { version = "0.11", features = ["json"] } +bytes = "1.7.2" serde_json = "=1.0.1" urlencoding = "2.1.3" http = "1.1.0" diff --git a/src/errors.rs b/src/errors.rs index 4a4dfa3..d851cb1 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -8,6 +8,7 @@ pub enum QstashError { InvalidRequestUrl(String), RequestFailed(reqwest::Error), ResponseBodyParseError(reqwest::Error), + ResponseStreamParseError(serde_json::Error), DailyRateLimitExceeded { reset: u64, }, @@ -31,6 +32,9 @@ impl fmt::Display for QstashError { QstashError::ResponseBodyParseError(err) => { write!(f, "Failed to parse response body: {}", err) } + QstashError::ResponseStreamParseError(err) => { + write!(f, "Failed to parse response stream: {}", err) + } QstashError::DailyRateLimitExceeded { reset } => { write!(f, "Daily rate limit exceeded. Retry after: {}", reset) } @@ -60,6 +64,7 @@ impl error::Error for QstashError { QstashError::InvalidRequestUrl(_) => None, QstashError::RequestFailed(err) => Some(err), QstashError::ResponseBodyParseError(err) => Some(err), + QstashError::ResponseStreamParseError(err) => Some(err), QstashError::DailyRateLimitExceeded { .. } => None, QstashError::BurstRateLimitExceeded { .. } => None, QstashError::ChatRateLimitExceeded { .. } => None, diff --git a/src/llm.rs b/src/llm.rs index cb694b5..c4a39ba 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -55,4 +55,3 @@ impl QstashClient { } } } - diff --git a/src/llm_types.rs b/src/llm_types.rs index 345464a..e73fab6 100644 --- a/src/llm_types.rs +++ b/src/llm_types.rs @@ -1,10 +1,6 @@ -use futures::stream::Stream; -use serde::{Deserialize, Serialize}; -use std::cell::RefCell; -use std::pin::Pin; -use std::task::{Context, Poll}; - use crate::errors::QstashError; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] pub struct ChatCompletionRequest { @@ -158,49 +154,6 @@ pub struct Usage { pub total_tokens: i32, } -pub struct StreamResponse { - response: Option, // Use RefCell for interior mutability - pending: Vec, -} - -impl StreamResponse { - pub fn new(response: reqwest::Response) -> Self { - Self { - response: Some(response), - pending: Vec::new(), - } - } - - pub async fn get_next_stream_message(&mut self) -> Result, QstashError> { - let chunk = self.poll_chunk().await?; - - todo!() - } - - async fn poll_chunk(&mut self) -> Result>, QstashError> { - loop { - match self.response { - Some(ref mut response) => { - let chunk = response - .chunk() - .await - .map_err(QstashError::RequestFailed)?; - - match chunk { - Some(chunk) => { - todo!() - } - None => return Ok(None), - } - } - None => { - return Ok(None); - } - } - }; - } -} - #[derive(Debug, Serialize, Deserialize)] pub struct StreamMessage { // A unique identifier for the chat completion. Each chunk has the same ID @@ -238,3 +191,148 @@ pub struct Delta { // The contents of the chunk message pub content: Option, } + +enum ChunkType { + Message(Vec), + Done(), +} + +pub struct StreamResponse { + response: Option, // Use RefCell for interior mutability + buffer: Vec, +} + +impl StreamResponse { + pub fn new(response: reqwest::Response) -> Self { + Self { + response: Some(response), + buffer: Vec::new(), + } + } + + pub fn default() -> Self { + Self { + response: None, + buffer: Vec::new(), + } + } + + pub async fn get_next_stream_message(&mut self) -> Result, QstashError> { + let chunk = self.poll_chunk().await?; + match chunk { + ChunkType::Message(data) => { + let message = serde_json::from_slice(&data).map_err(QstashError::ResponseStreamParseError)?; + Ok(Some(message)) + } + ChunkType::Done() => Ok(None), + } + } + + async fn poll_chunk(&mut self) -> Result { + loop { + let response = match &mut self.response { + Some(r) => r, + None => return Ok(ChunkType::Done()), + }; + + // Get the next chunk + let chunk = match response.chunk().await.map_err(QstashError::RequestFailed)? { + Some(c) => c, + None => return Ok(ChunkType::Done()), + }; + + // Now we can mutably borrow self for extract_next_message + if let Some(message) = self.extract_next_message(&chunk) { + match message.as_slice() { + b"[DONE]" => return Ok(ChunkType::Done()), + _ => return Ok(ChunkType::Message(message)), + } + } + } + } + + // Takes a chunk of bytes and returns a complete message if available + fn extract_next_message(&mut self, chunk: &Bytes) -> Option> { + // Append new chunk to existing buffer + self.buffer.extend_from_slice(chunk); + + // Look for delimiter + if let Some(msg_end) = self.buffer.windows(2).position(|w| w == b"\n\n") { + // Extract the message (excluding delimiter) + let message = self.buffer[..msg_end].to_vec(); + + // Remove the processed message and delimiter from buffer + self.buffer = self.buffer[msg_end + 2..].to_vec(); + + Some(message) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_complete_message() { + let mut processor = StreamResponse::default(); + let chunk = Bytes::from("Hello\n\nWorld"); + + let message = processor.extract_next_message(&chunk); + assert_eq!(message, Some(b"Hello".to_vec())); + assert_eq!(processor.buffer, b"World"); + } + + #[test] + fn test_message_in_multiple_chunks() { + let mut processor = StreamResponse::default(); + + assert_eq!(processor.extract_next_message(&Bytes::from("Hel")), None); + assert_eq!(processor.extract_next_message(&Bytes::from("lo Wo")), None); + assert_eq!( + processor.extract_next_message(&Bytes::from("rld\n\nNext")), + Some(b"Hello World".to_vec()) + ); + assert_eq!(processor.buffer, b"Next"); + } + + #[test] + fn test_multiple_messages_in_single_chunk() { + let mut processor = StreamResponse::default(); + let chunk = Bytes::from("First\n\nSecond\n\nThird"); + + assert_eq!( + processor.extract_next_message(&chunk), + Some(b"First".to_vec()) + ); + assert_eq!(processor.buffer, b"Second\n\nThird"); + + assert_eq!( + processor.extract_next_message(&Bytes::from("")), + Some(b"Second".to_vec()) + ); + assert_eq!(processor.buffer, b"Third"); + } + + #[test] + fn test_empty_message() { + let mut processor = StreamResponse::default(); + + assert_eq!( + processor.extract_next_message(&Bytes::from("\n\nAfter")), + Some(b"".to_vec()) + ); + assert_eq!(processor.buffer, b"After"); + } + + #[test] + fn test_no_complete_message() { + let mut processor = StreamResponse::default(); + + assert_eq!(processor.extract_next_message(&Bytes::from("Hello")), None); + assert_eq!(processor.extract_next_message(&Bytes::from(" World")), None); + assert_eq!(processor.buffer, b"Hello World"); + } +}