-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added quickstart tutorial for self-hosted models (#191)
* feat: added docker-compose and test for llama3:8b via ollama * feat: added quick-start tutorial for self-hosted models * chore: formatting fix * fix: fixed ci test for ollama model * chore: bumped the default Ollama model to Llama 3.1 * chore: fixed doc * feat: added instructions for self-hosted vision models * fix: updated env for ollama ci test * chore: divided a single quickstart ci job into separate jobs * chore: simplified names of the ci jobs * fix: migrated setup.py script for ollama to httpx * feat: added ci test for self-hosted embedding model * feat: used .env file instead of env vars in self-hosted model tutorial * fix: increased timeout in the ollama setup script * review * feat: added progress bar for model downloading --------- Co-authored-by: sr-remsha <sr.remsha@gmail.com>
- Loading branch information
Showing
25 changed files
with
561 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
DIAL_DIR="./ollama" | ||
OLLAMA_CHAT_MODEL=llama3.1:8b-instruct-q4_0 | ||
OLLAMA_VISION_MODEL=llava-phi3:3.8b-mini-q4_0 | ||
OLLAMA_EMBEDDING_MODEL=nomic-embed-text:137m-v1.5-fp16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
include: | ||
- path: ../../ollama/docker-compose.yml | ||
env_file: ./.env | ||
|
||
services: | ||
test: | ||
build: test | ||
environment: | ||
DIAL_URL: "http://core:8080" | ||
DIAL_API_KEY: "dial_api_key" | ||
DIAL_API_VERSION: "2024-02-01" | ||
depends_on: | ||
ollama-setup: | ||
condition: service_healthy | ||
core: | ||
condition: service_healthy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Dockerfile |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FROM python:3.11-alpine | ||
|
||
WORKDIR /app | ||
COPY * /app | ||
RUN pip install -r requirements.txt | ||
|
||
CMD ["python", "app.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import base64 | ||
import os | ||
from pathlib import Path | ||
from typing import Any | ||
import aiohttp | ||
import asyncio | ||
import backoff | ||
|
||
import logging | ||
import time | ||
from contextlib import asynccontextmanager | ||
|
||
|
||
def get_env(name: str) -> str: | ||
value = os.environ.get(name) | ||
if value is None: | ||
raise ValueError(f"'{name}' environment variable must be defined") | ||
return value | ||
|
||
|
||
DIAL_URL = get_env("DIAL_URL") | ||
DIAL_API_KEY = get_env("DIAL_API_KEY") | ||
DIAL_API_VERSION = get_env("DIAL_API_VERSION") | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
@asynccontextmanager | ||
async def timer(name: str): | ||
log.debug(f"[{name}] Starting...") | ||
start = time.perf_counter() | ||
yield | ||
elapsed = time.perf_counter() - start | ||
log.debug(f"[{name}] Executed in {elapsed:.2f} seconds") | ||
|
||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
(aiohttp.ClientError, aiohttp.ServerTimeoutError), | ||
max_time=60, | ||
) | ||
async def post_with_retry(url: str, payload: dict, headers: dict, params: dict): | ||
async with aiohttp.ClientSession() as session: | ||
async with session.post( | ||
url, json=payload, headers=headers, params=params | ||
) as response: | ||
response.raise_for_status() | ||
return await response.json() | ||
|
||
|
||
def read_image_base64(png_file: Path) -> str: | ||
return base64.b64encode(png_file.read_bytes()).decode("utf-8") | ||
|
||
async def dial_chat_completion(deployment_id: str, messages: list) -> str: | ||
api_url = f"{DIAL_URL}/openai/deployments/{deployment_id}/chat/completions" | ||
|
||
payload = { | ||
"model": deployment_id, | ||
"messages": messages, | ||
"stream": False, | ||
} | ||
headers = {"api-key": DIAL_API_KEY} | ||
params = {"api-version": DIAL_API_VERSION} | ||
|
||
body = await post_with_retry(api_url, payload, headers, params) | ||
log.debug(f"Response: {body}") | ||
|
||
content = body.get("choices", [])[0].get("message", {}).get("content", "") | ||
|
||
log.debug(f"Content: {content}") | ||
|
||
return content | ||
|
||
async def dial_embeddings(deployment_id: str, input: Any) -> str: | ||
api_url = f"{DIAL_URL}/openai/deployments/{deployment_id}/embeddings" | ||
|
||
payload = { | ||
"model": deployment_id, | ||
"input": input, | ||
} | ||
headers = {"api-key": DIAL_API_KEY} | ||
params = {"api-version": DIAL_API_VERSION} | ||
|
||
body = await post_with_retry(api_url, payload, headers, params) | ||
log.debug(f"Response: {body}") | ||
|
||
embedding = body.get("data", [])[0].get("embedding", []) | ||
|
||
log.debug(f"Len embedding vector: {len(embedding)}") | ||
|
||
return embedding | ||
|
||
async def test_chat_model(deployment_id: str): | ||
message = "2 + 3 = ? Reply with a single number:" | ||
messages = [{"role": "user", "content": message}] | ||
content = await dial_chat_completion(deployment_id, messages) | ||
|
||
if "5" not in content: | ||
raise ValueError(f"Test failed for {deployment_id!r}") | ||
|
||
|
||
async def test_vision_model(deployment_id: str): | ||
base64_data = read_image_base64(Path("./image.png")) | ||
base64_image = f"data:image/png;base64,{base64_data}" | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "Describe the image"}, | ||
{"type": "image_url", "image_url": {"url": base64_image}}, | ||
], | ||
} | ||
] | ||
|
||
content = await dial_chat_completion(deployment_id, messages) | ||
|
||
if "vision" not in content.lower(): | ||
raise ValueError(f"Test failed for {deployment_id!r}") | ||
|
||
async def test_embedding_model(deployment_id: str): | ||
embeddings = await dial_embeddings(deployment_id, "cat") | ||
|
||
if len(embeddings) == 0 or not isinstance(embeddings[0], float): | ||
raise ValueError(f"Test failed for {deployment_id!r}") | ||
|
||
|
||
async def tests(): | ||
async with timer("Testing chat-model"): | ||
await test_chat_model("chat-model") | ||
|
||
async with timer("Testing vision-model"): | ||
await test_vision_model("vision-model") | ||
|
||
async with timer("Testing embedding-model"): | ||
await test_embedding_model("embedding-model") | ||
|
||
if __name__ == "__main__": | ||
loop = asyncio.get_event_loop() | ||
loop.run_until_complete(tests()) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
aiohttp==3.9.4 | ||
backoff==2.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.