Skip to content

Commit

Permalink
real tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mpnowacki-reef committed Dec 9, 2024
1 parent 0c702bb commit 60e585a
Show file tree
Hide file tree
Showing 6 changed files with 597 additions and 12 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,5 @@ jobs:
run: python -m pip install --upgrade nox 'pdm==2.19.3'
- name: Install dependencies
run: pdm install -G test
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Run unit tests
run: pytest tests/integration_mock/
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@ A tool for generating responses to prompts using vLLM, primarily designed for us

This project provides a script for generating responses to prompts using the vLLM library. It's designed to be flexible and can be run in various environments, including Docker containers and directly from Python.

There is `--mock` that allows for running smoke tests that allow to validate the interface without actaully downloading
a model or having a GPU.

## Features

- Generate responses for multiple prompts
- Configurable model parameters (temperature, top-p, max tokens)
- Support for multiple input files
- Deterministic output with seed setting
- Docker support for easy deployment
- Can be started with a seed known ad-hoc or as an http server which will wait for a seed and then call the model.
This server is designed to serve one request and then be told to shut down

## Installation

Expand All @@ -22,6 +27,10 @@ The project uses `pdm` for dependency management. To install dependencies:
pdm install
```

## Testing

Tests in `integration_mock` are light and can be run on any platform, the ones in `integration_real_llm` will only pass
with an actual nvidia A6000.
## Usage

### Running with Docker
Expand Down
11 changes: 2 additions & 9 deletions src/compute_horde_prompt_solver/prompt_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
import vllm
from flask import Flask, Blueprint, jsonify
from flask import Flask, jsonify
from vllm import SamplingParams

# Import the set_deterministic function
Expand All @@ -30,14 +30,7 @@ class GPULLMProvider(BaseLLMProvider):
def __init__(self, model_name: str, dtype: str = "auto"):
self.model_name = model_name
self.dtype = dtype
self._model = None

@property
def model(self):
if self._model is None:
return self._model
self._model = self.setup_model()
return self._model
self.model = self.setup_model()

def setup_model(self) -> vllm.LLM:
gpu_count = torch.cuda.device_count()
Expand Down
102 changes: 102 additions & 0 deletions tests/integration_real_llm/test_real_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import hashlib
import pathlib
import subprocess
import sys
import tempfile
import time

import pytest
import requests
from huggingface_hub import snapshot_download

TIMEOUT = 180


@pytest.fixture(scope='session', autouse=True)
def download_model():
snapshot_download(
repo_id="microsoft/Phi-3.5-mini-instruct",
local_dir=pathlib.Path(__file__).parent.parent.parent / "src" / "compute_horde_prompt_solver" / "saved_model/",
)


@pytest.mark.parametrize(
"seed,expected_output_file",
[
("1234567891", "expected_real_1234567891_output.json"),
("99", "expected_real_99_output.json"),
],
)
def test_cli(input_file, seed, expected_output_file):
subprocess.check_call(
[
sys.executable,
"-m",
"src.compute_horde_prompt_solver",
"--temperature", "0.5",
"--top-p", "0.8",
"--max-tokens", "256",
"--seed", seed,
"--output-dir", tempfile.gettempdir(),
input_file,
],
timeout=TIMEOUT,
)
expected = (pathlib.Path(__file__).parent.parent / "payload" / expected_output_file).read_text()
actual = pathlib.Path(input_file + '.json').read_text()
assert expected == actual


def get_url_within_time(url, timeout=TIMEOUT):
start_time = time.time()

while time.time() - start_time < timeout:
try:
response = requests.get(url)
response.raise_for_status()
return response
except (requests.HTTPError, requests.ConnectionError):
pass

time.sleep(0.5) # Wait a bit before trying again to not overload the server and your machine.

raise TimeoutError(f"Could not get data from {url} within {timeout} seconds")


@pytest.mark.parametrize(
"seed,expected_output_file",
[
("1234567891", "expected_real_1234567891_output.json"),
("99", "expected_real_99_output.json"),
],
)
def test_http(input_file, seed, expected_output_file):
server = subprocess.Popen(
[
sys.executable,
"-m",
"src.compute_horde_prompt_solver",
"--temperature", "0.5",
"--top-p", "0.8",
"--max-tokens", "256",
"--output-dir", tempfile.gettempdir(),
"--server",
input_file,
],
)
try:
base_url = 'http://localhost:8000/'
get_url_within_time(base_url + 'health')

import time
with requests.post(base_url + 'execute-job', json={"seed": seed}) as resp:
resp.raise_for_status()
hashes = resp.json()
try:
requests.get(base_url + 'terminate')
except:
pass
assert hashes == {input_file + '.json': hashlib.sha256(pathlib.Path(input_file + '.json').read_bytes()).hexdigest()}
finally:
server.terminate()
server.wait()
Loading

0 comments on commit 60e585a

Please sign in to comment.