From cf7cf94d9cdb61a66ff034429e4363ba2bcb8842 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Mon, 1 Apr 2024 16:40:52 -0700 Subject: [PATCH] feat: Add API Docs and Improve Inference Readability (#331) **Reason for Change**: This change adds plenty more documentation and OpenAPI spec for our inference. It also enables the use of preferred non-NVIDIA nodes without crashing. **Requirements** - [x] added unit tests and e2e tests (if applicable). **Issue Fixed**: Fixes #321 **Notes for Reviewers**: --- .../inference/text-generation/api_spec.json | 599 ++++++++++++++++++ .../text-generation/inference_api.py | 284 ++++++++- .../text-generation/requirements.txt | 1 + .../tests/test_inference_api.py | 86 ++- 4 files changed, 919 insertions(+), 51 deletions(-) create mode 100644 presets/inference/text-generation/api_spec.json diff --git a/presets/inference/text-generation/api_spec.json b/presets/inference/text-generation/api_spec.json new file mode 100644 index 000000000..480fa97e4 --- /dev/null +++ b/presets/inference/text-generation/api_spec.json @@ -0,0 +1,599 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "FastAPI", + "version": "0.1.0" + }, + "paths": { + "/": { + "get": { + "summary": "Home Endpoint", + "description": "A simple endpoint that indicates the server is running.\nNo parameters are required. Returns a message indicating the server status.", + "operationId": "home__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HomeResponse" + } + } + } + } + } + } + }, + "/healthz": { + "get": { + "summary": "Health Check Endpoint", + "operationId": "health_check_healthz_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HealthStatus" + }, + "example": { + "status": "Healthy" + } + } + } + }, + "500": { + "description": "Error Response", + "content": { + "application/json": { + "examples": { + "model_uninitialized": { + "summary": "Model not initialized", + "value": { + "detail": "Model not initialized" + } + }, + "pipeline_uninitialized": { + "summary": "Pipeline not initialized", + "value": { + "detail": "Pipeline not initialized" + } + } + } + } + } + } + } + } + }, + "/chat": { + "post": { + "summary": "Chat Endpoint", + "description": "Processes chat requests, generating text based on the specified pipeline (text generation or conversational).\nValidates required parameters based on the pipeline and returns the generated text.", + "operationId": "generate_text_chat_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnifiedRequestModel" + }, + "examples": { + "text_generation_example": { + "summary": "Text Generation Example", + "description": "An example of a text generation request.", + "value": { + "prompt": "Tell me a joke", + "return_full_text": true, + "clean_up_tokenization_spaces": false, + "generate_kwargs": { + "max_length": 200, + "min_length": 0, + "do_sample": true, + "early_stopping": false, + "num_beams": 1, + "temperature": 1, + "top_k": 10, + "top_p": 1, + "typical_p": 1, + "repetition_penalty": 1, + "eos_token_id": 11 + } + } + }, + "conversation_example": { + "summary": "Conversation Example", + "description": "An example of a conversational request.", + "value": { + "messages": [ + { + "role": "user", + "content": "What is your favourite condiment?" + }, + { + "role": "assistant", + "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!" + }, + { + "role": "user", + "content": "Do you have mayonnaise recipes?" + } + ], + "return_full_text": true, + "clean_up_tokenization_spaces": false, + "generate_kwargs": { + "max_length": 200, + "min_length": 0, + "do_sample": true, + "early_stopping": false, + "num_beams": 1, + "temperature": 1, + "top_k": 10, + "top_p": 1, + "typical_p": 1, + "repetition_penalty": 1, + "eos_token_id": 11 + } + } + } + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {}, + "examples": { + "text_generation": { + "summary": "Text Generation Response", + "value": { + "Result": "Generated text based on the prompt." + } + }, + "conversation": { + "summary": "Conversation Response", + "value": { + "Result": "Response to the last message in the conversation." + } + } + } + } + } + }, + "400": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "examples": { + "missing_prompt": { + "summary": "Missing Prompt", + "value": { + "detail": "Text generation parameter prompt required" + } + }, + "missing_messages": { + "summary": "Missing Messages", + "value": { + "detail": "Conversational parameter messages required" + } + } + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } + }, + "/metrics": { + "get": { + "summary": "Metrics Endpoint", + "description": "Provides system metrics, including GPU details if available, or CPU and memory usage otherwise.\nUseful for monitoring the resource utilization of the server running the ML models.", + "operationId": "get_metrics_metrics_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MetricsResponse" + }, + "examples": { + "gpu_metrics": { + "summary": "Example when GPUs are available", + "value": { + "gpu_info": [ + { + "id": "GPU-1234", + "name": "GeForce GTX 950", + "load": "25.00%", + "temperature": "55 C", + "memory": { + "used": "1.00 GB", + "total": "2.00 GB" + } + } + ] + } + }, + "cpu_metrics": { + "summary": "Example when only CPU is available", + "value": { + "cpu_info": { + "load_percentage": 20, + "physical_cores": 4, + "total_cores": 8, + "memory": { + "used": "4.00 GB", + "total": "16.00 GB" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "CPUInfo": { + "properties": { + "load_percentage": { + "type": "number", + "title": "Load Percentage" + }, + "physical_cores": { + "type": "integer", + "title": "Physical Cores" + }, + "total_cores": { + "type": "integer", + "title": "Total Cores" + }, + "memory": { + "$ref": "#/components/schemas/MemoryInfo" + } + }, + "type": "object", + "required": [ + "load_percentage", + "physical_cores", + "total_cores", + "memory" + ], + "title": "CPUInfo" + }, + "ErrorResponse": { + "properties": { + "detail": { + "type": "string", + "title": "Detail" + } + }, + "type": "object", + "required": [ + "detail" + ], + "title": "ErrorResponse" + }, + "GPUInfo": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "load": { + "type": "string", + "title": "Load" + }, + "temperature": { + "type": "string", + "title": "Temperature" + }, + "memory": { + "$ref": "#/components/schemas/MemoryInfo" + } + }, + "type": "object", + "required": [ + "id", + "name", + "load", + "temperature", + "memory" + ], + "title": "GPUInfo" + }, + "GenerateKwargs": { + "properties": { + "max_length": { + "type": "integer", + "title": "Max Length", + "default": 200 + }, + "min_length": { + "type": "integer", + "title": "Min Length", + "default": 0 + }, + "do_sample": { + "type": "boolean", + "title": "Do Sample", + "default": true + }, + "early_stopping": { + "type": "boolean", + "title": "Early Stopping", + "default": false + }, + "num_beams": { + "type": "integer", + "title": "Num Beams", + "default": 1 + }, + "temperature": { + "type": "number", + "title": "Temperature", + "default": 1 + }, + "top_k": { + "type": "integer", + "title": "Top K", + "default": 10 + }, + "top_p": { + "type": "number", + "title": "Top P", + "default": 1 + }, + "typical_p": { + "type": "number", + "title": "Typical P", + "default": 1 + }, + "repetition_penalty": { + "type": "number", + "title": "Repetition Penalty", + "default": 1 + }, + "pad_token_id": { + "type": "integer", + "title": "Pad Token Id" + }, + "eos_token_id": { + "type": "integer", + "title": "Eos Token Id", + "default": 11 + } + }, + "type": "object", + "title": "GenerateKwargs", + "example": { + "max_length": 200, + "temperature": 0.7, + "top_p": 0.9, + "additional_param": "Example value" + } + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "HealthStatus": { + "properties": { + "status": { + "type": "string", + "title": "Status", + "example": "Healthy" + } + }, + "type": "object", + "required": [ + "status" + ], + "title": "HealthStatus" + }, + "HomeResponse": { + "properties": { + "message": { + "type": "string", + "title": "Message", + "example": "Server is running" + } + }, + "type": "object", + "required": [ + "message" + ], + "title": "HomeResponse" + }, + "MemoryInfo": { + "properties": { + "used": { + "type": "string", + "title": "Used" + }, + "total": { + "type": "string", + "title": "Total" + } + }, + "type": "object", + "required": [ + "used", + "total" + ], + "title": "MemoryInfo" + }, + "Message": { + "properties": { + "role": { + "type": "string", + "title": "Role" + }, + "content": { + "type": "string", + "title": "Content" + } + }, + "type": "object", + "required": [ + "role", + "content" + ], + "title": "Message" + }, + "MetricsResponse": { + "properties": { + "gpu_info": { + "items": { + "$ref": "#/components/schemas/GPUInfo" + }, + "type": "array", + "title": "Gpu Info" + }, + "cpu_info": { + "$ref": "#/components/schemas/CPUInfo" + } + }, + "type": "object", + "title": "MetricsResponse" + }, + "UnifiedRequestModel": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt", + "description": "Prompt for text generation. Required for text-generation pipeline. Do not use with 'messages'." + }, + "return_full_text": { + "type": "boolean", + "title": "Return Full Text", + "description": "Return full text if True, else only added text", + "default": true + }, + "clean_up_tokenization_spaces": { + "type": "boolean", + "title": "Clean Up Tokenization Spaces", + "description": "Clean up extra spaces in text output", + "default": false + }, + "prefix": { + "type": "string", + "title": "Prefix", + "description": "Prefix added to prompt" + }, + "handle_long_generation": { + "type": "string", + "title": "Handle Long Generation", + "description": "Strategy to handle long generation" + }, + "generate_kwargs": { + "allOf": [ + { + "$ref": "#/components/schemas/GenerateKwargs" + } + ], + "title": "Generate Kwargs", + "description": "Additional kwargs for generate method" + }, + "messages": { + "items": { + "$ref": "#/components/schemas/Message" + }, + "type": "array", + "title": "Messages", + "description": "Messages for conversational model. Required for conversational pipeline. Do not use with 'prompt'." + } + }, + "type": "object", + "title": "UnifiedRequestModel" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "type": "array", + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + }, + "type": "object", + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError" + } + } + } +} \ No newline at end of file diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index f6c604a54..c23a15c6b 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -2,13 +2,15 @@ # Licensed under the MIT license. import os from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Optional import GPUtil +import psutil import torch import transformers import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import Body, FastAPI, HTTPException +from fastapi.responses import Response from pydantic import BaseModel, Extra, Field from transformers import (AutoModelForCausalLM, AutoTokenizer, GenerationConfig, HfArgumentParser) @@ -35,7 +37,7 @@ class ModelConfig: load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"}) torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"}) device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"}) - + # Method to process additional arguments def process_additional_args(self, addt_args: List[str]): """ @@ -51,7 +53,7 @@ def process_additional_args(self, addt_args: List[str]): else: value = True # Assign a True value for standalone flags i += 1 # Move to the next item - + addt_args_dict[key] = value # Update the ModelConfig instance with the additional args @@ -102,20 +104,57 @@ def __post_init__(self): try: # Attempt to load the generation configuration default_generate_config = GenerationConfig.from_pretrained( - args.pretrained_model_name_or_path, + args.pretrained_model_name_or_path, local_files_only=args.local_files_only ).to_dict() except Exception as e: default_generate_config = {} -@app.get('/') +class HomeResponse(BaseModel): + message: str = Field(..., example="Server is running") +@app.get('/', response_model=HomeResponse, summary="Home Endpoint") def home(): - return "Server is running", 200 + """ + A simple endpoint that indicates the server is running. + No parameters are required. Returns a message indicating the server status. + """ + return {"message": "Server is running"} -@app.get("/healthz") +class HealthStatus(BaseModel): + status: str = Field(..., example="Healthy") +@app.get( + "/healthz", + response_model=HealthStatus, + summary="Health Check Endpoint", + responses={ + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "example": {"status": "Healthy"} + } + } + }, + 500: { + "description": "Error Response", + "content": { + "application/json": { + "examples": { + "model_uninitialized": { + "summary": "Model not initialized", + "value": {"detail": "Model not initialized"} + }, + "pipeline_uninitialized": { + "summary": "Pipeline not initialized", + "value": {"detail": "Pipeline not initialized"} + } + } + } + } + } + } +) def health_check(): - if not torch.cuda.is_available(): - raise HTTPException(status_code=500, detail="No GPU available") if not model: raise HTTPException(status_code=500, detail="Model not initialized") if not pipeline: @@ -137,10 +176,22 @@ class GenerateKwargs(BaseModel): eos_token_id: Optional[int] = tokenizer.eos_token_id class Config: extra = Extra.allow # Allows for additional fields not explicitly defined + schema_extra = { + "example": { + "max_length": 200, + "temperature": 0.7, + "top_p": 0.9, + "additional_param": "Example value" + } + } + +class Message(BaseModel): + role: str + content: str class UnifiedRequestModel(BaseModel): # Fields for text generation - prompt: Optional[str] = Field(None, description="Prompt for text generation") + prompt: Optional[str] = Field(None, description="Prompt for text generation. Required for text-generation pipeline. Do not use with 'messages'.") return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") prefix: Optional[str] = Field(None, description="Prefix added to prompt") @@ -148,10 +199,103 @@ class UnifiedRequestModel(BaseModel): generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") # Field for conversational model - messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model") + messages: Optional[List[Message]] = Field(None, description="Messages for conversational model. Required for conversational pipeline. Do not use with 'prompt'.") + def messages_to_dict_list(self): + return [message.dict() for message in self.messages] if self.messages else [] + +class ErrorResponse(BaseModel): + detail: str -@app.post("/chat") -def generate_text(request_model: UnifiedRequestModel): +@app.post( + "/chat", + summary="Chat Endpoint", + responses={ + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "examples": { + "text_generation": { + "summary": "Text Generation Response", + "value": { + "Result": "Generated text based on the prompt." + } + }, + "conversation": { + "summary": "Conversation Response", + "value": { + "Result": "Response to the last message in the conversation." + } + } + } + } + } + }, + 400: { + "model": ErrorResponse, + "description": "Validation Error", + "content": { + "application/json": { + "examples": { + "missing_prompt": { + "summary": "Missing Prompt", + "value": {"detail": "Text generation parameter prompt required"} + }, + "missing_messages": { + "summary": "Missing Messages", + "value": {"detail": "Conversational parameter messages required"} + } + } + } + } + }, + 500: { + "model": ErrorResponse, + "description": "Internal Server Error" + } + } +) +def generate_text( + request_model: Annotated[ + UnifiedRequestModel, + Body( + openapi_examples={ + "text_generation_example": { + "summary": "Text Generation Example", + "description": "An example of a text generation request.", + "value": { + "prompt": "Tell me a joke", + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "prefix": None, + "handle_long_generation": None, + "generate_kwargs": GenerateKwargs().dict(), + }, + }, + "conversation_example": { + "summary": "Conversation Example", + "description": "An example of a conversational request.", + "value": { + "messages": [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} + ], + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "prefix": None, + "handle_long_generation": None, + "generate_kwargs": GenerateKwargs().dict(), + }, + }, + }, + ), + ], +): + """ + Processes chat requests, generating text based on the specified pipeline (text generation or conversational). + Validates required parameters based on the pipeline and returns the generated text. + """ user_generate_kwargs = request_model.generate_kwargs.dict() if request_model.generate_kwargs else {} generate_kwargs = {**default_generate_config, **user_generate_kwargs} @@ -176,12 +320,12 @@ def generate_text(request_model: UnifiedRequestModel): return {"Result": result} - elif args.pipeline == "conversational": + elif args.pipeline == "conversational": if not request_model.messages: raise HTTPException(status_code=400, detail="Conversational parameter messages required") response = pipeline( - request_model.messages, + request_model.messages_to_dict_list(), clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, **generate_kwargs ) @@ -190,27 +334,101 @@ def generate_text(request_model: UnifiedRequestModel): else: raise HTTPException(status_code=400, detail="Invalid pipeline type") -@app.get("/metrics") +class MemoryInfo(BaseModel): + used: str + total: str + +class CPUInfo(BaseModel): + load_percentage: float + physical_cores: int + total_cores: int + memory: MemoryInfo + +class GPUInfo(BaseModel): + id: str + name: str + load: str + temperature: str + memory: MemoryInfo + +class MetricsResponse(BaseModel): + gpu_info: Optional[List[GPUInfo]] = None + cpu_info: Optional[CPUInfo] = None + +@app.get( + "/metrics", + response_model=MetricsResponse, + summary="Metrics Endpoint", + responses={ + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "examples": { + "gpu_metrics": { + "summary": "Example when GPUs are available", + "value": { + "gpu_info": [{"id": "GPU-1234", "name": "GeForce GTX 950", "load": "25.00%", "temperature": "55 C", "memory": {"used": "1.00 GB", "total": "2.00 GB"}}], + "cpu_info": None # Indicates CPUs info might not be present when GPUs are available + } + }, + "cpu_metrics": { + "summary": "Example when only CPU is available", + "value": { + "gpu_info": None, # Indicates GPU info might not be present when only CPU is available + "cpu_info": {"load_percentage": 20.0, "physical_cores": 4, "total_cores": 8, "memory": {"used": "4.00 GB", "total": "16.00 GB"}} + } + } + } + } + } + }, + 500: { + "description": "Internal Server Error", + "model": ErrorResponse, + } + } +) def get_metrics(): + """ + Provides system metrics, including GPU details if available, or CPU and memory usage otherwise. + Useful for monitoring the resource utilization of the server running the ML models. + """ try: - gpus = GPUtil.getGPUs() - gpu_info = [] - for gpu in gpus: - gpu_info.append({ - "id": gpu.id, - "name": gpu.name, - "load": f"{gpu.load * 100:.2f}%", # Format as percentage - "temperature": f"{gpu.temperature} C", - "memory": { - "used": f"{gpu.memoryUsed / 1024:.2f} GB", - "total": f"{gpu.memoryTotal / 1024:.2f} GB" - } - }) - return {"gpu_info": gpu_info} + if torch.cuda.is_available(): + gpus = GPUtil.getGPUs() + gpu_info = [GPUInfo( + id=gpu.id, + name=gpu.name, + load=f"{gpu.load * 100:.2f}%", + temperature=f"{gpu.temperature} C", + memory=MemoryInfo( + used=f"{gpu.memoryUsed / (1024 ** 3):.2f} GB", + total=f"{gpu.memoryTotal / (1024 ** 3):.2f} GB" + ) + ) for gpu in gpus] + return MetricsResponse(gpu_info=gpu_info) + else: + # Gather CPU metrics + cpu_usage = psutil.cpu_percent(interval=1, percpu=False) + physical_cores = psutil.cpu_count(logical=False) + total_cores = psutil.cpu_count(logical=True) + virtual_memory = psutil.virtual_memory() + memory = MemoryInfo( + used=f"{virtual_memory.used / (1024 ** 3):.2f} GB", + total=f"{virtual_memory.total / (1024 ** 3):.2f} GB" + ) + cpu_info = CPUInfo( + load_percentage=cpu_usage, + physical_cores=physical_cores, + total_cores=total_cores, + memory=memory + ) + return MetricsResponse(cpu_info=cpu_info) except Exception as e: - return {"error": str(e)} + raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set port = 5000 + local_rank # Adjust port based on local rank - uvicorn.run(app=app, host='0.0.0.0', port=port) + uvicorn.run(app=app, host='0.0.0.0', port=port) \ No newline at end of file diff --git a/presets/inference/text-generation/requirements.txt b/presets/inference/text-generation/requirements.txt index 8a7c50dbe..1d1c845a5 100644 --- a/presets/inference/text-generation/requirements.txt +++ b/presets/inference/text-generation/requirements.txt @@ -8,6 +8,7 @@ uvicorn[standard]==0.23.2 bitsandbytes==0.42.0 deepspeed==0.11.1 gputil==1.4.0 +psutil==5.9.8 # For UTs pytest==8.0.0 httpx==0.26.0 \ No newline at end of file diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index c15b0f38f..535d0f4e2 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -4,7 +4,6 @@ from unittest.mock import patch import pytest -import torch from fastapi.testclient import TestClient from transformers import AutoTokenizer @@ -44,7 +43,7 @@ def test_conversational(configured_app): client = TestClient(configured_app) messages = [ {"role": "user", "content": "What is your favourite condiment?"}, - {"role": "assistant", "content": "Well, Im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever Im cooking up in the kitchen!"}, + {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, {"role": "user", "content": "Do you have mayonnaise recipes?"} ] request_data = { @@ -102,17 +101,12 @@ def test_missing_prompt(configured_app): def test_read_main(configured_app): client = TestClient(configured_app) response = client.get("/") - server_msg, status_code = response.json() - assert server_msg == "Server is running" - assert status_code == 200 + assert response.status_code == 200 + assert response.json() == {"message": "Server is running"} def test_health_check(configured_app): - device = "GPU" if torch.cuda.is_available() else "CPU" - if device != "GPU": - pytest.skip("Skipping healthz endpoint check - running on CPU") client = TestClient(configured_app) response = client.get("/healthz") - # Assuming we have a GPU available assert response.status_code == 200 assert response.json() == {"status": "Healthy"} @@ -122,17 +116,73 @@ def test_get_metrics(configured_app): assert response.status_code == 200 assert "gpu_info" in response.json() +def test_get_metrics_with_gpus(configured_app): + client = TestClient(configured_app) + # Define a simple mock GPU object with the necessary attributes + class MockGPU: + def __init__(self, id, name, load, temperature, memoryUsed, memoryTotal): + self.id = id + self.name = name + self.load = load + self.temperature = temperature + self.memoryUsed = memoryUsed + self.memoryTotal = memoryTotal + + # Create a mock GPU object with the desired attributes + mock_gpu = MockGPU( + id="GPU-1234", + name="GeForce GTX 950", + load=0.25, # 25% + temperature=55, # 55 C + memoryUsed=1 * (1024 ** 3), # 1 GB + memoryTotal=2 * (1024 ** 3) # 2 GB + ) + + # Mock torch.cuda.is_available to simulate an environment with GPUs + # Mock GPUtil.getGPUs to return a list containing the mock GPU object + with patch('torch.cuda.is_available', return_value=True), \ + patch('GPUtil.getGPUs', return_value=[mock_gpu]): + response = client.get("/metrics") + assert response.status_code == 200 + data = response.json() + + # Assertions to verify that the GPU info is correctly returned in the response + assert data["gpu_info"] != [] + assert len(data["gpu_info"]) == 1 + gpu_data = data["gpu_info"][0] + + assert gpu_data["id"] == "GPU-1234" + assert gpu_data["name"] == "GeForce GTX 950" + assert gpu_data["load"] == "25.00%" + assert gpu_data["temperature"] == "55 C" + assert gpu_data["memory"]["used"] == "1.00 GB" + assert gpu_data["memory"]["total"] == "2.00 GB" + assert data["cpu_info"] is None # Assuming CPU info is not present when GPUs are available + def test_get_metrics_no_gpus(configured_app): client = TestClient(configured_app) - with patch('GPUtil.getGPUs', return_value=[]) as mock_getGPUs: + # Mock GPUtil.getGPUs to simulate an environment without GPUs + with patch('torch.cuda.is_available', return_value=False), \ + patch('psutil.cpu_percent', return_value=20.0), \ + patch('psutil.cpu_count', side_effect=[4, 8]), \ + patch('psutil.virtual_memory') as mock_virtual_memory: + mock_virtual_memory.return_value.used = 4 * (1024 ** 3) # 4 GB + mock_virtual_memory.return_value.total = 16 * (1024 ** 3) # 16 GB response = client.get("/metrics") assert response.status_code == 200 - assert response.json()["gpu_info"] == [] + data = response.json() + assert data["gpu_info"] is None # No GPUs available + assert data["cpu_info"] is not None # CPU info should be present + assert data["cpu_info"]["load_percentage"] == 20.0 + assert data["cpu_info"]["physical_cores"] == 4 + assert data["cpu_info"]["total_cores"] == 8 + assert data["cpu_info"]["memory"]["used"] == "4.00 GB" + assert data["cpu_info"]["memory"]["total"] == "16.00 GB" def test_default_generation_params(configured_app): if configured_app.test_config['pipeline'] != 'text-generation': pytest.skip("Skipping non-text-generation tests") - + client = TestClient(configured_app) request_data = { @@ -144,14 +194,14 @@ def test_default_generation_params(configured_app): with patch('inference_api.pipeline') as mock_pipeline: mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function - + response = client.post("/chat", json=request_data) - + assert response.status_code == 200 data = response.json() assert "Result" in data assert data["Result"] == "Mocked response", "The response content doesn't match the expected mock response" - + # Check the default args _, kwargs = mock_pipeline.call_args assert kwargs['max_length'] == 200 @@ -187,7 +237,7 @@ def test_generation_with_max_length(configured_app): data = response.json() print("Response: ", data["Result"]) assert "Result" in data, "The response should contain a 'Result' key" - + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) prompt_tokens = tokenizer.tokenize(prompt) total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt @@ -207,7 +257,7 @@ def test_generation_with_min_length(configured_app): client = TestClient(configured_app) prompt = "This prompt requests a response of a certain minimum length to test the functionality." min_length = 30 - max_length = 40 + max_length = 40 request_data = { "prompt": prompt, @@ -221,7 +271,7 @@ def test_generation_with_min_length(configured_app): assert response.status_code == 200 data = response.json() assert "Result" in data, "The response should contain a 'Result' key" - + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) prompt_tokens = tokenizer.tokenize(prompt) total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt