Skip to content

Commit

Permalink
fix:add tests for rate_limited_client
Browse files Browse the repository at this point in the history
  • Loading branch information
mertakman committed Nov 7, 2024
1 parent 0641046 commit 1959c96
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 68 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ urlencoding = "2.1.3"
http = "1.1.0"
tokio = { version="1.41.0", features = ["full"] }
futures = "0.3"
httpmock = "0.7.0"
5 changes: 5 additions & 0 deletions src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ impl QstashClient {
Ok(response)
}
}

#[cfg(test)]
mod tests {
use super::*;
}
13 changes: 0 additions & 13 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,4 @@ impl QstashClient {
}
}
}

pub async fn a(&self) {
let a = self.create_chat_completion(todo!()).await.unwrap();
match a {
ChatCompletionResponse::Direct(d) => {
println!("Direct response: {:?}", d);
}
ChatCompletionResponse::Stream(s) => {
let s = s;
s.get_next_stream_message().await.unwrap();
}
}
}
}
189 changes: 134 additions & 55 deletions src/llm_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ impl StreamResponse {
let chunk = self.poll_chunk().await?;
match chunk {
ChunkType::Message(data) => {
let message = serde_json::from_slice(&data).map_err(QstashError::ResponseStreamParseError)?;
let message =
serde_json::from_slice(&data).map_err(QstashError::ResponseStreamParseError)?;
Ok(Some(message))
}
ChunkType::Done() => Ok(None),
Expand All @@ -244,7 +245,10 @@ impl StreamResponse {
// Now we can mutably borrow self for extract_next_message
if let Some(message) = self.extract_next_message(&chunk) {
match message.as_slice() {
b"[DONE]" => return Ok(ChunkType::Done()),
b"[DONE]" => {
self.response = None;
return Ok(ChunkType::Done());
}
_ => return Ok(ChunkType::Message(message)),
}
}
Expand Down Expand Up @@ -273,66 +277,141 @@ impl StreamResponse {

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_single_complete_message() {
let mut processor = StreamResponse::default();
let chunk = Bytes::from("Hello\n\nWorld");
use crate::rate_limited_client::RateLimitedClient;

let message = processor.extract_next_message(&chunk);
assert_eq!(message, Some(b"Hello".to_vec()));
assert_eq!(processor.buffer, b"World");
use super::*;
use reqwest::{Method, Url};
use httpmock::prelude::*;

#[tokio::test]
async fn test_send_request_success() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(200);
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
assert!(result.is_ok());
mock.assert();
}

#[test]
fn test_message_in_multiple_chunks() {
let mut processor = StreamResponse::default();

assert_eq!(processor.extract_next_message(&Bytes::from("Hel")), None);
assert_eq!(processor.extract_next_message(&Bytes::from("lo Wo")), None);
assert_eq!(
processor.extract_next_message(&Bytes::from("rld\n\nNext")),
Some(b"Hello World".to_vec())
);
assert_eq!(processor.buffer, b"Next");
#[tokio::test]
async fn test_send_request_daily_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429)
.header("RateLimit-Limit", "1000")
.header("RateLimit-Reset", "3600");
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::DailyRateLimitExceeded { reset }) => assert_eq!(reset, 3600),
_ => panic!("Expected DailyRateLimitExceeded error"),
}
mock.assert();
}

#[test]
fn test_multiple_messages_in_single_chunk() {
let mut processor = StreamResponse::default();
let chunk = Bytes::from("First\n\nSecond\n\nThird");

assert_eq!(
processor.extract_next_message(&chunk),
Some(b"First".to_vec())
);
assert_eq!(processor.buffer, b"Second\n\nThird");

assert_eq!(
processor.extract_next_message(&Bytes::from("")),
Some(b"Second".to_vec())
);
assert_eq!(processor.buffer, b"Third");
#[tokio::test]
async fn test_send_request_burst_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429)
.header("Burst-RateLimit-Limit", "100")
.header("Burst-RateLimit-Reset", "60");
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::BurstRateLimitExceeded { reset }) => assert_eq!(reset, 60),
_ => panic!("Expected BurstRateLimitExceeded error"),
}
mock.assert();
}

#[test]
fn test_empty_message() {
let mut processor = StreamResponse::default();

assert_eq!(
processor.extract_next_message(&Bytes::from("\n\nAfter")),
Some(b"".to_vec())
);
assert_eq!(processor.buffer, b"After");
#[tokio::test]
async fn test_send_request_chat_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429)
.header("x-ratelimit-limit-requests", "100")
.header("x-ratelimit-reset-requests", "30")
.header("x-ratelimit-reset-tokens", "45");
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::ChatRateLimitExceeded { reset_requests, reset_tokens }) => {
assert_eq!(reset_requests, 30);
assert_eq!(reset_tokens, 45);
},
_ => panic!("Expected ChatRateLimitExceeded error"),
}
mock.assert();
}

#[test]
fn test_no_complete_message() {
let mut processor = StreamResponse::default();

assert_eq!(processor.extract_next_message(&Bytes::from("Hello")), None);
assert_eq!(processor.extract_next_message(&Bytes::from(" World")), None);
assert_eq!(processor.buffer, b"Hello World");
#[tokio::test]
async fn test_send_request_unspecified_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429);
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::UnspecifiedRateLimitExceeded) => (),
_ => panic!("Expected UnspecifiedRateLimitExceeded error"),
}
mock.assert();
}
}
}
139 changes: 139 additions & 0 deletions src/rate_limited_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,142 @@ fn parse_reset_time(headers: &HeaderMap, header_name: &str) -> u64 {
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0)
}

#[cfg(test)]
mod tests {
use super::*;
use reqwest::Method;
use httpmock::prelude::*;

#[tokio::test]
async fn test_send_request_success() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(200);
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
assert!(result.is_ok());
mock.assert();
}

#[tokio::test]
async fn test_send_request_daily_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429)
.header("RateLimit-Limit", "1000")
.header("RateLimit-Reset", "3600");
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::DailyRateLimitExceeded { reset }) => assert_eq!(reset, 3600),
_ => panic!("Expected DailyRateLimitExceeded error"),
}
mock.assert();
}

#[tokio::test]
async fn test_send_request_burst_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429)
.header("Burst-RateLimit-Limit", "100")
.header("Burst-RateLimit-Reset", "60");
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::BurstRateLimitExceeded { reset }) => assert_eq!(reset, 60),
_ => panic!("Expected BurstRateLimitExceeded error"),
}
mock.assert();
}

#[tokio::test]
async fn test_send_request_chat_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429)
.header("x-ratelimit-limit-requests", "100")
.header("x-ratelimit-reset-requests", "30")
.header("x-ratelimit-reset-tokens", "45");
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::ChatRateLimitExceeded { reset_requests, reset_tokens }) => {
assert_eq!(reset_requests, 30);
assert_eq!(reset_tokens, 45);
},
_ => panic!("Expected ChatRateLimitExceeded error"),
}
mock.assert();
}

#[tokio::test]
async fn test_send_request_unspecified_rate_limit_exceeded() {
// Arrange
let server = MockServer::start_async().await;
let mock = server.mock(|when, then| {
when.method(GET)
.path("/test");
then.status(429);
});

let client = RateLimitedClient::new("test_api_key".to_string());
let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap();
let request_builder = client.get_request_builder(Method::GET, url);

// Act
let result = client.send_request(request_builder).await;

// Assert
match result {
Err(QstashError::UnspecifiedRateLimitExceeded) => (),
_ => panic!("Expected UnspecifiedRateLimitExceeded error"),
}
mock.assert();
}
}

0 comments on commit 1959c96

Please sign in to comment.