From 1959c96e68e8281eb694faaf269da1e784c1bcd0 Mon Sep 17 00:00:00 2001 From: mertakman Date: Thu, 7 Nov 2024 01:16:12 +0000 Subject: [PATCH] fix:add tests for rate_limited_client --- Cargo.toml | 1 + src/events.rs | 5 + src/llm.rs | 13 --- src/llm_types.rs | 189 ++++++++++++++++++++++++++----------- src/rate_limited_client.rs | 139 +++++++++++++++++++++++++++ 5 files changed, 279 insertions(+), 68 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 76a3a02..0967fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ urlencoding = "2.1.3" http = "1.1.0" tokio = { version="1.41.0", features = ["full"] } futures = "0.3" +httpmock = "0.7.0" \ No newline at end of file diff --git a/src/events.rs b/src/events.rs index 2a0855e..b48f7eb 100644 --- a/src/events.rs +++ b/src/events.rs @@ -27,3 +27,8 @@ impl QstashClient { Ok(response) } } + +#[cfg(test)] +mod tests { + use super::*; +} diff --git a/src/llm.rs b/src/llm.rs index c4a39ba..2bdf901 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -41,17 +41,4 @@ impl QstashClient { } } } - - pub async fn a(&self) { - let a = self.create_chat_completion(todo!()).await.unwrap(); - match a { - ChatCompletionResponse::Direct(d) => { - println!("Direct response: {:?}", d); - } - ChatCompletionResponse::Stream(s) => { - let s = s; - s.get_next_stream_message().await.unwrap(); - } - } - } } diff --git a/src/llm_types.rs b/src/llm_types.rs index e73fab6..8358f46 100644 --- a/src/llm_types.rs +++ b/src/llm_types.rs @@ -221,7 +221,8 @@ impl StreamResponse { let chunk = self.poll_chunk().await?; match chunk { ChunkType::Message(data) => { - let message = serde_json::from_slice(&data).map_err(QstashError::ResponseStreamParseError)?; + let message = + serde_json::from_slice(&data).map_err(QstashError::ResponseStreamParseError)?; Ok(Some(message)) } ChunkType::Done() => Ok(None), @@ -244,7 +245,10 @@ impl StreamResponse { // Now we can mutably borrow self for extract_next_message if let Some(message) = self.extract_next_message(&chunk) { match message.as_slice() { - b"[DONE]" => return Ok(ChunkType::Done()), + b"[DONE]" => { + self.response = None; + return Ok(ChunkType::Done()); + } _ => return Ok(ChunkType::Message(message)), } } @@ -273,66 +277,141 @@ impl StreamResponse { #[cfg(test)] mod tests { - use super::*; - - #[test] - fn test_single_complete_message() { - let mut processor = StreamResponse::default(); - let chunk = Bytes::from("Hello\n\nWorld"); + use crate::rate_limited_client::RateLimitedClient; - let message = processor.extract_next_message(&chunk); - assert_eq!(message, Some(b"Hello".to_vec())); - assert_eq!(processor.buffer, b"World"); + use super::*; + use reqwest::{Method, Url}; + use httpmock::prelude::*; + + #[tokio::test] + async fn test_send_request_success() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(200); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + assert!(result.is_ok()); + mock.assert(); } - #[test] - fn test_message_in_multiple_chunks() { - let mut processor = StreamResponse::default(); - - assert_eq!(processor.extract_next_message(&Bytes::from("Hel")), None); - assert_eq!(processor.extract_next_message(&Bytes::from("lo Wo")), None); - assert_eq!( - processor.extract_next_message(&Bytes::from("rld\n\nNext")), - Some(b"Hello World".to_vec()) - ); - assert_eq!(processor.buffer, b"Next"); + #[tokio::test] + async fn test_send_request_daily_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429) + .header("RateLimit-Limit", "1000") + .header("RateLimit-Reset", "3600"); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::DailyRateLimitExceeded { reset }) => assert_eq!(reset, 3600), + _ => panic!("Expected DailyRateLimitExceeded error"), + } + mock.assert(); } - #[test] - fn test_multiple_messages_in_single_chunk() { - let mut processor = StreamResponse::default(); - let chunk = Bytes::from("First\n\nSecond\n\nThird"); - - assert_eq!( - processor.extract_next_message(&chunk), - Some(b"First".to_vec()) - ); - assert_eq!(processor.buffer, b"Second\n\nThird"); - - assert_eq!( - processor.extract_next_message(&Bytes::from("")), - Some(b"Second".to_vec()) - ); - assert_eq!(processor.buffer, b"Third"); + #[tokio::test] + async fn test_send_request_burst_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429) + .header("Burst-RateLimit-Limit", "100") + .header("Burst-RateLimit-Reset", "60"); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::BurstRateLimitExceeded { reset }) => assert_eq!(reset, 60), + _ => panic!("Expected BurstRateLimitExceeded error"), + } + mock.assert(); } - #[test] - fn test_empty_message() { - let mut processor = StreamResponse::default(); - - assert_eq!( - processor.extract_next_message(&Bytes::from("\n\nAfter")), - Some(b"".to_vec()) - ); - assert_eq!(processor.buffer, b"After"); + #[tokio::test] + async fn test_send_request_chat_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429) + .header("x-ratelimit-limit-requests", "100") + .header("x-ratelimit-reset-requests", "30") + .header("x-ratelimit-reset-tokens", "45"); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::ChatRateLimitExceeded { reset_requests, reset_tokens }) => { + assert_eq!(reset_requests, 30); + assert_eq!(reset_tokens, 45); + }, + _ => panic!("Expected ChatRateLimitExceeded error"), + } + mock.assert(); } - #[test] - fn test_no_complete_message() { - let mut processor = StreamResponse::default(); - - assert_eq!(processor.extract_next_message(&Bytes::from("Hello")), None); - assert_eq!(processor.extract_next_message(&Bytes::from(" World")), None); - assert_eq!(processor.buffer, b"Hello World"); + #[tokio::test] + async fn test_send_request_unspecified_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::UnspecifiedRateLimitExceeded) => (), + _ => panic!("Expected UnspecifiedRateLimitExceeded error"), + } + mock.assert(); } -} +} \ No newline at end of file diff --git a/src/rate_limited_client.rs b/src/rate_limited_client.rs index ef7dd6b..4cddf82 100644 --- a/src/rate_limited_client.rs +++ b/src/rate_limited_client.rs @@ -74,3 +74,142 @@ fn parse_reset_time(headers: &HeaderMap, header_name: &str) -> u64 { .and_then(|s| s.parse::().ok()) .unwrap_or(0) } + +#[cfg(test)] +mod tests { + use super::*; + use reqwest::Method; + use httpmock::prelude::*; + + #[tokio::test] + async fn test_send_request_success() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(200); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + assert!(result.is_ok()); + mock.assert(); + } + + #[tokio::test] + async fn test_send_request_daily_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429) + .header("RateLimit-Limit", "1000") + .header("RateLimit-Reset", "3600"); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::DailyRateLimitExceeded { reset }) => assert_eq!(reset, 3600), + _ => panic!("Expected DailyRateLimitExceeded error"), + } + mock.assert(); + } + + #[tokio::test] + async fn test_send_request_burst_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429) + .header("Burst-RateLimit-Limit", "100") + .header("Burst-RateLimit-Reset", "60"); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::BurstRateLimitExceeded { reset }) => assert_eq!(reset, 60), + _ => panic!("Expected BurstRateLimitExceeded error"), + } + mock.assert(); + } + + #[tokio::test] + async fn test_send_request_chat_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429) + .header("x-ratelimit-limit-requests", "100") + .header("x-ratelimit-reset-requests", "30") + .header("x-ratelimit-reset-tokens", "45"); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::ChatRateLimitExceeded { reset_requests, reset_tokens }) => { + assert_eq!(reset_requests, 30); + assert_eq!(reset_tokens, 45); + }, + _ => panic!("Expected ChatRateLimitExceeded error"), + } + mock.assert(); + } + + #[tokio::test] + async fn test_send_request_unspecified_rate_limit_exceeded() { + // Arrange + let server = MockServer::start_async().await; + let mock = server.mock(|when, then| { + when.method(GET) + .path("/test"); + then.status(429); + }); + + let client = RateLimitedClient::new("test_api_key".to_string()); + let url = Url::parse(&format!("{}/test", &server.base_url())).unwrap(); + let request_builder = client.get_request_builder(Method::GET, url); + + // Act + let result = client.send_request(request_builder).await; + + // Assert + match result { + Err(QstashError::UnspecifiedRateLimitExceeded) => (), + _ => panic!("Expected UnspecifiedRateLimitExceeded error"), + } + mock.assert(); + } +} \ No newline at end of file