Skip to content

Commit

Permalink
Merge branch 'main' into autoquant
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil authored Oct 14, 2024
2 parents b9c59b5 + 0f1b786 commit 55a8b0e
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 38 deletions.
23 changes: 23 additions & 0 deletions libs/infinity_emb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket

import pytest
import requests
from sentence_transformers import InputExample, util # type: ignore

pytest.DEFAULT_BERT_MODEL = "michaelfeil/bge-small-en-v1.5"
Expand All @@ -23,6 +24,28 @@ def anyio_backend():
return "asyncio"


def _download(url: str, **kwargs) -> requests.Response:
for i in range(5):
try:
response = requests.get(url, **kwargs)
if response.status_code == 200:
return response
except Exception:
pass
else:
raise Exception(f"Failed to download {url}")


@pytest.fixture(scope="function")
def audio_sample() -> tuple[requests.Response, str]:
return (_download(pytest.AUDIO_SAMPLE_URL)), pytest.AUDIO_SAMPLE_URL # type: ignore


@pytest.fixture(scope="function")
def image_sample() -> tuple[requests.Response, str]:
return (_download(pytest.IMAGE_SAMPLE_URL, stream=True)), pytest.IMAGE_SAMPLE_URL # type: ignore


def internet_available():
try:
# Attempt to connect to a well-known public DNS server (Google's)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ async def client():

def url_to_base64(url, modality="image"):
"""small helper to convert url to base64 without server requiring access to the url"""
response = requests.get(url)
for i in range(3):
try:
response = requests.get(url)
if response.status_code == 200:
break
except Exception:
pass
else:
raise Exception(f"Failed to download {url}")
response.raise_for_status()
base64_encoded = base64.b64encode(response.content).decode("utf-8")
mimetype = f"{modality}/{url.split('.')[-1]}"
Expand Down
11 changes: 5 additions & 6 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
import requests
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
Expand Down Expand Up @@ -144,8 +143,8 @@ async def test_audio_multiple(client):


@pytest.mark.anyio
async def test_audio_base64(client):
bytes_downloaded = requests.get(pytest.AUDIO_SAMPLE_URL).content
async def test_audio_base64(client, audio_sample):
bytes_downloaded = audio_sample[0].content
base_64_audio = base64.b64encode(bytes_downloaded).decode("utf-8")

response = await client.post(
Expand All @@ -154,7 +153,7 @@ async def test_audio_base64(client):
"model": MODEL,
"input": [
"data:audio/wav;base64," + base_64_audio,
pytest.AUDIO_SAMPLE_URL,
audio_sample[1],
],
},
)
Expand All @@ -166,8 +165,8 @@ async def test_audio_base64(client):
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0

np.testing.assert_array_equal(
rdata_results[0]["embedding"], rdata_results[1]["embedding"]
np.testing.assert_array_almost_equal(
rdata_results[0]["embedding"], rdata_results[1]["embedding"], decimal=4
)


Expand Down
9 changes: 4 additions & 5 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
import requests
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
Expand Down Expand Up @@ -84,8 +83,8 @@ async def test_vision_single_text_only(client):


@pytest.mark.anyio
async def test_vision_base64(client):
bytes_downloaded = requests.get(pytest.IMAGE_SAMPLE_URL).content
async def test_vision_base64(client, image_sample):
bytes_downloaded = image_sample[0].content
base_64_image = base64.b64encode(bytes_downloaded).decode("utf-8")

response = await client.post(
Expand All @@ -106,8 +105,8 @@ async def test_vision_base64(client):
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0

np.testing.assert_array_equal(
rdata_results[0]["embedding"], rdata_results[1]["embedding"]
np.testing.assert_array_almost_equal(
rdata_results[0]["embedding"], rdata_results[1]["embedding"], decimal=4
)


Expand Down
20 changes: 8 additions & 12 deletions libs/infinity_emb/tests/unit_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import pytest
import requests
import torch
from PIL import Image
from sentence_transformers import CrossEncoder # type: ignore[import-untyped]
Expand Down Expand Up @@ -217,13 +216,13 @@ async def test_torch_clip_embed():


@pytest.mark.anyio
async def test_clap_like_model():
model_name = "laion/clap-htsat-unfused"
async def test_clap_like_model(audio_sample):
model_name = pytest.DEFAULT_AUDIO_MODEL
engine = AsyncEmbeddingEngine.from_args(
EngineArgs(model_name_or_path=model_name, dtype="float32")
)
url = pytest.AUDIO_SAMPLE_URL
bytes_url = requests.get(url).content
url = audio_sample[1]
bytes_url = audio_sample[0].content

inputs = ["a sound of a cat", "a sound of a cat"]
audios = [url, bytes_url]
Expand All @@ -240,11 +239,8 @@ async def test_clap_like_model():


@pytest.mark.anyio
async def test_clip_embed_pil_image_input():
response = requests.get(pytest.IMAGE_SAMPLE_URL, stream=True)

assert response.status_code == 200
img_data = response.raw
async def test_clip_embed_pil_image_input(image_sample):
img_data = image_sample[0].raw
img_obj = Image.open(img_data)
images = [img_obj] # a photo of two cats
sentences = [
Expand All @@ -255,7 +251,7 @@ async def test_clip_embed_pil_image_input():
]
engine = AsyncEmbeddingEngine.from_args(
EngineArgs(
model_name_or_path="wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M",
model_name_or_path=pytest.DEFAULT_IMAGE_MODEL,
engine=InferenceEngine.torch,
model_warmup=True,
)
Expand Down Expand Up @@ -301,7 +297,7 @@ async def test_async_api_torch_embedding_quant(embedding_dtype: EmbeddingDtype):
device = "cpu"
engine = AsyncEmbeddingEngine.from_args(
EngineArgs(
model_name_or_path="michaelfeil/bge-small-en-v1.5",
model_name_or_path=pytest.DEFAULT_BERT_MODEL, # type: ignore
engine=InferenceEngine.torch,
device=Device[device],
lengths_via_tokenize=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io

import numpy as np
import requests # type: ignore
import pytest
import soundfile as sf # type: ignore
import torch
from transformers import ClapModel, ClapProcessor # type: ignore
Expand All @@ -10,13 +10,10 @@
from infinity_emb.transformer.audio.torch import ClapLikeModel


def test_clap_like_model():
model_name = "laion/clap-htsat-unfused"
model = ClapLikeModel(
engine_args=EngineArgs(model_name_or_path=model_name, dtype="float16")
)
url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
raw_bytes = requests.get(url, stream=True).content
def test_clap_like_model(audio_sample):
model_name = pytest.DEFAULT_AUDIO_MODEL
model = ClapLikeModel(engine_args=EngineArgs(model_name_or_path=model_name))
raw_bytes = audio_sample[0].content
data, samplerate = sf.read(io.BytesIO(raw_bytes))

assert samplerate == model.sampling_rate
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import requests # type: ignore
import pytest
import torch
from PIL import Image # type: ignore
from transformers import CLIPModel, CLIPProcessor # type: ignore
Expand All @@ -8,13 +8,12 @@
from infinity_emb.transformer.vision.torch_vision import ClipLikeModel


def test_clip_like_model():
model_name = "openai/clip-vit-base-patch32"
def test_clip_like_model(image_sample):
model_name = pytest.DEFAULT_IMAGE_MODEL
model = ClipLikeModel(
engine_args=EngineArgs(model_name_or_path=model_name, dtype="float16")
engine_args=EngineArgs(model_name_or_path=model_name, dtype="auto")
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image = Image.open(image_sample[0].raw)

inputs = [
"a photo of a cat",
Expand Down

0 comments on commit 55a8b0e

Please sign in to comment.