Skip to content

Commit

Permalink
add: llm stream response
Browse files Browse the repository at this point in the history
  • Loading branch information
mertakman committed Nov 7, 2024
1 parent c4d8f32 commit 0641046
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 50 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
base64 = "0.22.1"
serde = { version = "1.0", features = ["derive"] }
reqwest = { version = "0.11", features = ["json"] }
bytes = "1.7.2"
serde_json = "=1.0.1"
urlencoding = "2.1.3"
http = "1.1.0"
Expand Down
5 changes: 5 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub enum QstashError {
InvalidRequestUrl(String),
RequestFailed(reqwest::Error),
ResponseBodyParseError(reqwest::Error),
ResponseStreamParseError(serde_json::Error),
DailyRateLimitExceeded {
reset: u64,
},
Expand All @@ -31,6 +32,9 @@ impl fmt::Display for QstashError {
QstashError::ResponseBodyParseError(err) => {
write!(f, "Failed to parse response body: {}", err)
}
QstashError::ResponseStreamParseError(err) => {
write!(f, "Failed to parse response stream: {}", err)
}
QstashError::DailyRateLimitExceeded { reset } => {
write!(f, "Daily rate limit exceeded. Retry after: {}", reset)
}
Expand Down Expand Up @@ -60,6 +64,7 @@ impl error::Error for QstashError {
QstashError::InvalidRequestUrl(_) => None,
QstashError::RequestFailed(err) => Some(err),
QstashError::ResponseBodyParseError(err) => Some(err),
QstashError::ResponseStreamParseError(err) => Some(err),
QstashError::DailyRateLimitExceeded { .. } => None,
QstashError::BurstRateLimitExceeded { .. } => None,
QstashError::ChatRateLimitExceeded { .. } => None,
Expand Down
1 change: 0 additions & 1 deletion src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,3 @@ impl QstashClient {
}
}
}

196 changes: 147 additions & 49 deletions src/llm_types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::pin::Pin;
use std::task::{Context, Poll};

use crate::errors::QstashError;
use bytes::Bytes;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug)]
pub struct ChatCompletionRequest {
Expand Down Expand Up @@ -158,49 +154,6 @@ pub struct Usage {
pub total_tokens: i32,
}

pub struct StreamResponse {
response: Option<reqwest::Response>, // Use RefCell for interior mutability
pending: Vec<u8>,
}

impl StreamResponse {
pub fn new(response: reqwest::Response) -> Self {
Self {
response: Some(response),
pending: Vec::new(),
}
}

pub async fn get_next_stream_message(&mut self) -> Result<Option<StreamMessage>, QstashError> {
let chunk = self.poll_chunk().await?;

todo!()
}

async fn poll_chunk(&mut self) -> Result<Option<Vec<u8>>, QstashError> {
loop {
match self.response {
Some(ref mut response) => {
let chunk = response
.chunk()
.await
.map_err(QstashError::RequestFailed)?;

match chunk {
Some(chunk) => {
todo!()
}
None => return Ok(None),
}
}
None => {
return Ok(None);
}
}
};
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct StreamMessage {
// A unique identifier for the chat completion. Each chunk has the same ID
Expand Down Expand Up @@ -238,3 +191,148 @@ pub struct Delta {
// The contents of the chunk message
pub content: Option<String>,
}

enum ChunkType {
Message(Vec<u8>),
Done(),
}

pub struct StreamResponse {
response: Option<reqwest::Response>, // Use RefCell for interior mutability
buffer: Vec<u8>,
}

impl StreamResponse {
pub fn new(response: reqwest::Response) -> Self {
Self {
response: Some(response),
buffer: Vec::new(),
}
}

pub fn default() -> Self {
Self {
response: None,
buffer: Vec::new(),
}
}

pub async fn get_next_stream_message(&mut self) -> Result<Option<StreamMessage>, QstashError> {
let chunk = self.poll_chunk().await?;
match chunk {
ChunkType::Message(data) => {
let message = serde_json::from_slice(&data).map_err(QstashError::ResponseStreamParseError)?;
Ok(Some(message))
}
ChunkType::Done() => Ok(None),
}
}

async fn poll_chunk(&mut self) -> Result<ChunkType, QstashError> {
loop {
let response = match &mut self.response {
Some(r) => r,
None => return Ok(ChunkType::Done()),
};

// Get the next chunk
let chunk = match response.chunk().await.map_err(QstashError::RequestFailed)? {
Some(c) => c,
None => return Ok(ChunkType::Done()),
};

// Now we can mutably borrow self for extract_next_message
if let Some(message) = self.extract_next_message(&chunk) {
match message.as_slice() {
b"[DONE]" => return Ok(ChunkType::Done()),
_ => return Ok(ChunkType::Message(message)),
}
}
}
}

// Takes a chunk of bytes and returns a complete message if available
fn extract_next_message(&mut self, chunk: &Bytes) -> Option<Vec<u8>> {
// Append new chunk to existing buffer
self.buffer.extend_from_slice(chunk);

// Look for delimiter
if let Some(msg_end) = self.buffer.windows(2).position(|w| w == b"\n\n") {
// Extract the message (excluding delimiter)
let message = self.buffer[..msg_end].to_vec();

// Remove the processed message and delimiter from buffer
self.buffer = self.buffer[msg_end + 2..].to_vec();

Some(message)
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_single_complete_message() {
let mut processor = StreamResponse::default();
let chunk = Bytes::from("Hello\n\nWorld");

let message = processor.extract_next_message(&chunk);
assert_eq!(message, Some(b"Hello".to_vec()));
assert_eq!(processor.buffer, b"World");
}

#[test]
fn test_message_in_multiple_chunks() {
let mut processor = StreamResponse::default();

assert_eq!(processor.extract_next_message(&Bytes::from("Hel")), None);
assert_eq!(processor.extract_next_message(&Bytes::from("lo Wo")), None);
assert_eq!(
processor.extract_next_message(&Bytes::from("rld\n\nNext")),
Some(b"Hello World".to_vec())
);
assert_eq!(processor.buffer, b"Next");
}

#[test]
fn test_multiple_messages_in_single_chunk() {
let mut processor = StreamResponse::default();
let chunk = Bytes::from("First\n\nSecond\n\nThird");

assert_eq!(
processor.extract_next_message(&chunk),
Some(b"First".to_vec())
);
assert_eq!(processor.buffer, b"Second\n\nThird");

assert_eq!(
processor.extract_next_message(&Bytes::from("")),
Some(b"Second".to_vec())
);
assert_eq!(processor.buffer, b"Third");
}

#[test]
fn test_empty_message() {
let mut processor = StreamResponse::default();

assert_eq!(
processor.extract_next_message(&Bytes::from("\n\nAfter")),
Some(b"".to_vec())
);
assert_eq!(processor.buffer, b"After");
}

#[test]
fn test_no_complete_message() {
let mut processor = StreamResponse::default();

assert_eq!(processor.extract_next_message(&Bytes::from("Hello")), None);
assert_eq!(processor.extract_next_message(&Bytes::from(" World")), None);
assert_eq!(processor.buffer, b"Hello World");
}
}

0 comments on commit 0641046

Please sign in to comment.