Skip to content

Commit

Permalink
Add support for pgvector
Browse files Browse the repository at this point in the history
  • Loading branch information
MaskerPRC committed Sep 23, 2024
1 parent c0d1a08 commit 5bc01f3
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 124 deletions.
22 changes: 21 additions & 1 deletion infra_ai_service.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,29 @@ README.md
setup.py
infra_ai_service/__init__.py
infra_ai_service/demo.py
infra_ai_service/server.py
infra_ai_service.egg-info/PKG-INFO
infra_ai_service.egg-info/SOURCES.txt
infra_ai_service.egg-info/dependency_links.txt
infra_ai_service.egg-info/top_level.txt
tests/__init__.py
infra_ai_service/api/__init__.py
infra_ai_service/api/router.py
infra_ai_service/api/ai_enhance/__init__.py
infra_ai_service/api/ai_enhance/embedding.py
infra_ai_service/api/ai_enhance/text_process.py
infra_ai_service/api/ai_enhance/vector_search.py
infra_ai_service/common/__init__.py
infra_ai_service/common/utils.py
infra_ai_service/config/__init__.py
infra_ai_service/config/config.py
infra_ai_service/core/__init__.py
infra_ai_service/core/app.py
infra_ai_service/model/__init__.py
infra_ai_service/model/model.py
infra_ai_service/sdk/__init__.py
infra_ai_service/sdk/pgvector.py
infra_ai_service/service/__init__.py
infra_ai_service/service/embedding_service.py
infra_ai_service/service/search_service.py
infra_ai_service/service/text_service.py
tests/test_demo.py
1 change: 0 additions & 1 deletion infra_ai_service.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
infra_ai_service
tests
11 changes: 1 addition & 10 deletions infra_ai_service/api/ai_enhance/embedding.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
from fastapi import APIRouter

from infra_ai_service.api.ai_enhance.text_process import TextInput
from infra_ai_service.api.common.utils import setup_qdrant_environment

from infra_ai_service.model.model import EmbeddingOutput
from infra_ai_service.service.embedding_service import create_embedding, \
get_collection_status
from infra_ai_service.service.embedding_service import create_embedding

router = APIRouter()


@router.post("/embed/", response_model=EmbeddingOutput)
async def embed_text(input_data: TextInput):
return await create_embedding(input_data.content)


@router.get("/status/")
async def status():
return await get_collection_status()
29 changes: 0 additions & 29 deletions infra_ai_service/api/common/utils.py

This file was deleted.

File renamed without changes.
20 changes: 20 additions & 0 deletions infra_ai_service/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import HTTPException
from fastembed.embedding import DefaultEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams


async def setup_database(pool):
async with pool.connection() as conn:
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
await conn.execute('''
CREATE TABLE IF NOT EXISTS documents (
id bigserial PRIMARY KEY,
content text,
embedding vector(384)
)
''')
await conn.execute('''
CREATE INDEX IF NOT EXISTS documents_content_idx
ON documents USING GIN (to_tsvector('english', content))
''')
6 changes: 6 additions & 0 deletions infra_ai_service/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from fastapi import FastAPI
from fastapi.responses import UJSONResponse

from infra_ai_service.sdk.pgvector import setup_model_and_pool


def get_app() -> FastAPI:
"""
Expand All @@ -23,4 +25,8 @@ def get_app() -> FastAPI:

app.include_router(router=api_router, prefix="/api")

@app.on_event("startup")
async def startup_event():
await setup_model_and_pool()

return app
22 changes: 22 additions & 0 deletions infra_ai_service/sdk/pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from sentence_transformers import SentenceTransformer
from psycopg_pool import AsyncConnectionPool
from infra_ai_service.common.utils import setup_database

# 初始化模型
model = None
# 创建连接池(暂时不初始化)
pool = None


async def setup_model_and_pool():
global model, pool
# 初始化模型
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
# 创建异步连接池
pool = AsyncConnectionPool(
"dbname=pgvector user=postgres password=postgres "
"host=localhost port=5432",
open=True
)
# 设置数据库
await setup_database(pool)
3 changes: 0 additions & 3 deletions infra_ai_service/sdk/qdrant.py

This file was deleted.

65 changes: 28 additions & 37 deletions infra_ai_service/service/embedding_service.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,39 @@
from fastapi import HTTPException
import uuid

from infra_ai_service.model.model import PointStruct, EmbeddingOutput
from infra_ai_service.sdk.qdrant import fastembed_model, qdrant_client, \
collection_name
import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import HTTPException
from infra_ai_service.sdk import pgvector
from infra_ai_service.model.model import EmbeddingOutput


async def create_embedding(content):
try:
embeddings = list(fastembed_model.embed([content]))
if not embeddings:
# 确保模型已初始化
if pgvector.model is None:
raise HTTPException(status_code=500,
detail="Failed to generate embedding")

embedding_vector = embeddings[0]
point_id = str(uuid.uuid4())

qdrant_client.upsert(
collection_name=collection_name,
points=[
PointStruct(
id=point_id,
vector=embedding_vector.tolist(),
payload={"text": content}
detail="Model is not initialized")

# 使用线程池执行同步的嵌入计算
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as pool_executor:
embedding_vector = await loop.run_in_executor(
pool_executor, pgvector.model.encode, [content]
)
embedding_vector = embedding_vector[0]

# 将 ndarray 转换为列表
embedding_vector_list = embedding_vector.tolist()

# 从连接池获取连接
async with pgvector.pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(
'INSERT INTO documents (content, embedding) VALUES (%s, %s) RETURNING id',
(content, embedding_vector_list) # 使用转换后的列表
)
]
)
point_id = (await cur.fetchone())[0]

return EmbeddingOutput(id=point_id,
embedding=embedding_vector.tolist())
return EmbeddingOutput(id=point_id, embedding=embedding_vector_list)
except Exception as e:
raise HTTPException(status_code=400,
detail=f"Error processing embedding: {e}")


async def get_collection_status():
try:
collection_info = qdrant_client.get_collection(collection_name)
return {
"collection_name": collection_name,
"vectors_count": collection_info.vectors_count,
"status": "ready" if collection_info.status == "green"
else "not ready"
}
except Exception as e:
raise HTTPException(status_code=400,
detail=f"Error getting collection status: {e}")
91 changes: 50 additions & 41 deletions infra_ai_service/service/search_service.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,62 @@
from fastapi import HTTPException
import logging
# infraAIService/infra_ai_service/service/search_service.py

from infra_ai_service.model.model import SearchOutput, SearchResult, \
SearchInput
from infra_ai_service.sdk.qdrant import qdrant_client, collection_name, \
fastembed_model
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from fastapi import HTTPException
from infra_ai_service.sdk import pgvector
from infra_ai_service.model.model import SearchInput, SearchOutput, SearchResult

logger = logging.getLogger(__name__)


async def perform_vector_search(input_data: SearchInput):
try:
# 检查集合是否存在
collection_info = qdrant_client.get_collection(collection_name)
if not collection_info:
logger.error(f"Collection {collection_name} does not exist")
raise HTTPException(status_code=404,
detail=f"Collection {collection_name} does "
f"not exist")

# 生成查询文本的嵌入
query_vector = list(fastembed_model.embed([input_data.query_text]))
if not query_vector:
logger.error("Failed to generate query embedding")
raise HTTPException(status_code=500,
detail="Failed to generate query embedding")

# 执行向量搜索
search_results = qdrant_client.search(
collection_name=collection_name,
query_vector=query_vector[0],
limit=input_data.top_n,
score_threshold=input_data.score_threshold
)
# 确保模型已初始化
if pgvector.model is None:
logger.error("模型未初始化")
raise HTTPException(status_code=500, detail="模型未初始化")

# 生成查询文本的嵌入向量
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as pool_executor:
embedding_vector = await loop.run_in_executor(
pool_executor, pgvector.model.encode, [input_data.query_text]
)
embedding_vector = embedding_vector[0]

# 将 ndarray 转换为列表
embedding_vector_list = embedding_vector.tolist()

# 从连接池获取连接
async with pgvector.pool.connection() as conn:
async with conn.cursor() as cur:
# 执行向量搜索查询,显式转换参数为 vector 类型
await cur.execute(
"""
SELECT id, content, embedding, 1 - (embedding <#> %s::vector) AS similarity
FROM documents
ORDER BY similarity DESC
LIMIT %s
""",
(embedding_vector_list, input_data.top_n)
)
rows = await cur.fetchall()

# 转换搜索结果为输出格式
results = [
SearchResult(
id=str(result.id),
score=result.score,
text=result.payload.get('text', 'No text available')
)
for result in search_results
]
results = []
for row in rows:
similarity = row[3] # 相似度得分
if similarity >= input_data.score_threshold:
results.append(
SearchResult(
id=str(row[0]),
score=similarity,
text=row[1] # 内容
)
)

return SearchOutput(results=results)
except Exception as e:
logger.error(f"Error performing vector search: {str(e)}",
exc_info=True)
raise HTTPException(status_code=500,
detail=f"Error performing vector search: "
f"{str(e)}")
logger.error(f"执行向量搜索时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"执行向量搜索时出错: {str(e)}")
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ requests==2.31.0
httpx==0.23.0
pydantic==1.10.12
fastembed==0.3.6
qdrant-client==1.11.1
setuptools~=74.1.2
psycopg~=3.2.1
pgvector~=0.3.3
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ deps =
pytest_asyncio
asyncpg
fastapi
qdrant_client
fastembed
commands =
pytest tests/ --cov=infra_ai_service --cov-report=term-missing
Expand All @@ -28,6 +27,7 @@ commands =
[testenv:coverage]
deps =
coverage
{[testenv]deps}
commands =
coverage report --fail-under=0
coverage html
Expand Down

0 comments on commit 5bc01f3

Please sign in to comment.