From 29ce2e6c2d019c42585961f1793987370ba64519 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Tue, 17 Dec 2024 13:29:56 +0000 Subject: [PATCH] rebase --- src/compute_horde_prompt_solver/config.py | 5 +- .../prompt_solver.py | 44 +++++++------ src/compute_horde_prompt_solver/run.py | 2 + tests/conftest.py | 2 +- tests/integration_mock/test_mock.py | 54 ++++++++++------ tests/integration_real_llm/test_real_llm.py | 61 ++++++++++++------- 6 files changed, 106 insertions(+), 62 deletions(-) diff --git a/src/compute_horde_prompt_solver/config.py b/src/compute_horde_prompt_solver/config.py index f20be20..29640de 100644 --- a/src/compute_horde_prompt_solver/config.py +++ b/src/compute_horde_prompt_solver/config.py @@ -51,11 +51,12 @@ def parse_arguments() -> Config: "--top-p", type=float, default=0.1, help="Top-p sampling parameter" ) parser.add_argument( - "--dtype", default="auto", + "--dtype", + default="auto", choices=("auto", "half", "float16", "bfloat16", "float", "float32"), help=( "model dtype - setting `float32` helps with deterministic prompts in different batches" - ) + ), ) seed_or_server_group = parser.add_mutually_exclusive_group(required=True) diff --git a/src/compute_horde_prompt_solver/prompt_solver.py b/src/compute_horde_prompt_solver/prompt_solver.py index dc37635..8d895fc 100644 --- a/src/compute_horde_prompt_solver/prompt_solver.py +++ b/src/compute_horde_prompt_solver/prompt_solver.py @@ -23,7 +23,9 @@ class BaseLLMProvider(abc.ABC): @abc.abstractmethod - def generate_responses(self, prompts: List[str], sampling_params: SamplingParams) -> Dict[str, str]: ... + def generate_responses( + self, prompts: List[str], sampling_params: SamplingParams + ) -> Dict[str, str]: ... class GPULLMProvider(BaseLLMProvider): @@ -44,37 +46,44 @@ def setup_model(self) -> vllm.LLM: def make_prompt(self, prompt: str) -> str: system_msg = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ You are a helpful AI assistant }}}}<|eot_id|>" - user_msg = f"<|start_header_id|>user<|end_header_id|>\n{{{{ {prompt} }}}}<|eot_id|>" + user_msg = ( + f"<|start_header_id|>user<|end_header_id|>\n{{{{ {prompt} }}}}<|eot_id|>" + ) assistant_start = "<|start_header_id|>assistant<|end_header_id|>" return f"{system_msg}{user_msg}{assistant_start}" def generate_responses( - self, prompts: List[str], sampling_params: SamplingParams + self, prompts: List[str], sampling_params: SamplingParams ) -> Dict[str, str]: requests = [self.make_prompt(prompt) for prompt in prompts] responses = self.model.generate(requests, sampling_params, use_tqdm=True) return { - prompt: response.outputs[0].text for prompt, response in zip(prompts, responses) + prompt: response.outputs[0].text + for prompt, response in zip(prompts, responses) } class MockLLMProvider(BaseLLMProvider): - def generate_responses(self, prompts: List[str], sampling_params: SamplingParams) -> Dict[str, str]: + def generate_responses( + self, prompts: List[str], sampling_params: SamplingParams + ) -> Dict[str, str]: result = {} for prompt in prompts: generator = random.Random(str(sampling_params.seed) + prompt) answer_length = generator.randint(100, 200) - answer = ''.join(generator.choice(string.ascii_letters) for _ in range(answer_length)) + answer = "".join( + generator.choice(string.ascii_letters) for _ in range(answer_length) + ) result[prompt] = answer return result def _run_server( - start_server_event: mp.Event, - seed_queue: mp.Queue, - result_queue: mp.Queue, - ready_to_terminate_event: mp.Event, - config: Config, + start_server_event: mp.Event, + seed_queue: mp.Queue, + result_queue: mp.Queue, + ready_to_terminate_event: mp.Event, + config: Config, ): start_server_event.wait() @@ -112,11 +121,7 @@ def terminate(): class BaseSolver(abc.ABC): - def __init__( - self, - provider: BaseLLMProvider, - config: Config - ): + def __init__(self, provider: BaseLLMProvider, config: Config): self.provider = provider self.config = config @@ -148,7 +153,6 @@ def run(self): ... class CLISolver(BaseSolver): - def run(self): self.config.output_dir.mkdir(parents=True, exist_ok=True) @@ -169,7 +173,9 @@ def __init__(self, *args, **kwargs): def save_output_file(self, responses: Dict[str, str], output_file: pathlib.Path): response_body = json.dumps(responses, indent=2, sort_keys=True).encode() - self.response_hashes[output_file.as_posix()] = hashlib.sha256(response_body).hexdigest() + self.response_hashes[output_file.as_posix()] = hashlib.sha256( + response_body + ).hexdigest() pathlib.Path(output_file).write_bytes(response_body) def run(self): @@ -181,7 +187,7 @@ def run(self): self.result_queue, self.ready_to_terminate_event, self.config, - ) + ), ) process.start() diff --git a/src/compute_horde_prompt_solver/run.py b/src/compute_horde_prompt_solver/run.py index c42959f..c5d1405 100644 --- a/src/compute_horde_prompt_solver/run.py +++ b/src/compute_horde_prompt_solver/run.py @@ -1,3 +1,5 @@ +import json +import pathlib from .prompt_solver import CLISolver, GPULLMProvider, MockLLMProvider, HttpSolver from .config import parse_arguments diff --git a/tests/conftest.py b/tests/conftest.py index c528c13..6fd3a40 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,6 @@ def input_file() -> str: if os.path.isfile(tmp_path): os.remove(tmp_path) - output_file = tmp_path + '.json' + output_file = tmp_path + ".json" if os.path.isfile(output_file): os.remove(output_file) diff --git a/tests/integration_mock/test_mock.py b/tests/integration_mock/test_mock.py index 662ed8d..ab8fc4d 100644 --- a/tests/integration_mock/test_mock.py +++ b/tests/integration_mock/test_mock.py @@ -24,18 +24,25 @@ def test_cli(input_file, seed, expected_output_file): sys.executable, "-m", "src.compute_horde_prompt_solver", - "--temperature", "0.5", - "--top-p", "0.8", - "--max-tokens", "256", - "--seed", seed, - "--output-dir", tempfile.gettempdir(), + "--temperature", + "0.5", + "--top-p", + "0.8", + "--max-tokens", + "256", + "--seed", + seed, + "--output-dir", + tempfile.gettempdir(), "--mock", input_file, ], timeout=TIMEOUT, ) - expected = (pathlib.Path(__file__).parent.parent / "payload" / expected_output_file).read_text() - actual = pathlib.Path(input_file + '.json').read_text() + expected = ( + pathlib.Path(__file__).parent.parent / "payload" / expected_output_file + ).read_text() + actual = pathlib.Path(input_file + ".json").read_text() assert expected == actual @@ -50,7 +57,9 @@ def get_url_within_time(url, timeout=TIMEOUT): except (requests.HTTPError, requests.ConnectionError): pass - time.sleep(0.5) # Wait a bit before trying again to not overload the server and your machine. + time.sleep( + 0.5 + ) # Wait a bit before trying again to not overload the server and your machine. raise TimeoutError(f"Could not get data from {url} within {timeout} seconds") @@ -68,28 +77,35 @@ def test_http(input_file, seed, expected_output_file): sys.executable, "-m", "src.compute_horde_prompt_solver", - "--temperature", "0.5", - "--top-p", "0.8", - "--max-tokens", "256", - "--output-dir", tempfile.gettempdir(), + "--temperature", + "0.5", + "--top-p", + "0.8", + "--max-tokens", + "256", + "--output-dir", + tempfile.gettempdir(), "--mock", "--server", input_file, ], ) try: - base_url = 'http://localhost:8000/' - get_url_within_time(base_url + 'health') + base_url = "http://localhost:8000/" + get_url_within_time(base_url + "health") - import time - with requests.post(base_url + 'execute-job', json={"seed": seed}) as resp: + with requests.post(base_url + "execute-job", json={"seed": seed}) as resp: resp.raise_for_status() hashes = resp.json() try: - requests.get(base_url + 'terminate') - except: + requests.get(base_url + "terminate") + except Exception: pass - assert hashes == {input_file + '.json': hashlib.sha256(pathlib.Path(input_file + '.json').read_bytes()).hexdigest()} + assert hashes == { + input_file + ".json": hashlib.sha256( + pathlib.Path(input_file + ".json").read_bytes() + ).hexdigest() + } finally: server.terminate() server.wait() diff --git a/tests/integration_real_llm/test_real_llm.py b/tests/integration_real_llm/test_real_llm.py index 8d3877e..6b3e7e5 100644 --- a/tests/integration_real_llm/test_real_llm.py +++ b/tests/integration_real_llm/test_real_llm.py @@ -12,11 +12,14 @@ TIMEOUT = 180 -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def download_model(): snapshot_download( repo_id="microsoft/Phi-3.5-mini-instruct", - local_dir=pathlib.Path(__file__).parent.parent.parent / "src" / "compute_horde_prompt_solver" / "saved_model/", + local_dir=pathlib.Path(__file__).parent.parent.parent + / "src" + / "compute_horde_prompt_solver" + / "saved_model/", revision="cd6881a82d62252f5a84593c61acf290f15d89e3", ) @@ -34,17 +37,24 @@ def test_cli(input_file, seed, expected_output_file): sys.executable, "-m", "src.compute_horde_prompt_solver", - "--temperature", "0.5", - "--top-p", "0.8", - "--max-tokens", "256", - "--seed", seed, - "--output-dir", tempfile.gettempdir(), + "--temperature", + "0.5", + "--top-p", + "0.8", + "--max-tokens", + "256", + "--seed", + seed, + "--output-dir", + tempfile.gettempdir(), input_file, ], timeout=TIMEOUT, ) - expected = (pathlib.Path(__file__).parent.parent / "payload" / expected_output_file).read_text() - actual = pathlib.Path(input_file + '.json').read_text() + expected = ( + pathlib.Path(__file__).parent.parent / "payload" / expected_output_file + ).read_text() + actual = pathlib.Path(input_file + ".json").read_text() assert expected == actual @@ -59,7 +69,9 @@ def get_url_within_time(url, timeout=TIMEOUT): except (requests.HTTPError, requests.ConnectionError): pass - time.sleep(0.5) # Wait a bit before trying again to not overload the server and your machine. + time.sleep( + 0.5 + ) # Wait a bit before trying again to not overload the server and your machine. raise TimeoutError(f"Could not get data from {url} within {timeout} seconds") @@ -77,27 +89,34 @@ def test_http(input_file, seed, expected_output_file): sys.executable, "-m", "src.compute_horde_prompt_solver", - "--temperature", "0.5", - "--top-p", "0.8", - "--max-tokens", "256", - "--output-dir", tempfile.gettempdir(), + "--temperature", + "0.5", + "--top-p", + "0.8", + "--max-tokens", + "256", + "--output-dir", + tempfile.gettempdir(), "--server", input_file, ], ) try: - base_url = 'http://localhost:8000/' - get_url_within_time(base_url + 'health') + base_url = "http://localhost:8000/" + get_url_within_time(base_url + "health") - import time - with requests.post(base_url + 'execute-job', json={"seed": seed}) as resp: + with requests.post(base_url + "execute-job", json={"seed": seed}) as resp: resp.raise_for_status() hashes = resp.json() try: - requests.get(base_url + 'terminate') - except: + requests.get(base_url + "terminate") + except Exception: pass - assert hashes == {input_file + '.json': hashlib.sha256(pathlib.Path(input_file + '.json').read_bytes()).hexdigest()} + assert hashes == { + input_file + ".json": hashlib.sha256( + pathlib.Path(input_file + ".json").read_bytes() + ).hexdigest() + } finally: server.terminate() server.wait()