Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
andreea-popescu-reef committed Dec 17, 2024
1 parent f96fb0f commit 29ce2e6
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 62 deletions.
5 changes: 3 additions & 2 deletions src/compute_horde_prompt_solver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def parse_arguments() -> Config:
"--top-p", type=float, default=0.1, help="Top-p sampling parameter"
)
parser.add_argument(
"--dtype", default="auto",
"--dtype",
default="auto",
choices=("auto", "half", "float16", "bfloat16", "float", "float32"),
help=(
"model dtype - setting `float32` helps with deterministic prompts in different batches"
)
),
)

seed_or_server_group = parser.add_mutually_exclusive_group(required=True)
Expand Down
44 changes: 25 additions & 19 deletions src/compute_horde_prompt_solver/prompt_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

class BaseLLMProvider(abc.ABC):
@abc.abstractmethod
def generate_responses(self, prompts: List[str], sampling_params: SamplingParams) -> Dict[str, str]: ...
def generate_responses(
self, prompts: List[str], sampling_params: SamplingParams
) -> Dict[str, str]: ...


class GPULLMProvider(BaseLLMProvider):
Expand All @@ -44,37 +46,44 @@ def setup_model(self) -> vllm.LLM:

def make_prompt(self, prompt: str) -> str:
system_msg = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ You are a helpful AI assistant }}}}<|eot_id|>"
user_msg = f"<|start_header_id|>user<|end_header_id|>\n{{{{ {prompt} }}}}<|eot_id|>"
user_msg = (
f"<|start_header_id|>user<|end_header_id|>\n{{{{ {prompt} }}}}<|eot_id|>"
)
assistant_start = "<|start_header_id|>assistant<|end_header_id|>"
return f"{system_msg}{user_msg}{assistant_start}"

def generate_responses(
self, prompts: List[str], sampling_params: SamplingParams
self, prompts: List[str], sampling_params: SamplingParams
) -> Dict[str, str]:
requests = [self.make_prompt(prompt) for prompt in prompts]
responses = self.model.generate(requests, sampling_params, use_tqdm=True)
return {
prompt: response.outputs[0].text for prompt, response in zip(prompts, responses)
prompt: response.outputs[0].text
for prompt, response in zip(prompts, responses)
}


class MockLLMProvider(BaseLLMProvider):
def generate_responses(self, prompts: List[str], sampling_params: SamplingParams) -> Dict[str, str]:
def generate_responses(
self, prompts: List[str], sampling_params: SamplingParams
) -> Dict[str, str]:
result = {}
for prompt in prompts:
generator = random.Random(str(sampling_params.seed) + prompt)
answer_length = generator.randint(100, 200)
answer = ''.join(generator.choice(string.ascii_letters) for _ in range(answer_length))
answer = "".join(
generator.choice(string.ascii_letters) for _ in range(answer_length)
)
result[prompt] = answer
return result


def _run_server(
start_server_event: mp.Event,
seed_queue: mp.Queue,
result_queue: mp.Queue,
ready_to_terminate_event: mp.Event,
config: Config,
start_server_event: mp.Event,
seed_queue: mp.Queue,
result_queue: mp.Queue,
ready_to_terminate_event: mp.Event,
config: Config,
):
start_server_event.wait()

Expand Down Expand Up @@ -112,11 +121,7 @@ def terminate():


class BaseSolver(abc.ABC):
def __init__(
self,
provider: BaseLLMProvider,
config: Config
):
def __init__(self, provider: BaseLLMProvider, config: Config):
self.provider = provider
self.config = config

Expand Down Expand Up @@ -148,7 +153,6 @@ def run(self): ...


class CLISolver(BaseSolver):

def run(self):
self.config.output_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -169,7 +173,9 @@ def __init__(self, *args, **kwargs):

def save_output_file(self, responses: Dict[str, str], output_file: pathlib.Path):
response_body = json.dumps(responses, indent=2, sort_keys=True).encode()
self.response_hashes[output_file.as_posix()] = hashlib.sha256(response_body).hexdigest()
self.response_hashes[output_file.as_posix()] = hashlib.sha256(
response_body
).hexdigest()
pathlib.Path(output_file).write_bytes(response_body)

def run(self):
Expand All @@ -181,7 +187,7 @@ def run(self):
self.result_queue,
self.ready_to_terminate_event,
self.config,
)
),
)
process.start()

Expand Down
2 changes: 2 additions & 0 deletions src/compute_horde_prompt_solver/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import pathlib
from .prompt_solver import CLISolver, GPULLMProvider, MockLLMProvider, HttpSolver
from .config import parse_arguments

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def input_file() -> str:
if os.path.isfile(tmp_path):
os.remove(tmp_path)

output_file = tmp_path + '.json'
output_file = tmp_path + ".json"
if os.path.isfile(output_file):
os.remove(output_file)
54 changes: 35 additions & 19 deletions tests/integration_mock/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,25 @@ def test_cli(input_file, seed, expected_output_file):
sys.executable,
"-m",
"src.compute_horde_prompt_solver",
"--temperature", "0.5",
"--top-p", "0.8",
"--max-tokens", "256",
"--seed", seed,
"--output-dir", tempfile.gettempdir(),
"--temperature",
"0.5",
"--top-p",
"0.8",
"--max-tokens",
"256",
"--seed",
seed,
"--output-dir",
tempfile.gettempdir(),
"--mock",
input_file,
],
timeout=TIMEOUT,
)
expected = (pathlib.Path(__file__).parent.parent / "payload" / expected_output_file).read_text()
actual = pathlib.Path(input_file + '.json').read_text()
expected = (
pathlib.Path(__file__).parent.parent / "payload" / expected_output_file
).read_text()
actual = pathlib.Path(input_file + ".json").read_text()
assert expected == actual


Expand All @@ -50,7 +57,9 @@ def get_url_within_time(url, timeout=TIMEOUT):
except (requests.HTTPError, requests.ConnectionError):
pass

time.sleep(0.5) # Wait a bit before trying again to not overload the server and your machine.
time.sleep(
0.5
) # Wait a bit before trying again to not overload the server and your machine.

raise TimeoutError(f"Could not get data from {url} within {timeout} seconds")

Expand All @@ -68,28 +77,35 @@ def test_http(input_file, seed, expected_output_file):
sys.executable,
"-m",
"src.compute_horde_prompt_solver",
"--temperature", "0.5",
"--top-p", "0.8",
"--max-tokens", "256",
"--output-dir", tempfile.gettempdir(),
"--temperature",
"0.5",
"--top-p",
"0.8",
"--max-tokens",
"256",
"--output-dir",
tempfile.gettempdir(),
"--mock",
"--server",
input_file,
],
)
try:
base_url = 'http://localhost:8000/'
get_url_within_time(base_url + 'health')
base_url = "http://localhost:8000/"
get_url_within_time(base_url + "health")

import time
with requests.post(base_url + 'execute-job', json={"seed": seed}) as resp:
with requests.post(base_url + "execute-job", json={"seed": seed}) as resp:
resp.raise_for_status()
hashes = resp.json()
try:
requests.get(base_url + 'terminate')
except:
requests.get(base_url + "terminate")
except Exception:
pass
assert hashes == {input_file + '.json': hashlib.sha256(pathlib.Path(input_file + '.json').read_bytes()).hexdigest()}
assert hashes == {
input_file + ".json": hashlib.sha256(
pathlib.Path(input_file + ".json").read_bytes()
).hexdigest()
}
finally:
server.terminate()
server.wait()
61 changes: 40 additions & 21 deletions tests/integration_real_llm/test_real_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
TIMEOUT = 180


@pytest.fixture(scope='session', autouse=True)
@pytest.fixture(scope="session", autouse=True)
def download_model():
snapshot_download(
repo_id="microsoft/Phi-3.5-mini-instruct",
local_dir=pathlib.Path(__file__).parent.parent.parent / "src" / "compute_horde_prompt_solver" / "saved_model/",
local_dir=pathlib.Path(__file__).parent.parent.parent
/ "src"
/ "compute_horde_prompt_solver"
/ "saved_model/",
revision="cd6881a82d62252f5a84593c61acf290f15d89e3",
)

Expand All @@ -34,17 +37,24 @@ def test_cli(input_file, seed, expected_output_file):
sys.executable,
"-m",
"src.compute_horde_prompt_solver",
"--temperature", "0.5",
"--top-p", "0.8",
"--max-tokens", "256",
"--seed", seed,
"--output-dir", tempfile.gettempdir(),
"--temperature",
"0.5",
"--top-p",
"0.8",
"--max-tokens",
"256",
"--seed",
seed,
"--output-dir",
tempfile.gettempdir(),
input_file,
],
timeout=TIMEOUT,
)
expected = (pathlib.Path(__file__).parent.parent / "payload" / expected_output_file).read_text()
actual = pathlib.Path(input_file + '.json').read_text()
expected = (
pathlib.Path(__file__).parent.parent / "payload" / expected_output_file
).read_text()
actual = pathlib.Path(input_file + ".json").read_text()
assert expected == actual


Expand All @@ -59,7 +69,9 @@ def get_url_within_time(url, timeout=TIMEOUT):
except (requests.HTTPError, requests.ConnectionError):
pass

time.sleep(0.5) # Wait a bit before trying again to not overload the server and your machine.
time.sleep(
0.5
) # Wait a bit before trying again to not overload the server and your machine.

raise TimeoutError(f"Could not get data from {url} within {timeout} seconds")

Expand All @@ -77,27 +89,34 @@ def test_http(input_file, seed, expected_output_file):
sys.executable,
"-m",
"src.compute_horde_prompt_solver",
"--temperature", "0.5",
"--top-p", "0.8",
"--max-tokens", "256",
"--output-dir", tempfile.gettempdir(),
"--temperature",
"0.5",
"--top-p",
"0.8",
"--max-tokens",
"256",
"--output-dir",
tempfile.gettempdir(),
"--server",
input_file,
],
)
try:
base_url = 'http://localhost:8000/'
get_url_within_time(base_url + 'health')
base_url = "http://localhost:8000/"
get_url_within_time(base_url + "health")

import time
with requests.post(base_url + 'execute-job', json={"seed": seed}) as resp:
with requests.post(base_url + "execute-job", json={"seed": seed}) as resp:
resp.raise_for_status()
hashes = resp.json()
try:
requests.get(base_url + 'terminate')
except:
requests.get(base_url + "terminate")
except Exception:
pass
assert hashes == {input_file + '.json': hashlib.sha256(pathlib.Path(input_file + '.json').read_bytes()).hexdigest()}
assert hashes == {
input_file + ".json": hashlib.sha256(
pathlib.Path(input_file + ".json").read_bytes()
).hexdigest()
}
finally:
server.terminate()
server.wait()

0 comments on commit 29ce2e6

Please sign in to comment.