Skip to content

Commit

Permalink
return hashes of output files from http prompt answering
Browse files Browse the repository at this point in the history
  • Loading branch information
mpnowacki-reef committed Dec 9, 2024
1 parent b591ff7 commit 0c702bb
Show file tree
Hide file tree
Showing 12 changed files with 2,023 additions and 933 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Run QA

on:
push:
branches: [master, main]
pull_request:
branches: [master, main]

env:
PYTHON_DEFAULT_VERSION: "3.11"

jobs:
test:
timeout-minutes: 10
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python ${{ env.PYTHON_DEFAULT_VERSION }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_DEFAULT_VERSION }}
cache: "pip"
- name: Install dependencies
run: python -m pip install --upgrade nox 'pdm==2.19.3'
- name: Install dependencies
run: pdm install -G test
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Run unit tests
run: pytest tests/integration_mock/
1,562 changes: 818 additions & 744 deletions pdm.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ dependencies = [
"setuptools>=74.1.2",
"flask>=3.0.3",
]
[dependency-groups]
test = [
"pytest",
]
requires-python = "==3.11.*"
readme = "README.md"
license = {text = "MIT"}
Expand Down
7 changes: 7 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[pytest]
python_files = tests.py test_*.py *_tests.py
filterwarnings =
error
default::DeprecationWarning
default:Error when trying to teardown test databases
addopts = -s
96 changes: 96 additions & 0 deletions src/compute_horde_prompt_solver/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import dataclasses
import pathlib
from typing import Optional, List


@dataclasses.dataclass
class Config:
input_files: List[pathlib.Path]
output_dir: pathlib.Path
model: str
max_tokens: int
temperature: float
top_p: float
dtype: str
seed: Optional[int]
server: Optional[bool]
server_port: int
mock: bool


def parse_arguments() -> Config:
parser = argparse.ArgumentParser(
description="Generate responses for prompts using vLLM."
)
parser.add_argument(
"input_files",
nargs="+",
type=pathlib.Path,
help="Input files containing prompts",
)
parser.add_argument(
"--output-dir",
default="./output",
type=pathlib.Path,
help="Directory to save output files",
)
parser.add_argument(
"--model", default="microsoft/Phi-3.5-mini-instruct", help="Model name or path"
)
parser.add_argument(
"--max-tokens",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature", type=float, default=0, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=0.1, help="Top-p sampling parameter"
)
parser.add_argument(
"--dtype", default="auto",
choices=("auto", "half", "float16", "bfloat16", "float", "float32"),
help=(
"model dtype - setting `float32` helps with deterministic prompts in different batches"
)
)

seed_or_server_group = parser.add_mutually_exclusive_group(required=True)
seed_or_server_group.add_argument(
"--seed", type=int, help="Random seed for reproducibility"
)
seed_or_server_group.add_argument(
"--server",
action="store_true",
help="Spin up a temporary HTTP server to receive the seed",
)

parser.add_argument(
"--server-port",
type=int,
default=8000,
help="Port for temporary HTTP server",
)
parser.add_argument(
"--mock",
action="store_true",
help="Don't use an actual model, generate random gibberish based on the input and the seed",
)
args = parser.parse_args()

return Config(
input_files=args.input_files,
output_dir=args.output_dir,
model=args.model,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
dtype=args.dtype,
seed=args.seed,
server=args.server,
server_port=args.server_port,
mock=args.mock,
)
215 changes: 215 additions & 0 deletions src/compute_horde_prompt_solver/prompt_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import abc
import hashlib
import json
import multiprocessing as mp
import pathlib
import queue
import random
import string
from typing import List, Dict

import torch
import vllm
from flask import Flask, Blueprint, jsonify
from vllm import SamplingParams

# Import the set_deterministic function
from deterministic_ml.v1 import set_deterministic

from .config import Config

TIMEOUT = 5 * 60


class BaseLLMProvider(abc.ABC):
@abc.abstractmethod
def generate_responses(self, prompts: List[str], sampling_params: SamplingParams) -> Dict[str, str]: ...


class GPULLMProvider(BaseLLMProvider):
def __init__(self, model_name: str, dtype: str = "auto"):
self.model_name = model_name
self.dtype = dtype
self._model = None

@property
def model(self):
if self._model is None:
return self._model
self._model = self.setup_model()
return self._model

def setup_model(self) -> vllm.LLM:
gpu_count = torch.cuda.device_count()
return vllm.LLM(
model=self.model_name,
tensor_parallel_size=gpu_count,
max_model_len=6144,
enforce_eager=True,
dtype=self.dtype,
)

def make_prompt(self, prompt: str) -> str:
system_msg = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ You are a helpful AI assistant }}}}<|eot_id|>"
user_msg = f"<|start_header_id|>user<|end_header_id|>\n{{{{ {prompt} }}}}<|eot_id|>"
assistant_start = "<|start_header_id|>assistant<|end_header_id|>"
return f"{system_msg}{user_msg}{assistant_start}"

def generate_responses(
self, prompts: List[str], sampling_params: SamplingParams
) -> Dict[str, str]:
requests = [self.make_prompt(prompt) for prompt in prompts]
responses = self.model.generate(requests, sampling_params, use_tqdm=True)
return {
prompt: response.outputs[0].text for prompt, response in zip(prompts, responses)
}


class MockLLMProvider(BaseLLMProvider):
def generate_responses(self, prompts: List[str], sampling_params: SamplingParams) -> Dict[str, str]:
result = {}
for prompt in prompts:
generator = random.Random(str(sampling_params.seed) + prompt)
answer_length = generator.randint(100, 200)
answer = ''.join(generator.choice(string.ascii_letters) for _ in range(answer_length))
result[prompt] = answer
return result


def _run_server(
start_server_event: mp.Event,
seed_queue: mp.Queue,
result_queue: mp.Queue,
ready_to_terminate_event: mp.Event,
config: Config,
):
start_server_event.wait()

app = Flask("compute_horde_prompt_solver")

@app.route("/health")
def server_healthcheck():
return {"status": "ok"}

@app.route("/execute-job", methods=["POST"])
def execute_job():
try:
from flask import request

seed_raw = request.json.get("seed")
seed = int(seed_raw)
seed_queue.put(seed)
result = result_queue.get(timeout=TIMEOUT)
return jsonify(result)
finally:
# The seed_queue.put(seed) can fail (request not having int seed etc.),
# so we always put a None to make sure process is terminated when the view returns.
seed_queue.put(None)

@app.route("/terminate")
def terminate():
ready_to_terminate_event.set()
return {"status": "ok"}

app.run(
host="0.0.0.0",
port=config.server_port,
debug=False,
)


class BaseSolver(abc.ABC):
def __init__(
self,
provider: BaseLLMProvider,
config: Config
):
self.provider = provider
self.config = config

def process_file(self, input_file, sampling_params):
with open(input_file, "r") as f:
prompts = [line.strip() for line in f if line.strip()]

responses = self.provider.generate_responses(prompts, sampling_params)

output_file = self.config.output_dir / f"{input_file.stem}.json"
self.save_output_file(responses, output_file)

def save_output_file(self, responses: Dict[str, str], output_file: pathlib.Path):
with open(output_file, "w") as f:
json.dump(responses, f, indent=2)

def get_sampling_params(self, seed):
set_deterministic(seed)

return SamplingParams(
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
seed=seed,
)

@abc.abstractmethod
def run(self): ...


class CLISolver(BaseSolver):

def run(self):
self.config.output_dir.mkdir(parents=True, exist_ok=True)

sampling_params = self.get_sampling_params(self.config.seed)

for input_file in self.config.input_files:
self.process_file(input_file, sampling_params)


class HttpSolver(BaseSolver):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_server_event = mp.Event()
self.seed_queue = mp.Queue()
self.result_queue = mp.Queue()
self.ready_to_terminate_event = mp.Event()
self.response_hashes: Dict[str, str] = {}

def save_output_file(self, responses: Dict[str, str], output_file: pathlib.Path):
response_body = json.dumps(responses, indent=2).encode()
self.response_hashes[output_file.as_posix()] = hashlib.sha256(response_body).hexdigest()
pathlib.Path(output_file).write_bytes(response_body)

def run(self):
process = mp.Process(
target=_run_server,
args=(
self.start_server_event,
self.seed_queue,
self.result_queue,
self.ready_to_terminate_event,
self.config,
)
)
process.start()

self.config.output_dir.mkdir(parents=True, exist_ok=True)

self.start_server_event.set()

try:
seed = self.seed_queue.get(block=True, timeout=TIMEOUT)
except queue.Empty:
seed = None

if seed is None:
raise SystemExit("ERROR: provided seed is malformed!")

sampling_params = self.get_sampling_params(seed)

try:
for input_file in self.config.input_files:
self.process_file(input_file, sampling_params)
self.result_queue.put(self.response_hashes)
self.ready_to_terminate_event.wait(timeout=TIMEOUT)
finally:
process.terminate()
Loading

0 comments on commit 0c702bb

Please sign in to comment.