Skip to content

Commit

Permalink
Fix code style ci and code review
Browse files Browse the repository at this point in the history
  • Loading branch information
MaskerPRC committed Sep 23, 2024
1 parent cdf91d3 commit e9969fd
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 70 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ dmypy.json
cython_debug/

.python-version
infra_ai_service.egg-info
3 changes: 0 additions & 3 deletions infra_ai_service.egg-info/PKG-INFO

This file was deleted.

30 changes: 0 additions & 30 deletions infra_ai_service.egg-info/SOURCES.txt

This file was deleted.

1 change: 0 additions & 1 deletion infra_ai_service.egg-info/dependency_links.txt

This file was deleted.

1 change: 0 additions & 1 deletion infra_ai_service.egg-info/top_level.txt

This file was deleted.

15 changes: 10 additions & 5 deletions infra_ai_service/api/router.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from fastapi.routing import APIRouter

from infra_ai_service.api.ai_enhance.embedding import router as embedding_router
from infra_ai_service.api.ai_enhance.text_process import router as text_process_router
from infra_ai_service.api.ai_enhance.vector_search import router as vector_search_router
from infra_ai_service.api.ai_enhance.embedding import \
router as embedding_router
from infra_ai_service.api.ai_enhance.text_process import \
router as text_process_router
from infra_ai_service.api.ai_enhance.vector_search import \
router as vector_search_router

api_router = APIRouter()
api_router.include_router(text_process_router, prefix="/text", tags=["text"])
api_router.include_router(embedding_router, prefix="/embedding", tags=["embedding"])
api_router.include_router(vector_search_router, prefix="/search", tags=["search"])
api_router.include_router(embedding_router, prefix="/embedding",
tags=["embedding"])
api_router.include_router(vector_search_router, prefix="/search",
tags=["search"])
29 changes: 17 additions & 12 deletions infra_ai_service/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from fastapi import HTTPException
from fastembed.embedding import DefaultEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
# infra_ai_service/common/utils.py

from infra_ai_service.config.config import settings


async def setup_database(pool):
async with pool.connection() as conn:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
# 创建扩展名,使用配置项
await conn.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
f"CREATE EXTENSION IF NOT EXISTS {settings.VECTOR_EXTENSION}"
)
# 创建表,使用配置项
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {settings.TABLE_NAME} (
id bigserial PRIMARY KEY,
content text,
embedding vector(384)
embedding vector({settings.VECTOR_DIMENSION})
)
"""
"""
)
# 创建索引,使用配置项
await conn.execute(
f"""
CREATE INDEX IF NOT EXISTS {settings.TABLE_NAME}_content_idx
ON {settings.TABLE_NAME}
USING GIN (to_tsvector('{settings.LANGUAGE}', content))
"""
CREATE INDEX IF NOT EXISTS documents_content_idx
ON documents USING GIN (to_tsvector('english', content))
"""
)
45 changes: 34 additions & 11 deletions infra_ai_service/config/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# infra_ai_service/config/config.py

from pathlib import Path
from sys import modules

from pydantic import BaseSettings

Expand All @@ -12,26 +13,48 @@ class Settings(BaseSettings):
ENV: str = "dev"
HOST: str = "0.0.0.0"
PORT: int = 8000
_BASE_URL: str = f"https://{HOST}:{PORT}"
# quantity of workers for uvicorn
_BASE_URL: str = f"http://{HOST}:{PORT}"
WORKERS_COUNT: int = 1
# Enable uvicorn reloading
RELOAD: bool = False

# 数据库配置项
DB_NAME: str = "pgvector"
DB_USER: str = "postgres"
DB_PASSWORD: str = "postgres"
DB_HOST: str = "localhost"
DB_PORT: int = 5432

# 模型名称配置项
MODEL_NAME: str = "multi-qa-MiniLM-L6-cos-v1"

# 新增的配置项
VECTOR_EXTENSION: str = "vector"
TABLE_NAME: str = "documents"
VECTOR_DIMENSION: int = 384
LANGUAGE: str = "english"

@property
def BASE_URL(self) -> str:
return self._BASE_URL if self._BASE_URL.endswith("/") else f"{self._BASE_URL}/"
if self._BASE_URL.endswith("/"):
return self._BASE_URL
else:
return f"{self._BASE_URL}/"

class Config:
env_file = f"{BASE_DIR}/.env"
env_file_encoding = "utf-8"
fields = {
"_BASE_URL": {
"env": "BASE_URL",
},
"_DB_BASE": {
"env": "DB_BASE",
},
"_BASE_URL": {"env": "BASE_URL"},
"DB_NAME": {"env": "DB_NAME"},
"DB_USER": {"env": "DB_USER"},
"DB_PASSWORD": {"env": "DB_PASSWORD"},
"DB_HOST": {"env": "DB_HOST"},
"DB_PORT": {"env": "DB_PORT"},
"MODEL_NAME": {"env": "MODEL_NAME"},
"VECTOR_EXTENSION": {"env": "VECTOR_EXTENSION"},
"TABLE_NAME": {"env": "TABLE_NAME"},
"VECTOR_DIMENSION": {"env": "VECTOR_DIMENSION"},
"LANGUAGE": {"env": "LANGUAGE"},
}


Expand Down
15 changes: 11 additions & 4 deletions infra_ai_service/sdk/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sentence_transformers import SentenceTransformer

from infra_ai_service.common.utils import setup_database
from infra_ai_service.config.config import settings

# 初始化模型
model = None
Expand All @@ -12,11 +13,17 @@
async def setup_model_and_pool():
global model, pool
# 初始化模型
model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
model = SentenceTransformer(settings.MODEL_NAME)
# 创建异步连接池
pool = AsyncConnectionPool(
"dbname=pgvector user=postgres password=postgres " "host=localhost port=5432",
open=True,
conn_str = (
f"dbname={settings.DB_NAME} "
f"user={settings.DB_USER} "
f"password={settings.DB_PASSWORD} "
f"host={settings.DB_HOST} "
f"port={settings.DB_PORT}"
)
pool = AsyncConnectionPool(conn_str, open=True)

# 设置数据库
await setup_database(pool)

3 changes: 2 additions & 1 deletion infra_ai_service/service/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ async def create_embedding(content):
async with pgvector.pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO documents (content, embedding) VALUES (%s, %s) RETURNING id",
"INSERT INTO documents (content, embedding) "
"VALUES (%s, %s) RETURNING id",
(content, embedding_vector_list), # 使用转换后的列表
)
point_id = (await cur.fetchone())[0]
Expand Down
4 changes: 3 additions & 1 deletion infra_ai_service/service/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ async def perform_vector_search(input_data: SearchInput):
# 执行向量搜索查询,显式转换参数为 vector 类型
await cur.execute(
"""
SELECT id, content, embedding, 1 - (embedding <#> %s::vector) AS similarity
SELECT id, content, embedding,
1 - (embedding <#> %s::vector)
AS similarity
FROM documents
ORDER BY similarity DESC
LIMIT %s
Expand Down
3 changes: 2 additions & 1 deletion infra_ai_service/service/text_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ async def process_text(input_content: str) -> TextOutput:
return TextOutput(modified_content=modified_text)
except Exception as e:
logger.error(f"Error processing text: {str(e)}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Error processing text: {str(e)}")
raise HTTPException(status_code=400,
detail=f"Error processing text: {str(e)}")

0 comments on commit e9969fd

Please sign in to comment.