-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7697a2b
commit 53a5db3
Showing
4 changed files
with
48 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import vectorlite_py | ||
import apsw | ||
import pytest | ||
import numpy as np | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def conn() -> None: | ||
conn = apsw.Connection(':memory:') | ||
conn.enable_load_extension(True) | ||
conn.load_extension(vectorlite_py.vectorlite_path()) | ||
return conn | ||
|
||
DIM = 32 | ||
NUM_ELEMENTS = 1000 | ||
|
||
# Generating sample data | ||
@pytest.fixture(scope='module') | ||
def random_vectors(): | ||
return np.float32(np.random.random((NUM_ELEMENTS, DIM))) | ||
|
||
def test_vectorlite_info(conn): | ||
cur = conn.cursor() | ||
cur.execute('select vectorlite_info()') | ||
output = cur.fetchone() | ||
assert f'vectorlite extension version {vectorlite_py.__version__}' in output[0] | ||
|
||
def test_l2_space_with_default_hnsw_param(conn, random_vectors): | ||
cur = conn.cursor() | ||
cur.execute('create virtual table x using vectorlite(my_embedding(32, "l2"), hnsw(max_elements=1000))') | ||
|
||
for i in range(NUM_ELEMENTS): | ||
cur.execute('insert into x (rowid, my_embedding) values (?, ?)', (i, random_vectors[i].tobytes())) | ||
|
||
result = cur.execute('select my_embedding from x where rowid = 0').fetchone() | ||
assert result[0] == random_vectors[0].tobytes() | ||
|
||
cur.execute('delete from x where rowid = 0') | ||
result = cur.execute('select my_embedding from x where rowid = 0').fetchone() | ||
assert result is None | ||
|
||
result = cur.execute('select rowid, distance from x where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() | ||
assert len(result) == 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters