Skip to content

Commit

Permalink
Add UTs and Fix require modules
Browse files Browse the repository at this point in the history
  • Loading branch information
MaskerPRC committed Sep 24, 2024
1 parent 93b568e commit ab91cd0
Show file tree
Hide file tree
Showing 22 changed files with 452 additions and 396 deletions.
8 changes: 0 additions & 8 deletions etc/user_cases/test_embedding.sh

This file was deleted.

8 changes: 0 additions & 8 deletions etc/user_cases/text_process.sh

This file was deleted.

10 changes: 0 additions & 10 deletions etc/user_cases/vector_search.sh

This file was deleted.

23 changes: 11 additions & 12 deletions infra_ai_service/api/ai_enhance/spec_repair_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from infra_ai_service.service.spec_repair import SpecBot

from infra_ai_service.service.spec_repair import SpecBot

router = APIRouter()


@router.post("/")
async def spec_repair_process(err_spec_file: UploadFile = File(...),
err_log_file: UploadFile = File(...)):
async def spec_repair_process(
err_spec_file: UploadFile = File(...), err_log_file: UploadFile = File(...)
):
try:
err_spec_lines = await err_spec_file.read()
err_log_lines = await err_log_file.read()
Expand All @@ -18,18 +19,16 @@ async def spec_repair_process(err_spec_file: UploadFile = File(...),

bot = SpecBot()
suggestion, is_repaired, repaired_spec_lines, log_content = bot.repair(
err_spec_lines, err_log_lines)
err_spec_lines, err_log_lines
)

resp_data = {
'suggestions': suggestion,
'repair_status': is_repaired,
'repair_spec': repaired_spec_lines,
'log': log_content
"suggestions": suggestion,
"repair_status": is_repaired,
"repair_spec": repaired_spec_lines,
"log": log_content,
}
return JSONResponse(content=resp_data)
except Exception as e:
resp_data = {
'status': 'error',
'message': str(e)
}
resp_data = {"status": "error", "message": str(e)}
return JSONResponse(content=resp_data)
29 changes: 0 additions & 29 deletions infra_ai_service/api/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +0,0 @@
from fastapi import HTTPException
from fastembed.embedding import DefaultEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams


def setup_qdrant_environment():
# 初始化FastEmbed模型和Qdrant客户端
fastembed_model = DefaultEmbedding()
qdrant_client = QdrantClient(url="http://localhost:6333")
collection_name = 'test_simi'

# 检查集合是否存在,如果不存在则创建
try:
qdrant_client.get_collection(collection_name)
print(f"Collection {collection_name} already exists")
except Exception as e:
# 获取向量维度
sample_embedding = next(fastembed_model.embed(["Sample text"]))
vector_size = len(sample_embedding)

# 创建集合
qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=vector_size,
distance=Distance.COSINE),
)
print(f"Created collection: {collection_name}")
return fastembed_model, qdrant_client, collection_name
35 changes: 21 additions & 14 deletions infra_ai_service/api/router.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from fastapi.routing import APIRouter

from infra_ai_service.api.ai_enhance.embedding import \
router as embedding_router
from infra_ai_service.api.ai_enhance.spec_repair_process import \
router as spec_repair_process
from infra_ai_service.api.ai_enhance.text_process import \
router as text_process_router
from infra_ai_service.api.ai_enhance.vector_search import \
router as vector_search_router
from infra_ai_service.api.ai_enhance.embedding import (
router as embedding_router,
)
from infra_ai_service.api.ai_enhance.spec_repair_process import (
router as spec_repair_process,
)
from infra_ai_service.api.ai_enhance.text_process import (
router as text_process_router,
)
from infra_ai_service.api.ai_enhance.vector_search import (
router as vector_search_router,
)

api_router = APIRouter()
api_router.include_router(spec_repair_process, prefix="/spec-repair",
tags=["repair"])
api_router.include_router(
spec_repair_process, prefix="/spec-repair", tags=["repair"]
)
api_router.include_router(text_process_router, prefix="/text", tags=["text"])
api_router.include_router(embedding_router, prefix="/embedding",
tags=["embedding"])
api_router.include_router(vector_search_router, prefix="/search",
tags=["search"])
api_router.include_router(
embedding_router, prefix="/embedding", tags=["embedding"]
)
api_router.include_router(
vector_search_router, prefix="/search", tags=["search"]
)
36 changes: 14 additions & 22 deletions infra_ai_service/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Settings(BaseSettings):
"""Application settings."""

ENV: str = "dev"
HOST: str = 'localhost'
HOST: str = "localhost"
PORT: int = 8000
_BASE_URL: str = f"http://{HOST}:{PORT}"
WORKERS_COUNT: int = 1
Expand All @@ -25,7 +25,7 @@ class Settings(BaseSettings):
DB_PORT: int = 0

# 模型名称配置项
MODEL_NAME: str = ""
MODEL_NAME: str = "model-name-here"

# 新增的配置项
VECTOR_EXTENSION: str = ""
Expand All @@ -34,10 +34,10 @@ class Settings(BaseSettings):
LANGUAGE: str = ""

# SpecBot config
SPECBOT_AI_MODEL: str = ''
REPAIR_PRO_AI_MODEL: str = ''
OPENAI_API_KEY: str = ''
OPENAI_BASE_URL: str = ''
SPECBOT_AI_MODEL: str = ""
REPAIR_PRO_AI_MODEL: str = ""
OPENAI_API_KEY: str = ""
OPENAI_BASE_URL: str = ""

@property
def BASE_URL(self) -> str:
Expand All @@ -61,24 +61,16 @@ class Config:
"TABLE_NAME": {"env": "TABLE_NAME"},
"VECTOR_DIMENSION": {"env": "VECTOR_DIMENSION"},
"LANGUAGE": {"env": "LANGUAGE"},
'HOST': {
'env': 'HOST',
"HOST": {
"env": "HOST",
},
'PORT': {
'env': 'PORT',
"PORT": {
"env": "PORT",
},
'SPECBOT_AI_MODEL': {
'env': 'SPECBOT_AI_MODEL'
},
'REPAIR_PRO_AI_MODEL': {
'env': 'REPAIR_PRO_AI_MODEL'
},
'OPENAI_API_KEY': {
'env': 'OPENAI_API_KEY'
},
'OPENAI_BASE_URL': {
'env': 'OPENAI_BASE_URL'
}
"SPECBOT_AI_MODEL": {"env": "SPECBOT_AI_MODEL"},
"REPAIR_PRO_AI_MODEL": {"env": "REPAIR_PRO_AI_MODEL"},
"OPENAI_API_KEY": {"env": "OPENAI_API_KEY"},
"OPENAI_BASE_URL": {"env": "OPENAI_BASE_URL"},
}


Expand Down
16 changes: 0 additions & 16 deletions infra_ai_service/demo.py

This file was deleted.

23 changes: 15 additions & 8 deletions infra_ai_service/service/embedding_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import uuid
from concurrent.futures import ThreadPoolExecutor

import numpy
from fastapi import HTTPException

from infra_ai_service.model.model import EmbeddingOutput
Expand All @@ -12,8 +12,9 @@ async def create_embedding(content):
try:
# 确保模型已初始化
if pgvector.model is None:
raise HTTPException(status_code=500,
detail="Model is not initialized")
raise HTTPException(
status_code=500, detail="Model is not initialized"
)

# 使用线程池执行同步的嵌入计算
loop = asyncio.get_running_loop()
Expand All @@ -23,8 +24,11 @@ async def create_embedding(content):
)
embedding_vector = embedding_vector[0]

# 将 ndarray 转换为列表
embedding_vector_list = embedding_vector.tolist()
# 检查返回类型是否为 ndarray,如果是,则转换为列表
if isinstance(embedding_vector, numpy.ndarray):
embedding_vector_list = embedding_vector.tolist()
else:
embedding_vector_list = embedding_vector # 假设已经是列表

# 从连接池获取连接
async with pgvector.pool.connection() as conn:
Expand All @@ -36,7 +40,10 @@ async def create_embedding(content):
)
point_id = (await cur.fetchone())[0]

return EmbeddingOutput(id=point_id, embedding=embedding_vector_list)
return EmbeddingOutput(
id=str(point_id), embedding=embedding_vector_list
)
except Exception as e:
raise HTTPException(status_code=400,
detail=f"Error processing embedding: {e}")
raise HTTPException(
status_code=400, detail=f"Error processing embedding: {e}"
)
39 changes: 30 additions & 9 deletions infra_ai_service/service/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@
import logging
from concurrent.futures import ThreadPoolExecutor

import numpy
from fastapi import HTTPException

from infra_ai_service.model.model import SearchInput, SearchOutput, \
SearchResult
from infra_ai_service.model.model import (
SearchInput,
SearchOutput,
SearchResult,
)
from infra_ai_service.sdk import pgvector

logger = logging.getLogger(__name__)


async def perform_vector_search(input_data: SearchInput):
async def prepare_vector(input_data: SearchInput):
try:
# 确保模型已初始化
if pgvector.model is None:
Expand All @@ -28,9 +32,24 @@ async def perform_vector_search(input_data: SearchInput):
)
embedding_vector = embedding_vector[0]

# 将 ndarray 转换为列表
embedding_vector_list = embedding_vector.tolist()
# 检查返回类型是否为 ndarray,如果是,则转换为列表
if isinstance(embedding_vector, numpy.ndarray):
embedding_vector_list = embedding_vector.tolist()
else:
embedding_vector_list = embedding_vector # 假设已经是列表

return embedding_vector_list
except Exception as e:
logger.error(f"准备向量时出错: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"准备向量时出错: {str(e)}"
)


async def perform_vector_search(input_data: SearchInput):
embedding_vector_list = await prepare_vector(input_data)

try:
# 从连接池获取连接
async with pgvector.pool.connection() as conn:
async with conn.cursor() as cur:
Expand All @@ -54,12 +73,14 @@ async def perform_vector_search(input_data: SearchInput):
similarity = row[3] # 相似度得分
if similarity >= input_data.score_threshold:
results.append(
SearchResult(id=str(row[0]), score=similarity,
text=row[1]) # 内容
SearchResult(
id=str(row[0]), score=similarity, text=row[1]
) # 内容
)

return SearchOutput(results=results)
except Exception as e:
logger.error(f"执行向量搜索时出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500,
detail=f"执行向量搜索时出错: {str(e)}")
raise HTTPException(
status_code=500, detail=f"执行向量搜索时出错: {str(e)}"
)
Loading

0 comments on commit ab91cd0

Please sign in to comment.