forked from jlcoo/infraAIService
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
152 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1 @@ | ||
infra_ai_service | ||
tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,11 @@ | ||
from fastapi import APIRouter | ||
|
||
from infra_ai_service.api.ai_enhance.text_process import TextInput | ||
from infra_ai_service.api.common.utils import setup_qdrant_environment | ||
|
||
from infra_ai_service.model.model import EmbeddingOutput | ||
from infra_ai_service.service.embedding_service import create_embedding, \ | ||
get_collection_status | ||
from infra_ai_service.service.embedding_service import create_embedding | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/embed/", response_model=EmbeddingOutput) | ||
async def embed_text(input_data: TextInput): | ||
return await create_embedding(input_data.content) | ||
|
||
|
||
@router.get("/status/") | ||
async def status(): | ||
return await get_collection_status() |
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from fastapi import HTTPException | ||
from fastembed.embedding import DefaultEmbedding | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http.models import Distance, VectorParams | ||
|
||
|
||
async def setup_database(pool): | ||
async with pool.connection() as conn: | ||
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector') | ||
await conn.execute(''' | ||
CREATE TABLE IF NOT EXISTS documents ( | ||
id bigserial PRIMARY KEY, | ||
content text, | ||
embedding vector(384) | ||
) | ||
''') | ||
await conn.execute(''' | ||
CREATE INDEX IF NOT EXISTS documents_content_idx | ||
ON documents USING GIN (to_tsvector('english', content)) | ||
''') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from sentence_transformers import SentenceTransformer | ||
from psycopg_pool import AsyncConnectionPool | ||
from infra_ai_service.common.utils import setup_database | ||
|
||
# 初始化模型 | ||
model = None | ||
# 创建连接池(暂时不初始化) | ||
pool = None | ||
|
||
|
||
async def setup_model_and_pool(): | ||
global model, pool | ||
# 初始化模型 | ||
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | ||
# 创建异步连接池 | ||
pool = AsyncConnectionPool( | ||
"dbname=pgvector user=postgres password=postgres " | ||
"host=localhost port=5432", | ||
open=True | ||
) | ||
# 设置数据库 | ||
await setup_database(pool) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,39 @@ | ||
from fastapi import HTTPException | ||
import uuid | ||
|
||
from infra_ai_service.model.model import PointStruct, EmbeddingOutput | ||
from infra_ai_service.sdk.qdrant import fastembed_model, qdrant_client, \ | ||
collection_name | ||
import asyncio | ||
from concurrent.futures import ThreadPoolExecutor | ||
from fastapi import HTTPException | ||
from infra_ai_service.sdk import pgvector | ||
from infra_ai_service.model.model import EmbeddingOutput | ||
|
||
|
||
async def create_embedding(content): | ||
try: | ||
embeddings = list(fastembed_model.embed([content])) | ||
if not embeddings: | ||
# 确保模型已初始化 | ||
if pgvector.model is None: | ||
raise HTTPException(status_code=500, | ||
detail="Failed to generate embedding") | ||
|
||
embedding_vector = embeddings[0] | ||
point_id = str(uuid.uuid4()) | ||
|
||
qdrant_client.upsert( | ||
collection_name=collection_name, | ||
points=[ | ||
PointStruct( | ||
id=point_id, | ||
vector=embedding_vector.tolist(), | ||
payload={"text": content} | ||
detail="Model is not initialized") | ||
|
||
# 使用线程池执行同步的嵌入计算 | ||
loop = asyncio.get_running_loop() | ||
with ThreadPoolExecutor() as pool_executor: | ||
embedding_vector = await loop.run_in_executor( | ||
pool_executor, pgvector.model.encode, [content] | ||
) | ||
embedding_vector = embedding_vector[0] | ||
|
||
# 将 ndarray 转换为列表 | ||
embedding_vector_list = embedding_vector.tolist() | ||
|
||
# 从连接池获取连接 | ||
async with pgvector.pool.connection() as conn: | ||
async with conn.cursor() as cur: | ||
await cur.execute( | ||
'INSERT INTO documents (content, embedding) VALUES (%s, %s) RETURNING id', | ||
(content, embedding_vector_list) # 使用转换后的列表 | ||
) | ||
] | ||
) | ||
point_id = (await cur.fetchone())[0] | ||
|
||
return EmbeddingOutput(id=point_id, | ||
embedding=embedding_vector.tolist()) | ||
return EmbeddingOutput(id=point_id, embedding=embedding_vector_list) | ||
except Exception as e: | ||
raise HTTPException(status_code=400, | ||
detail=f"Error processing embedding: {e}") | ||
|
||
|
||
async def get_collection_status(): | ||
try: | ||
collection_info = qdrant_client.get_collection(collection_name) | ||
return { | ||
"collection_name": collection_name, | ||
"vectors_count": collection_info.vectors_count, | ||
"status": "ready" if collection_info.status == "green" | ||
else "not ready" | ||
} | ||
except Exception as e: | ||
raise HTTPException(status_code=400, | ||
detail=f"Error getting collection status: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,62 @@ | ||
from fastapi import HTTPException | ||
import logging | ||
# infraAIService/infra_ai_service/service/search_service.py | ||
|
||
from infra_ai_service.model.model import SearchOutput, SearchResult, \ | ||
SearchInput | ||
from infra_ai_service.sdk.qdrant import qdrant_client, collection_name, \ | ||
fastembed_model | ||
import asyncio | ||
import logging | ||
from concurrent.futures import ThreadPoolExecutor | ||
from fastapi import HTTPException | ||
from infra_ai_service.sdk import pgvector | ||
from infra_ai_service.model.model import SearchInput, SearchOutput, SearchResult | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def perform_vector_search(input_data: SearchInput): | ||
try: | ||
# 检查集合是否存在 | ||
collection_info = qdrant_client.get_collection(collection_name) | ||
if not collection_info: | ||
logger.error(f"Collection {collection_name} does not exist") | ||
raise HTTPException(status_code=404, | ||
detail=f"Collection {collection_name} does " | ||
f"not exist") | ||
|
||
# 生成查询文本的嵌入 | ||
query_vector = list(fastembed_model.embed([input_data.query_text])) | ||
if not query_vector: | ||
logger.error("Failed to generate query embedding") | ||
raise HTTPException(status_code=500, | ||
detail="Failed to generate query embedding") | ||
|
||
# 执行向量搜索 | ||
search_results = qdrant_client.search( | ||
collection_name=collection_name, | ||
query_vector=query_vector[0], | ||
limit=input_data.top_n, | ||
score_threshold=input_data.score_threshold | ||
) | ||
# 确保模型已初始化 | ||
if pgvector.model is None: | ||
logger.error("模型未初始化") | ||
raise HTTPException(status_code=500, detail="模型未初始化") | ||
|
||
# 生成查询文本的嵌入向量 | ||
loop = asyncio.get_running_loop() | ||
with ThreadPoolExecutor() as pool_executor: | ||
embedding_vector = await loop.run_in_executor( | ||
pool_executor, pgvector.model.encode, [input_data.query_text] | ||
) | ||
embedding_vector = embedding_vector[0] | ||
|
||
# 将 ndarray 转换为列表 | ||
embedding_vector_list = embedding_vector.tolist() | ||
|
||
# 从连接池获取连接 | ||
async with pgvector.pool.connection() as conn: | ||
async with conn.cursor() as cur: | ||
# 执行向量搜索查询,显式转换参数为 vector 类型 | ||
await cur.execute( | ||
""" | ||
SELECT id, content, embedding, 1 - (embedding <#> %s::vector) AS similarity | ||
FROM documents | ||
ORDER BY similarity DESC | ||
LIMIT %s | ||
""", | ||
(embedding_vector_list, input_data.top_n) | ||
) | ||
rows = await cur.fetchall() | ||
|
||
# 转换搜索结果为输出格式 | ||
results = [ | ||
SearchResult( | ||
id=str(result.id), | ||
score=result.score, | ||
text=result.payload.get('text', 'No text available') | ||
) | ||
for result in search_results | ||
] | ||
results = [] | ||
for row in rows: | ||
similarity = row[3] # 相似度得分 | ||
if similarity >= input_data.score_threshold: | ||
results.append( | ||
SearchResult( | ||
id=str(row[0]), | ||
score=similarity, | ||
text=row[1] # 内容 | ||
) | ||
) | ||
|
||
return SearchOutput(results=results) | ||
except Exception as e: | ||
logger.error(f"Error performing vector search: {str(e)}", | ||
exc_info=True) | ||
raise HTTPException(status_code=500, | ||
detail=f"Error performing vector search: " | ||
f"{str(e)}") | ||
logger.error(f"执行向量搜索时出错: {str(e)}", exc_info=True) | ||
raise HTTPException(status_code=500, detail=f"执行向量搜索时出错: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters