upload project source code
This commit is contained in:
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
@@ -0,0 +1,537 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Body, WebSocket, Form, UploadFile, File
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.common.response import StreamResponse, SuccessResponse
|
||||
from app.common.request import PaginationService
|
||||
from app.core.base_params import PaginationQueryParam
|
||||
from app.core.dependencies import AuthPermission
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from app.core.router_class import OperationLogRoute
|
||||
from .service import McpService, AIProviderService, EmbeddingConfigService, KnowledgeBaseService, AIModelConfigService, AIModelTrainingService, AIModelTestService
|
||||
from .schema import (
|
||||
McpCreateSchema, McpUpdateSchema, ChatQuerySchema, McpQueryParam,
|
||||
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderQueryParam,
|
||||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigQueryParam,
|
||||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseQueryParam,
|
||||
AIModelConfigUpdateSchema, AIModelConfigQueryParam, AIModelTrainingChatSchema,
|
||||
AIModelTestSchema
|
||||
)
|
||||
|
||||
|
||||
AIRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai", tags=["MCP智能助手"])
|
||||
AIProviderRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/provider", tags=["AI供应商配置"])
|
||||
EmbeddingConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/embedding", tags=["向量化配置"])
|
||||
KnowledgeBaseRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/knowledge", tags=["知识库管理"])
|
||||
AIModelConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/model", tags=["AI模型配置"])
|
||||
|
||||
|
||||
@AIRouter.post("/chat", summary="智能对话", description="与MCP智能助手进行对话")
|
||||
async def chat_controller(
|
||||
query: ChatQuerySchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
智能对话接口
|
||||
|
||||
参数:
|
||||
- query (ChatQuerySchema): 聊天查询模型
|
||||
|
||||
返回:
|
||||
- StreamingResponse: 流式响应,每次返回一个聊天响应
|
||||
"""
|
||||
user_name = auth.user.name if auth.user else "未知用户"
|
||||
log.info(f"用户 {user_name} 发起智能对话: {query.message[:50]}...")
|
||||
|
||||
async def generate_response():
|
||||
try:
|
||||
async for chunk in McpService.chat_query(query=query):
|
||||
# 确保返回的是字节串
|
||||
if chunk:
|
||||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||||
except Exception as e:
|
||||
log.error(f"流式响应出错: {str(e)}")
|
||||
yield f"抱歉,处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||||
|
||||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
@AIRouter.get("/detail/{id}", summary="获取 MCP 服务器详情", description="获取 MCP 服务器详情")
|
||||
async def detail_controller(
|
||||
id: int = Path(..., description="MCP ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取 MCP 服务器详情接口
|
||||
|
||||
参数:
|
||||
- id (int): MCP 服务器ID
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含 MCP 服务器详情的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.detail_service(auth=auth, id=id)
|
||||
log.info(f"获取 MCP 服务器详情成功 {id}")
|
||||
return SuccessResponse(data=result_dict, msg="获取 MCP 服务器详情成功")
|
||||
|
||||
|
||||
@AIRouter.get("/list", summary="查询 MCP 服务器列表", description="查询 MCP 服务器列表")
|
||||
async def list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: McpQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
查询 MCP 服务器列表接口
|
||||
|
||||
参数:
|
||||
- page (PaginationQueryParam): 分页查询参数模型
|
||||
- search (McpQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含 MCP 服务器列表的 JSON 响应
|
||||
"""
|
||||
result_dict_list = await McpService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
log.info(f"查询 MCP 服务器列表成功")
|
||||
return SuccessResponse(data=result_dict, msg="查询 MCP 服务器列表成功")
|
||||
|
||||
|
||||
@AIRouter.post("/create", summary="创建 MCP 服务器", description="创建 MCP 服务器")
|
||||
async def create_controller(
|
||||
data: McpCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
创建 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- data (McpCreateSchema): 创建 MCP 服务器模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含创建 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建 MCP 服务器成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.put("/update/{id}", summary="修改 MCP 服务器", description="修改 MCP 服务器")
|
||||
async def update_controller(
|
||||
data: McpUpdateSchema,
|
||||
id: int = Path(..., description="MCP ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
修改 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- data (McpUpdateSchema): 修改 MCP 服务器模型
|
||||
- id (int): MCP 服务器ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含修改 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改 MCP 服务器成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.delete("/delete", summary="删除 MCP 服务器", description="删除 MCP 服务器")
|
||||
async def delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
删除 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- ids (list[int]): MCP 服务器ID列表
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含删除 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
await McpService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除 MCP 服务器成功: {ids}")
|
||||
return SuccessResponse(msg="删除 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.websocket("/ws/chat", name="WebSocket聊天")
|
||||
async def websocket_chat_controller(
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""
|
||||
WebSocket聊天接口
|
||||
|
||||
ws://127.0.0.1:8001/api/v1/ai/mcp/ws/chat
|
||||
"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# 流式发送响应
|
||||
try:
|
||||
async for chunk in McpService.chat_query(query=ChatQuerySchema(message=data)):
|
||||
if chunk:
|
||||
await websocket.send_text(chunk)
|
||||
except Exception as e:
|
||||
log.error(f"处理聊天查询出错: {str(e)}")
|
||||
await websocket.send_text(f"抱歉,处理您的请求时出现了错误: {str(e)}")
|
||||
except Exception as e:
|
||||
log.error(f"WebSocket聊天出错: {str(e)}")
|
||||
finally:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@AIModelConfigRouter.post("/test", summary="起名测试")
|
||||
async def naming_test_controller(
|
||||
test_data: AIModelTestSchema,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
起名测试接口(用于小程序/外部调用)
|
||||
|
||||
- 读取模型配置和训练信息
|
||||
- 内部流式调用AI供应商
|
||||
- 组合响应后一次性返回
|
||||
- 仅读数据,不保存对话记录
|
||||
|
||||
参数:
|
||||
- model_type: 模型类型(enterprise_naming/personal_naming等)
|
||||
- text: 用户输入的文本
|
||||
|
||||
返回:
|
||||
- data: AI的完整响应文本
|
||||
"""
|
||||
log.info(f"[起名测试] 收到请求: model_type={test_data.model_type}, text={test_data.text[:50]}...")
|
||||
|
||||
result = await AIModelTestService.test_naming(
|
||||
model_type=test_data.model_type,
|
||||
text=test_data.text
|
||||
)
|
||||
|
||||
log.info(f"[起名测试] 响应完成,长度: {len(result)}")
|
||||
return SuccessResponse(data=result, msg="测试成功")
|
||||
|
||||
|
||||
# ============== AI供应商配置 ==============
|
||||
|
||||
@AIProviderRouter.get("/detail/{id}", summary="获取AI供应商详情")
|
||||
async def provider_detail_controller(
|
||||
id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取AI供应商详情成功")
|
||||
|
||||
|
||||
@AIProviderRouter.get("/list", summary="查询AI供应商列表")
|
||||
async def provider_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: AIProviderQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await AIProviderService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询AI供应商列表成功")
|
||||
|
||||
|
||||
@AIProviderRouter.post("/create", summary="创建AI供应商")
|
||||
async def provider_create_controller(
|
||||
data: AIProviderCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建AI供应商成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建AI供应商成功")
|
||||
|
||||
|
||||
@AIProviderRouter.put("/update/{id}", summary="修改AI供应商")
|
||||
async def provider_update_controller(
|
||||
data: AIProviderUpdateSchema,
|
||||
id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改AI供应商成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改AI供应商成功")
|
||||
|
||||
|
||||
@AIProviderRouter.delete("/delete", summary="删除AI供应商")
|
||||
async def provider_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIProviderService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除AI供应商成功: {ids}")
|
||||
return SuccessResponse(msg="删除AI供应商成功")
|
||||
|
||||
|
||||
# ============== 向量化配置 ==============
|
||||
|
||||
@EmbeddingConfigRouter.get("/detail/{id}", summary="获取向量化配置详情")
|
||||
async def embedding_detail_controller(
|
||||
id: int = Path(..., description="向量化配置ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取向量化配置详情成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.get("/list", summary="查询向量化配置列表")
|
||||
async def embedding_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: EmbeddingConfigQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await EmbeddingConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询向量化配置列表成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.post("/create", summary="创建向量化配置")
|
||||
async def embedding_create_controller(
|
||||
data: EmbeddingConfigCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建向量化配置成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建向量化配置成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.put("/update/{id}", summary="修改向量化配置")
|
||||
async def embedding_update_controller(
|
||||
data: EmbeddingConfigUpdateSchema,
|
||||
id: int = Path(..., description="向量化配置ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改向量化配置成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改向量化配置成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.delete("/delete", summary="删除向量化配置")
|
||||
async def embedding_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await EmbeddingConfigService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除向量化配置成功: {ids}")
|
||||
return SuccessResponse(msg="删除向量化配置成功")
|
||||
|
||||
|
||||
# ============== 知识库管理 ==============
|
||||
|
||||
@KnowledgeBaseRouter.get("/detail/{id}", summary="获取知识库详情")
|
||||
async def knowledge_detail_controller(
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取知识库详情成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.get("/list", summary="查询知识库列表")
|
||||
async def knowledge_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: KnowledgeBaseQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await KnowledgeBaseService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询知识库列表成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.post("/create", summary="创建知识库")
|
||||
async def knowledge_create_controller(
|
||||
name: str = Form(..., max_length=100, description="知识库名称"),
|
||||
embedding_config_id: int | None = Form(None, description="向量化配置ID"),
|
||||
description: str | None = Form(None, max_length=255, description="备注"),
|
||||
files: list[UploadFile] = File(default=[], description="上传的文件列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
# 构造数据对象
|
||||
data = KnowledgeBaseCreateSchema(name=name, embedding_config_id=embedding_config_id, description=description)
|
||||
result_dict = await KnowledgeBaseService.create_service(auth=auth, data=data, files=files)
|
||||
log.info(f"创建知识库成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.put("/update/{id}", summary="修改知识库")
|
||||
async def knowledge_update_controller(
|
||||
data: KnowledgeBaseUpdateSchema,
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改知识库成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.delete("/delete", summary="删除知识库")
|
||||
async def knowledge_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await KnowledgeBaseService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除知识库成功: {ids}")
|
||||
return SuccessResponse(msg="删除知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.post("/retry/{id}", summary="重试向量化")
|
||||
async def knowledge_retry_controller(
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.retry_service(auth=auth, id=id)
|
||||
log.info(f"重试向量化成功: {id}")
|
||||
return SuccessResponse(data=result_dict, msg="已启动重新向量化")
|
||||
|
||||
|
||||
# ============== AI模型配置 ==============
|
||||
|
||||
@AIModelConfigRouter.get("/detail/{model_type}", summary="获取AI模型配置详情")
|
||||
async def model_config_detail_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIModelConfigService.detail_service(auth=auth, model_type=model_type)
|
||||
return SuccessResponse(data=result_dict, msg="获取AI模型配置详情成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.get("/list", summary="查询AI模型配置列表")
|
||||
async def model_config_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: AIModelConfigQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await AIModelConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询AI模型配置列表成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.put("/update/{model_type}", summary="更新AI模型配置")
|
||||
async def model_config_update_controller(
|
||||
data: AIModelConfigUpdateSchema,
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIModelConfigService.update_service(auth=auth, model_type=model_type, data=data)
|
||||
log.info(f"更新AI模型配置成功: {model_type}")
|
||||
return SuccessResponse(data=result_dict, msg="更新AI模型配置成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.get("/available-models/{provider_id}", summary="获取可用模型列表")
|
||||
async def available_models_controller(
|
||||
provider_id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_list = await AIModelConfigService.get_available_models(auth=auth, provider_id=provider_id)
|
||||
return SuccessResponse(data=result_list, msg="获取可用模型列表成功")
|
||||
|
||||
|
||||
# ============== AI模型训练对话 ==============
|
||||
|
||||
@AIModelConfigRouter.get("/messages/{model_type}", summary="获取训练对话记录")
|
||||
async def training_messages_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_list = await AIModelTrainingService.get_messages_service(auth=auth, model_type=model_type)
|
||||
return SuccessResponse(data=result_list, msg="获取训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.delete("/message/{message_id}", summary="删除单条训练对话记录")
|
||||
async def delete_training_message_controller(
|
||||
message_id: int = Path(..., description="对话记录ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIModelTrainingService.delete_message_service(auth=auth, message_id=message_id)
|
||||
log.info(f"删除训练对话记录成功: {message_id}")
|
||||
return SuccessResponse(msg="删除训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.delete("/messages/{model_type}", summary="清空训练对话记录")
|
||||
async def clear_training_messages_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIModelTrainingService.clear_messages_service(auth=auth, model_type=model_type)
|
||||
log.info(f"清空训练对话记录成功: {model_type}")
|
||||
return SuccessResponse(msg="清空训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.post("/chat", summary="训练对话")
|
||||
async def training_chat_controller(
|
||||
chat_data: AIModelTrainingChatSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
训练对话接口,流式输出
|
||||
"""
|
||||
log.info(f"[训练对话] 收到请求: model_type={chat_data.model_type}, message={chat_data.message[:50]}...")
|
||||
|
||||
async def generate_response():
|
||||
log.info("[训练对话] 开始生成响应")
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||||
log.info(f"[训练对话] 响应生成完成,共 {chunk_count} 个 chunk")
|
||||
except Exception as e:
|
||||
log.error(f"训练对话出错: {str(e)}")
|
||||
yield f"处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||||
|
||||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
@AIModelConfigRouter.websocket("/ws/chat")
|
||||
async def training_websocket_chat_controller(
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""
|
||||
训练对话 WebSocket 接口
|
||||
"""
|
||||
from app.core.security import verify_token
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
model_type = data.get('model_type', 'naming')
|
||||
message = data.get('message', '')
|
||||
config_changed = data.get('config_changed', False)
|
||||
config_data = data.get('config_data')
|
||||
|
||||
# 构建chat_data
|
||||
chat_data = AIModelTrainingChatSchema(
|
||||
model_type=model_type,
|
||||
message=message,
|
||||
config_changed=config_changed,
|
||||
config_data=AIModelConfigUpdateSchema(**config_data) if config_data else None
|
||||
)
|
||||
|
||||
# 构建简化的auth(WebSocket没有正常的认证流程,实际使用时需要实现认证)
|
||||
auth = AuthSchema()
|
||||
|
||||
# 流式发送响应
|
||||
try:
|
||||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||||
if chunk:
|
||||
await websocket.send_text(chunk)
|
||||
except Exception as e:
|
||||
log.error(f"处理训练对话出错: {str(e)}")
|
||||
await websocket.send_text(f"处理您的请求时出现了错误: {str(e)}")
|
||||
except Exception as e:
|
||||
log.error(f"WebSocket训练对话出错: {str(e)}")
|
||||
finally:
|
||||
await websocket.close()
|
||||
@@ -0,0 +1,270 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Sequence, Any
|
||||
|
||||
from app.core.base_crud import CRUDBase
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .model import McpModel, AIProviderModel, EmbeddingConfigModel, KnowledgeBaseModel, AIModelConfigModel, AIModelTrainingMessageModel
|
||||
from .schema import (
|
||||
McpCreateSchema, McpUpdateSchema,
|
||||
AIProviderCreateSchema, AIProviderUpdateSchema,
|
||||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema,
|
||||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema,
|
||||
AIModelConfigCreateSchema, AIModelConfigUpdateSchema,
|
||||
AIModelTrainingMessageCreateSchema
|
||||
)
|
||||
|
||||
|
||||
class McpCRUD(CRUDBase[McpModel, McpCreateSchema, McpUpdateSchema]):
|
||||
"""MCP 服务器数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
"""
|
||||
初始化CRUD
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
self.auth = auth
|
||||
super().__init__(model=McpModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> McpModel | None:
|
||||
"""
|
||||
获取MCP服务器详情
|
||||
|
||||
参数:
|
||||
- id (int): MCP服务器ID
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- McpModel | None: MCP服务器模型实例(如果存在)
|
||||
"""
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> McpModel | None:
|
||||
"""
|
||||
通过名称获取MCP服务器
|
||||
|
||||
参数:
|
||||
- name (str): MCP服务器名称
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Optional[McpModel]: MCP服务器模型实例(如果存在)
|
||||
"""
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[McpModel]:
|
||||
"""
|
||||
列表查询MCP服务器
|
||||
|
||||
参数:
|
||||
- search (dict | None): 查询参数字典
|
||||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Sequence[McpModel]: MCP服务器模型实例序列
|
||||
"""
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: McpCreateSchema) -> McpModel | None:
|
||||
"""
|
||||
创建MCP服务器
|
||||
|
||||
参数:
|
||||
- data (McpCreateSchema): 创建MCP服务器模型
|
||||
|
||||
返回:
|
||||
- Optional[McpModel]: 创建的MCP服务器模型实例(如果成功)
|
||||
"""
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: McpUpdateSchema) -> McpModel | None:
|
||||
"""
|
||||
更新MCP服务器
|
||||
|
||||
参数:
|
||||
- id (int): MCP服务器ID
|
||||
- data (McpUpdateSchema): 更新MCP服务器模型
|
||||
|
||||
返回:
|
||||
- McpModel | None: 更新的MCP服务器模型实例(如果成功)
|
||||
"""
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
"""
|
||||
批量删除MCP服务器
|
||||
|
||||
参数:
|
||||
- ids (list[int]): MCP服务器ID列表
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
|
||||
class AIProviderCRUD(CRUDBase[AIProviderModel, AIProviderCreateSchema, AIProviderUpdateSchema]):
|
||||
"""AI供应商数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=AIProviderModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIProviderModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> AIProviderModel | None:
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIProviderModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: AIProviderCreateSchema) -> AIProviderModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: AIProviderUpdateSchema) -> AIProviderModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def clear_default_crud(self) -> None:
|
||||
"""清除所有默认设置"""
|
||||
obj_list = await self.list(search={'is_default': 1})
|
||||
for obj in obj_list:
|
||||
await self.update(id=obj.id, data=AIProviderUpdateSchema(
|
||||
name=obj.name,
|
||||
provider_type=obj.provider_type,
|
||||
base_url=obj.base_url,
|
||||
api_key=obj.api_key,
|
||||
is_default=0,
|
||||
description=obj.description
|
||||
))
|
||||
|
||||
|
||||
class EmbeddingConfigCRUD(CRUDBase[EmbeddingConfigModel, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema]):
|
||||
"""向量化配置数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=EmbeddingConfigModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[EmbeddingConfigModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: EmbeddingConfigCreateSchema) -> EmbeddingConfigModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: EmbeddingConfigUpdateSchema) -> EmbeddingConfigModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def clear_default_crud(self) -> None:
|
||||
"""清除所有默认设置"""
|
||||
obj_list = await self.list(search={'is_default': 1})
|
||||
for obj in obj_list:
|
||||
await self.update(id=obj.id, data=EmbeddingConfigUpdateSchema(
|
||||
name=obj.name,
|
||||
embedding_type=obj.embedding_type,
|
||||
model_name=obj.model_name,
|
||||
base_url=obj.base_url,
|
||||
api_key=obj.api_key,
|
||||
is_default=0,
|
||||
description=obj.description
|
||||
))
|
||||
|
||||
|
||||
class KnowledgeBaseCRUD(CRUDBase[KnowledgeBaseModel, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema]):
|
||||
"""知识库数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=KnowledgeBaseModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_by_collection_crud(self, collection_name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||||
return await self.get(collection_name=collection_name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[KnowledgeBaseModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'desc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: dict) -> KnowledgeBaseModel | None:
|
||||
"""create方法接受字典,因为需要额外字段"""
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: KnowledgeBaseUpdateSchema | dict) -> KnowledgeBaseModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
|
||||
class AIModelConfigCRUD(CRUDBase[AIModelConfigModel, AIModelConfigCreateSchema, AIModelConfigUpdateSchema]):
|
||||
"""AI模型配置数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=AIModelConfigModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_model_type_crud(self, model_type: str, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
|
||||
return await self.get(model_type=model_type, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIModelConfigModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: AIModelConfigCreateSchema | dict) -> AIModelConfigModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: AIModelConfigUpdateSchema | dict) -> AIModelConfigModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
|
||||
class AIModelTrainingMessageCRUD(CRUDBase[AIModelTrainingMessageModel, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageCreateSchema]):
|
||||
"""AI模型训练对话数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=AIModelTrainingMessageModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelTrainingMessageModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_list_by_config_crud(self, model_config_id: int, preload: list[str | Any] | None = None) -> Sequence[AIModelTrainingMessageModel]:
|
||||
"""获取某个模型配置的所有训练对话"""
|
||||
return await self.list(search={'model_config_id': model_config_id}, order_by=[{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: AIModelTrainingMessageCreateSchema | dict) -> AIModelTrainingMessageModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def delete_by_config_crud(self, model_config_id: int) -> None:
|
||||
"""删除某个模型配置的所有训练对话"""
|
||||
messages = await self.list(search={'model_config_id': model_config_id})
|
||||
if messages:
|
||||
ids = [m.id for m in messages]
|
||||
await self.delete(ids=ids)
|
||||
@@ -0,0 +1,155 @@
|
||||
'''
|
||||
Author: caoziyuan ziyuan.cao@zhuying.com
|
||||
Date: 2025-12-22 17:42:10
|
||||
LastEditors: caoziyuan ziyuan.cao@zhuying.com
|
||||
LastEditTime: 2025-12-22 18:03:28
|
||||
FilePath: \naming-backend\app\api\v1\module_application\ai\model.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sqlalchemy import JSON, String, Integer, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.base_model import ModelMixin, UserMixin
|
||||
|
||||
|
||||
class McpModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
MCP 服务器表
|
||||
MCP类型:
|
||||
- 0: stdio (标准输入输出)
|
||||
- 1: sse (Server-Sent Events)
|
||||
"""
|
||||
__tablename__: str = 'app_ai_mcp'
|
||||
__table_args__: dict[str, str] = ({'comment': 'MCP 服务器表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), comment='MCP 名称')
|
||||
type: Mapped[int] = mapped_column(Integer, default=0, comment='MCP 类型(0:stdio 1:sse)')
|
||||
url: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程 SSE 地址')
|
||||
command: Mapped[str | None] = mapped_column(String(255), default=None, comment='MCP 命令')
|
||||
args: Mapped[str | None] = mapped_column(String(255), default=None, comment='MCP 命令参数')
|
||||
env: Mapped[dict[str, str] | None] = mapped_column(JSON(), default=None, comment='MCP 环境变量')
|
||||
|
||||
|
||||
class AIProviderModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
AI供应商配置表
|
||||
存储AI服务供应商的接口地址和API Key
|
||||
"""
|
||||
__tablename__: str = 'app_ai_provider'
|
||||
__table_args__: dict[str, str] = ({'comment': 'AI供应商配置表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False, comment='供应商名称')
|
||||
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
|
||||
base_url: Mapped[str] = mapped_column(String(255), nullable=False, comment='接口地址BaseURL')
|
||||
api_key: Mapped[str] = mapped_column(String(255), nullable=False, comment='API Key')
|
||||
is_default: Mapped[int] = mapped_column(Integer, default=0, comment='是否默认供应商(0:否 1:是)')
|
||||
|
||||
|
||||
class EmbeddingConfigModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
知识库向量化配置表
|
||||
支持本地或远程向量化服务
|
||||
"""
|
||||
__tablename__: str = 'app_ai_embedding_config'
|
||||
__table_args__: dict[str, str] = ({'comment': '知识库向量化配置表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False, comment='配置名称')
|
||||
embedding_type: Mapped[int] = mapped_column(Integer, default=0, comment='向量化类型(0:本地 1:远程)')
|
||||
model_name: Mapped[str] = mapped_column(String(100), nullable=False, comment='Embedding模型名称')
|
||||
base_url: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程接口地址(远程模式必填)')
|
||||
api_key: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程API Key(远程模式必填)')
|
||||
is_default: Mapped[int] = mapped_column(Integer, default=0, comment='是否默认配置(0:否 1:是)')
|
||||
|
||||
|
||||
class KnowledgeBaseModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
知识库表
|
||||
存储知识库信息,关联向量化配置
|
||||
"""
|
||||
__tablename__: str = 'app_ai_knowledge_base'
|
||||
__table_args__: dict[str, str] = ({'comment': '知识库表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by", "embedding_config"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, comment='知识库名称')
|
||||
embedding_config_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('app_ai_embedding_config.id', ondelete="SET NULL", onupdate="CASCADE"),
|
||||
default=None,
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment='向量化配置ID'
|
||||
)
|
||||
collection_name: Mapped[str] = mapped_column(String(100), nullable=False, comment='ChromaDB集合名称')
|
||||
document_count: Mapped[int] = mapped_column(Integer, default=0, comment='文档数量')
|
||||
vector_count: Mapped[int] = mapped_column(Integer, default=0, comment='向量数量')
|
||||
kb_status: Mapped[int] = mapped_column(Integer, default=0, comment='知识库状态(0:待处理 1:处理中 2:已完成 3:处理失败)')
|
||||
error_message: Mapped[str | None] = mapped_column(Text, default=None, comment='错误信息')
|
||||
file_paths: Mapped[list[str] | None] = mapped_column(JSON, default=None, comment='文件路径列表')
|
||||
|
||||
# 关联关系
|
||||
embedding_config: Mapped["EmbeddingConfigModel | None"] = relationship(
|
||||
"EmbeddingConfigModel",
|
||||
lazy="selectin",
|
||||
foreign_keys=[embedding_config_id],
|
||||
uselist=False
|
||||
)
|
||||
|
||||
|
||||
class AIModelConfigModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
AI模型配置表
|
||||
存储不同类型AI模型的配置信息
|
||||
模型类型:
|
||||
- enterprise_naming(企业起名), enterprise_renaming(企业改名), enterprise_scoring(企业测名), enterprise_scoring_trial(企业测名试用)
|
||||
- personal_naming(个人起名), personal_renaming(个人改名), personal_scoring(个人测名), personal_scoring_trial(个人测名试用)
|
||||
"""
|
||||
__tablename__: str = 'app_ai_model_config'
|
||||
__table_args__: dict[str, str] = ({'comment': 'AI模型配置表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by", "provider", "knowledge_bases"]
|
||||
|
||||
model_type: Mapped[str] = mapped_column(String(50), nullable=False, unique=True, comment='模型类型(enterprise_naming/enterprise_renaming/enterprise_scoring/enterprise_scoring_trial/personal_naming/personal_renaming/personal_scoring/personal_scoring_trial)')
|
||||
model_name: Mapped[str | None] = mapped_column(String(100), default=None, comment='使用的模型名称')
|
||||
provider_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('app_ai_provider.id', ondelete="SET NULL", onupdate="CASCADE"),
|
||||
default=None,
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment='AI供应商ID'
|
||||
)
|
||||
system_prompt: Mapped[str | None] = mapped_column(Text, default=None, comment='系统提示词')
|
||||
temperature: Mapped[float] = mapped_column(default=1.0, comment='模型温度(0-2)')
|
||||
knowledge_base_ids: Mapped[list[int] | None] = mapped_column(JSON, default=None, comment='关联的知识库ID列表')
|
||||
|
||||
# 关联关系
|
||||
provider: Mapped["AIProviderModel | None"] = relationship(
|
||||
"AIProviderModel",
|
||||
lazy="selectin",
|
||||
foreign_keys=[provider_id],
|
||||
uselist=False
|
||||
)
|
||||
|
||||
|
||||
class AIModelTrainingMessageModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
AI模型训练对话记录表
|
||||
存储训练对话的历史记录
|
||||
"""
|
||||
__tablename__: str = 'app_ai_model_training_message'
|
||||
__table_args__: dict[str, str] = ({'comment': 'AI模型训练对话记录表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
model_config_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('app_ai_model_config.id', ondelete="CASCADE", onupdate="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment='模型配置ID'
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False, comment='角色(user/assistant)')
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False, comment='消息内容')
|
||||
@@ -0,0 +1,307 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import ConfigDict, Field, HttpUrl, BaseModel
|
||||
from fastapi import Query
|
||||
|
||||
from app.core.base_schema import BaseSchema
|
||||
from app.common.enums import McpLLMProvider, EmbeddingType, KnowledgeBaseStatus
|
||||
from app.core.base_schema import BaseSchema, UserBySchema
|
||||
from app.common.enums import McpType
|
||||
from app.core.validator import DateTimeStr
|
||||
|
||||
|
||||
class ChatQuerySchema(BaseModel):
|
||||
"""聊天查询模型"""
|
||||
message: str = Field(..., min_length=1, max_length=4000, description="聊天消息")
|
||||
|
||||
|
||||
class McpCreateSchema(BaseModel):
|
||||
"""创建 MCP 服务器参数"""
|
||||
name: str = Field(..., max_length=64, description='MCP 名称')
|
||||
type: McpType = Field(McpType.stdio, description='MCP 类型')
|
||||
description: str | None = Field(None, max_length=255, description='MCP 描述')
|
||||
url: HttpUrl | None = Field(None, description='远程 SSE 地址')
|
||||
command: str | None = Field(None, max_length=255, description='MCP 命令')
|
||||
args: str | None = Field(None, max_length=255, description='MCP 命令参数,多个参数用英文逗号隔开')
|
||||
env: dict[str, str] | None = Field(None, description='MCP 环境变量')
|
||||
|
||||
|
||||
class McpUpdateSchema(McpCreateSchema):
|
||||
"""更新 MCP 服务器参数"""
|
||||
...
|
||||
|
||||
|
||||
class McpOutSchema(McpCreateSchema, BaseSchema, UserBySchema):
|
||||
"""MCP 服务器详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class McpQueryParam:
|
||||
"""MCP 服务器查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="MCP 名称"),
|
||||
type: McpType | None = Query(None, description="MCP 类型"),
|
||||
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
created_id: int | None = Query(None, description="创建人"),
|
||||
updated_id: int | None = Query(None, description="更新人"),
|
||||
) -> None:
|
||||
|
||||
# 模糊查询字段
|
||||
self.name = ("like", name) if name else None
|
||||
|
||||
# 精确查询字段
|
||||
self.type = type
|
||||
self.created_id = created_id
|
||||
self.updated_id = updated_id
|
||||
|
||||
# 时间范围查询
|
||||
if created_time and len(created_time) == 2:
|
||||
self.created_time = ("between", (created_time[0], created_time[1]))
|
||||
if updated_time and len(updated_time) == 2:
|
||||
self.updated_time = ("between", (updated_time[0], updated_time[1]))
|
||||
|
||||
|
||||
class McpChatParam(BaseSchema):
|
||||
"""MCP 聊天参数"""
|
||||
pk: list[int] = Field(..., description='MCP ID 列表')
|
||||
provider: McpLLMProvider = Field(McpLLMProvider.openai, description='LLM 供应商')
|
||||
model: str = Field(..., description='LLM 名称')
|
||||
key: str = Field(..., description='LLM API Key')
|
||||
base_url: str | None = Field(None, description='自定义 LLM API 地址,必须兼容 openai 供应商')
|
||||
prompt: str = Field(..., description='用户提示词')
|
||||
|
||||
|
||||
# ============== AI供应商配置 ==============
|
||||
|
||||
class AIProviderCreateSchema(BaseModel):
|
||||
"""创建 AI供应商参数"""
|
||||
name: str = Field(..., max_length=50, description='供应商名称')
|
||||
provider_type: str = Field(..., max_length=50, description='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
|
||||
base_url: str = Field(..., max_length=255, description='接口地址BaseURL')
|
||||
api_key: str = Field(..., max_length=255, description='API Key')
|
||||
is_default: int = Field(0, description='是否默认供应商(0:否 1:是)')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class AIProviderUpdateSchema(AIProviderCreateSchema):
|
||||
"""更新 AI供应商参数"""
|
||||
...
|
||||
|
||||
|
||||
class AIProviderOutSchema(AIProviderCreateSchema, BaseSchema, UserBySchema):
|
||||
"""AI供应商详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AIProviderQueryParam:
|
||||
"""AI供应商查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="供应商名称"),
|
||||
provider_type: str | None = Query(None, description="供应商类型"),
|
||||
is_default: int | None = Query(None, description="是否默认"),
|
||||
) -> None:
|
||||
self.name = ("like", name) if name else None
|
||||
self.provider_type = provider_type
|
||||
self.is_default = is_default
|
||||
|
||||
|
||||
# ============== 向量化配置 ==============
|
||||
|
||||
class EmbeddingConfigCreateSchema(BaseModel):
|
||||
"""创建 向量化配置参数"""
|
||||
name: str = Field(..., max_length=50, description='配置名称')
|
||||
embedding_type: int = Field(0, description='向量化类型(0:本地 1:远程)')
|
||||
model_name: str = Field(..., max_length=100, description='Embedding模型名称')
|
||||
base_url: str | None = Field(None, max_length=255, description='远程接口地址(远程模式必填)')
|
||||
api_key: str | None = Field(None, max_length=255, description='远程API Key(远程模式必填)')
|
||||
is_default: int = Field(0, description='是否默认配置(0:否 1:是)')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class EmbeddingConfigUpdateSchema(EmbeddingConfigCreateSchema):
|
||||
"""更新 向量化配置参数"""
|
||||
...
|
||||
|
||||
|
||||
class EmbeddingConfigOutSchema(EmbeddingConfigCreateSchema, BaseSchema, UserBySchema):
|
||||
"""向量化配置详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class EmbeddingConfigQueryParam:
|
||||
"""向量化配置查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="配置名称"),
|
||||
embedding_type: int | None = Query(None, description="向量化类型"),
|
||||
is_default: int | None = Query(None, description="是否默认"),
|
||||
) -> None:
|
||||
self.name = ("like", name) if name else None
|
||||
self.embedding_type = embedding_type
|
||||
self.is_default = is_default
|
||||
|
||||
|
||||
# ============== 知识库 ==============
|
||||
|
||||
class KnowledgeBaseCreateSchema(BaseModel):
|
||||
"""创建 知识库参数"""
|
||||
name: str = Field(..., max_length=100, description='知识库名称')
|
||||
embedding_config_id: int | None = Field(None, description='向量化配置ID')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class KnowledgeBaseUpdateSchema(BaseModel):
|
||||
"""更新 知识库参数"""
|
||||
name: str | None = Field(None, max_length=100, description='知识库名称')
|
||||
embedding_config_id: int | None = Field(None, description='向量化配置ID')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class EmbeddingConfigRefSchema(BaseModel):
|
||||
"""向量化配置引用"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
embedding_type: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class KnowledgeBaseOutSchema(BaseSchema, UserBySchema):
|
||||
"""知识库详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
name: str
|
||||
embedding_config_id: int | None = None
|
||||
collection_name: str
|
||||
document_count: int = 0
|
||||
vector_count: int = 0
|
||||
kb_status: int = 0
|
||||
error_message: str | None = None
|
||||
description: str | None = None
|
||||
file_paths: list[str] | None = None
|
||||
embedding_config: EmbeddingConfigRefSchema | None = None
|
||||
|
||||
|
||||
class KnowledgeBaseQueryParam:
|
||||
"""知识库查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="知识库名称"),
|
||||
embedding_config_id: int | None = Query(None, description="向量化配置ID"),
|
||||
kb_status: int | None = Query(None, description="知识库状态"),
|
||||
) -> None:
|
||||
self.name = ("like", name) if name else None
|
||||
self.embedding_config_id = embedding_config_id
|
||||
self.kb_status = kb_status
|
||||
|
||||
|
||||
# ============== AI模型配置 ==============
|
||||
|
||||
# AI模型类型常量
|
||||
AI_MODEL_TYPES = {
|
||||
'enterprise_naming': '企业起名',
|
||||
'enterprise_renaming': '企业改名',
|
||||
'enterprise_scoring': '企业测名',
|
||||
'enterprise_scoring_trial': '企业测名试用',
|
||||
'personal_naming': '个人起名',
|
||||
'personal_renaming': '个人改名',
|
||||
'personal_scoring': '个人测名',
|
||||
'personal_scoring_trial': '个人测名试用',
|
||||
'yuanfen_hepan': '缘分合盘',
|
||||
'bazi_zeji': '八字择吉',
|
||||
'caiyun_jiexi': '财运解析',
|
||||
'caiyun_jiexi_qiye': '企业财运解析'
|
||||
}
|
||||
|
||||
|
||||
class AIProviderRefSchema(BaseModel):
|
||||
"""供应商引用"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str
|
||||
base_url: str
|
||||
|
||||
|
||||
class AIModelConfigCreateSchema(BaseModel):
|
||||
"""创建 AI模型配置参数"""
|
||||
model_type: str = Field(..., max_length=50, description='模型类型(naming/renaming/scoring/report)')
|
||||
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
|
||||
provider_id: int | None = Field(None, description='AI供应商ID')
|
||||
system_prompt: str | None = Field(None, description='系统提示词')
|
||||
temperature: float = Field(1.0, ge=0, le=2, description='模型温度(0-2)')
|
||||
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
|
||||
|
||||
|
||||
class AIModelConfigUpdateSchema(BaseModel):
|
||||
"""更新 AI模型配置参数"""
|
||||
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
|
||||
provider_id: int | None = Field(None, description='AI供应商ID')
|
||||
system_prompt: str | None = Field(None, description='系统提示词')
|
||||
temperature: float | None = Field(None, ge=0, le=2, description='模型温度(0-2)')
|
||||
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
|
||||
|
||||
|
||||
class AIModelConfigOutSchema(BaseSchema, UserBySchema):
|
||||
"""模型配置详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_type: str
|
||||
model_name: str | None = None
|
||||
provider_id: int | None = None
|
||||
system_prompt: str | None = None
|
||||
temperature: float = 1.0
|
||||
knowledge_base_ids: list[int] | None = None
|
||||
provider: AIProviderRefSchema | None = None
|
||||
|
||||
|
||||
class AIModelConfigQueryParam:
|
||||
"""AI模型配置查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str | None = Query(None, description="模型类型"),
|
||||
) -> None:
|
||||
self.model_type = model_type
|
||||
|
||||
|
||||
# ============== AI模型训练对话 ==============
|
||||
|
||||
class AIModelTrainingMessageCreateSchema(BaseModel):
|
||||
"""创建训练对话消息参数"""
|
||||
model_config_id: int = Field(..., description='模型配置ID')
|
||||
role: str = Field(..., max_length=20, description='角色(user/assistant)')
|
||||
content: str = Field(..., description='消息内容')
|
||||
|
||||
|
||||
class AIModelTrainingMessageOutSchema(BaseSchema, UserBySchema):
|
||||
"""训练对话消息详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config_id: int
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class AIModelTrainingChatSchema(BaseModel):
|
||||
"""训练对话请求参数"""
|
||||
model_type: str = Field(..., description='模型类型')
|
||||
message: str = Field(..., min_length=1, description='用户消息')
|
||||
# 如果配置有变动,先保存配置
|
||||
config_changed: bool = Field(False, description='配置是否有变动')
|
||||
config_data: AIModelConfigUpdateSchema | None = Field(None, description='变动的配置数据')
|
||||
|
||||
|
||||
class AIModelTestSchema(BaseModel):
|
||||
"""起名测试请求参数(用于小程序/外部调用)"""
|
||||
model_type: str = Field(..., description='模型类型(enterprise_naming/personal_naming等)')
|
||||
text: str = Field(..., min_length=1, max_length=4000, description='用户输入的文本')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
import httpx
|
||||
|
||||
from app.config.setting import settings
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""
|
||||
AI客户端类,用于与OpenAI API交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = settings.OPENAI_MODEL
|
||||
# 创建一个不带冲突参数的httpx客户端
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
# 使用自定义的http客户端
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
http_client=self.http_client
|
||||
)
|
||||
|
||||
def _friendly_error_message(self, e: Exception) -> str:
|
||||
"""将 OpenAI 或网络异常转换为友好的中文提示。"""
|
||||
# 尝试获取状态码与错误体
|
||||
status_code = getattr(e, "status_code", None)
|
||||
body = getattr(e, "body", None)
|
||||
message = None
|
||||
error_type = None
|
||||
error_code = None
|
||||
try:
|
||||
if isinstance(body, dict) and "error" in body:
|
||||
err = body.get("error") or {}
|
||||
error_type = err.get("type")
|
||||
error_code = err.get("code")
|
||||
message = err.get("message")
|
||||
except Exception:
|
||||
# 忽略解析失败
|
||||
pass
|
||||
|
||||
text = str(e)
|
||||
msg = message or text
|
||||
|
||||
# 特定错误映射
|
||||
# 欠费/账户状态异常
|
||||
if (error_code == "Arrearage") or (error_type == "Arrearage") or ("in good standing" in (msg or "")):
|
||||
return "账户欠费或结算异常,访问被拒绝。请检查账号状态或更换有效的 API Key。"
|
||||
# 鉴权失败
|
||||
if status_code == 401 or "invalid api key" in msg.lower():
|
||||
return "鉴权失败,API Key 无效或已过期。请检查系统配置中的 API Key。"
|
||||
# 权限不足或被拒绝
|
||||
if status_code == 403 or error_type in {"PermissionDenied", "permission_denied"}:
|
||||
return "访问被拒绝,权限不足或账号受限。请检查账户权限设置。"
|
||||
# 配额不足或限流
|
||||
if status_code == 429 or error_type in {"insufficient_quota", "rate_limit_exceeded"}:
|
||||
return "请求过于频繁或配额已用尽。请稍后重试或提升账户配额。"
|
||||
# 客户端错误
|
||||
if status_code == 400:
|
||||
return f"请求参数错误或服务拒绝:{message or '请检查输入内容。'}"
|
||||
# 服务端错误
|
||||
if status_code in {500, 502, 503, 504}:
|
||||
return "服务暂时不可用,请稍后重试。"
|
||||
|
||||
# 默认兜底
|
||||
return f"处理您的请求时出现错误:{msg}"
|
||||
|
||||
async def process(self, query: str) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
处理查询并返回流式响应
|
||||
|
||||
参数:
|
||||
- query (str): 用户查询。
|
||||
|
||||
返回:
|
||||
- AsyncGenerator[str, None]: 流式响应内容。
|
||||
"""
|
||||
system_prompt = """你是一个有用的AI助手,可以帮助用户回答问题和提供帮助。请用中文回答用户的问题。"""
|
||||
|
||||
try:
|
||||
# 使用 await 调用异步客户端
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 流式返回响应
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误,返回友好提示
|
||||
log.error(f"AI处理查询失败: {str(e)}")
|
||||
yield self._friendly_error_message(e)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
关闭客户端连接
|
||||
"""
|
||||
if hasattr(self, 'client'):
|
||||
await self.client.close()
|
||||
if hasattr(self, 'http_client'):
|
||||
await self.http_client.aclose()
|
||||
@@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
文档处理工具类
|
||||
支持 txt、pdf、doc、docx、md 格式解析
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, BinaryIO
|
||||
|
||||
from langchain.schema import Document
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理器"""
|
||||
|
||||
# 支持的文件扩展名
|
||||
SUPPORTED_EXTENSIONS = {'.txt', '.pdf', '.doc', '.docx', '.md'}
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, filename: str) -> bool:
|
||||
"""检查文件是否支持"""
|
||||
ext = Path(filename).suffix.lower()
|
||||
return ext in cls.SUPPORTED_EXTENSIONS
|
||||
|
||||
@classmethod
|
||||
async def process_upload_file(cls, file: UploadFile) -> List[Document]:
|
||||
"""
|
||||
处理上传的文件
|
||||
|
||||
参数:
|
||||
- file: FastAPI UploadFile 对象
|
||||
|
||||
返回:
|
||||
- 文档列表
|
||||
"""
|
||||
if not file.filename:
|
||||
log.warning("文件名为空,跳过处理")
|
||||
return []
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
|
||||
if not cls.is_supported(file.filename):
|
||||
log.warning(f"不支持的文件类型: {ext}")
|
||||
return []
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
await file.seek(0) # 重置文件指针
|
||||
|
||||
# 根据文件类型处理
|
||||
try:
|
||||
if ext == '.txt':
|
||||
return cls._process_txt(content, file.filename)
|
||||
elif ext == '.md':
|
||||
return cls._process_markdown(content, file.filename)
|
||||
elif ext == '.pdf':
|
||||
return await cls._process_pdf(content, file.filename)
|
||||
elif ext in {'.doc', '.docx'}:
|
||||
return await cls._process_word(content, file.filename, ext)
|
||||
else:
|
||||
log.warning(f"未知的文件类型: {ext}")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理文件失败: {file.filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_txt(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 TXT 文件"""
|
||||
try:
|
||||
# 尝试不同编码
|
||||
text = None
|
||||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||||
try:
|
||||
text = content.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
log.error(f"无法解码文件: {filename}")
|
||||
return []
|
||||
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "txt"}
|
||||
)]
|
||||
except Exception as e:
|
||||
log.error(f"处理 TXT 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_markdown(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 Markdown 文件"""
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "markdown"}
|
||||
)]
|
||||
except Exception as e:
|
||||
log.error(f"处理 Markdown 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def _process_pdf(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 PDF 文件"""
|
||||
try:
|
||||
# 使用 pypdf 或 pdfplumber 处理 PDF
|
||||
import pypdf
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(content)
|
||||
reader = pypdf.PdfReader(pdf_file)
|
||||
|
||||
documents = []
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
text = page.extract_text()
|
||||
if text and text.strip():
|
||||
documents.append(Document(
|
||||
page_content=text,
|
||||
metadata={
|
||||
"source": filename,
|
||||
"type": "pdf",
|
||||
"page": page_num + 1
|
||||
}
|
||||
))
|
||||
|
||||
log.info(f"PDF 文件处理完成: {filename}, 共 {len(documents)} 页")
|
||||
return documents
|
||||
|
||||
except ImportError:
|
||||
log.error("未安装 pypdf 库,请运行: pip install pypdf")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def _process_word(cls, content: bytes, filename: str, ext: str) -> List[Document]:
|
||||
"""处理 Word 文件 (doc/docx)"""
|
||||
try:
|
||||
if ext == '.docx':
|
||||
return cls._process_docx(content, filename)
|
||||
else:
|
||||
# .doc 格式需要特殊处理
|
||||
return cls._process_doc(content, filename)
|
||||
except Exception as e:
|
||||
log.error(f"处理 Word 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_docx(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 DOCX 文件"""
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
from io import BytesIO
|
||||
|
||||
docx_file = BytesIO(content)
|
||||
doc = DocxDocument(docx_file)
|
||||
|
||||
# 提取所有段落文本
|
||||
paragraphs = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
paragraphs.append(para.text)
|
||||
|
||||
text = '\n'.join(paragraphs)
|
||||
|
||||
if text.strip():
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "docx"}
|
||||
)]
|
||||
return []
|
||||
|
||||
except ImportError:
|
||||
log.error("未安装 python-docx 库,请运行: pip install python-docx")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_doc(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""
|
||||
处理 DOC 文件 (旧版 Word 格式)
|
||||
注意: .doc 格式处理需要额外依赖,这里做简单提示
|
||||
"""
|
||||
try:
|
||||
# 尝试使用 antiword 或 textract
|
||||
# 如果没有安装,建议用户转换为 docx 格式
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.doc', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# 尝试使用 antiword
|
||||
result = subprocess.run(
|
||||
['antiword', tmp_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return [Document(
|
||||
page_content=result.stdout,
|
||||
metadata={"source": filename, "type": "doc"}
|
||||
)]
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
log.warning(f"无法处理 .doc 文件: {filename},建议转换为 .docx 格式")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"处理 DOC 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def process_files(cls, files: List[UploadFile]) -> List[Document]:
|
||||
"""
|
||||
批量处理上传的文件
|
||||
|
||||
参数:
|
||||
- files: UploadFile 列表
|
||||
|
||||
返回:
|
||||
- 所有文档列表
|
||||
"""
|
||||
all_documents = []
|
||||
|
||||
for file in files:
|
||||
if file.filename and cls.is_supported(file.filename):
|
||||
docs = await cls.process_upload_file(file)
|
||||
all_documents.extend(docs)
|
||||
log.info(f"文件处理完成: {file.filename}, 提取 {len(docs)} 个文档")
|
||||
else:
|
||||
log.warning(f"跳过不支持的文件: {file.filename}")
|
||||
|
||||
log.info(f"批量处理完成,共 {len(all_documents)} 个文档")
|
||||
return all_documents
|
||||
@@ -0,0 +1,239 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量化工具类
|
||||
支持本地和远程 embedding 模型
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import chromadb
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_chroma import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.schema import Document
|
||||
|
||||
from app.core.logger import log
|
||||
from app.config.path_conf import BASE_DIR
|
||||
|
||||
|
||||
# ChromaDB 持久化目录
|
||||
CHROMA_PERSIST_DIR = BASE_DIR / "data" / "chroma_db"
|
||||
|
||||
# 全局 ChromaDB 客户端实例(单例模式)
|
||||
_chroma_client = None
|
||||
|
||||
|
||||
def get_chroma_client():
|
||||
"""获取全局 ChromaDB 客户端实例(单例模式)"""
|
||||
global _chroma_client
|
||||
if _chroma_client is None:
|
||||
# 确保持久化目录存在
|
||||
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_chroma_client = chromadb.PersistentClient(
|
||||
path=str(CHROMA_PERSIST_DIR)
|
||||
)
|
||||
return _chroma_client
|
||||
|
||||
|
||||
class EmbeddingUtil:
|
||||
"""向量化工具类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type: int = 0,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化向量化工具
|
||||
|
||||
参数:
|
||||
- embedding_type: 0=本地, 1=远程
|
||||
- model_name: Embedding模型名称
|
||||
- base_url: 远程接口地址(远程模式必填)
|
||||
- api_key: 远程API Key(远程模式必填)
|
||||
"""
|
||||
self.embedding_type = embedding_type
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
# 初始化 embedding 模型
|
||||
self._embeddings = None
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""延迟加载 embedding 模型"""
|
||||
if self._embeddings is None:
|
||||
self._embeddings = self._create_embeddings()
|
||||
return self._embeddings
|
||||
|
||||
def _create_embeddings(self):
|
||||
"""创建 embedding 模型实例"""
|
||||
if self.embedding_type == 0:
|
||||
# 本地模式 - 使用 sentence-transformers
|
||||
# sentence-transformers==3.3.1
|
||||
# 本地Embedding模型(可选)手动安装 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
try:
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
log.info(f"使用本地 Embedding 模型: {self.model_name}")
|
||||
return HuggingFaceEmbeddings(
|
||||
model_name=self.model_name,
|
||||
model_kwargs={'device': 'cpu'},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"加载本地 Embedding 模型失败: {e}")
|
||||
raise
|
||||
else:
|
||||
# 远程模式 - 使用 OpenAI 兼容接口
|
||||
if not self.base_url or not self.api_key:
|
||||
raise ValueError("远程模式必须提供 base_url 和 api_key")
|
||||
|
||||
# 自动拼接 /v1 路径(如果未包含)
|
||||
api_base = self.base_url.rstrip('/')
|
||||
if not api_base.endswith('/v1'):
|
||||
api_base = f"{api_base}/v1"
|
||||
|
||||
log.info(f"使用远程 Embedding 模型: {self.model_name}, URL: {api_base}")
|
||||
return OpenAIEmbeddings(
|
||||
model=self.model_name,
|
||||
base_url=api_base,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
def get_vector_store(self, collection_name: str) -> Chroma:
|
||||
"""
|
||||
获取或创建向量存储
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- Chroma 向量存储实例
|
||||
"""
|
||||
# 确保目录存在
|
||||
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return Chroma(
|
||||
collection_name=collection_name,
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory=str(CHROMA_PERSIST_DIR),
|
||||
client=get_chroma_client(),
|
||||
)
|
||||
|
||||
def split_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200
|
||||
) -> List[Document]:
|
||||
"""
|
||||
分割文档
|
||||
|
||||
参数:
|
||||
- documents: 文档列表
|
||||
- chunk_size: 分块大小
|
||||
- chunk_overlap: 分块重叠大小
|
||||
|
||||
返回:
|
||||
- 分割后的文档列表
|
||||
"""
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
|
||||
)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
collection_name: str,
|
||||
documents: List[Document]
|
||||
) -> int:
|
||||
"""
|
||||
添加文档到向量存储
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
- documents: 文档列表
|
||||
|
||||
返回:
|
||||
- 添加的向量数量
|
||||
"""
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
# 分割文档
|
||||
split_docs = self.split_documents(documents)
|
||||
log.info(f"文档分割完成,共 {len(split_docs)} 个片段")
|
||||
|
||||
# 获取向量存储
|
||||
vector_store = self.get_vector_store(collection_name)
|
||||
|
||||
# 添加文档
|
||||
vector_store.add_documents(split_docs)
|
||||
log.info(f"向量存储完成,集合: {collection_name}, 向量数: {len(split_docs)}")
|
||||
|
||||
return len(split_docs)
|
||||
|
||||
def delete_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
删除集合
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- 是否删除成功
|
||||
"""
|
||||
try:
|
||||
client = get_chroma_client()
|
||||
client.delete_collection(collection_name)
|
||||
log.info(f"删除集合成功: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"删除集合失败: {collection_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
k: int = 4
|
||||
) -> List[Document]:
|
||||
"""
|
||||
相似度搜索
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
- query: 查询文本
|
||||
- k: 返回结果数量
|
||||
|
||||
返回:
|
||||
- 相似文档列表
|
||||
"""
|
||||
vector_store = self.get_vector_store(collection_name)
|
||||
return vector_store.similarity_search(query, k=k)
|
||||
|
||||
def get_collection_count(self, collection_name: str) -> int:
|
||||
"""
|
||||
获取集合中的向量数量
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- 向量数量
|
||||
"""
|
||||
try:
|
||||
client = get_chroma_client()
|
||||
collection = client.get_collection(collection_name)
|
||||
return collection.count()
|
||||
except Exception as e:
|
||||
log.warning(f"获取集合数量失败: {collection_name}, 错误: {e}")
|
||||
return 0
|
||||
Reference in New Issue
Block a user