upload project source code

This commit is contained in:
2026-04-30 18:49:43 +08:00
commit 9b394ba682
2277 changed files with 660945 additions and 0 deletions

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,537 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter, Depends, Path, Body, WebSocket, Form, UploadFile, File
from fastapi.responses import JSONResponse, StreamingResponse
from app.common.response import StreamResponse, SuccessResponse
from app.common.request import PaginationService
from app.core.base_params import PaginationQueryParam
from app.core.dependencies import AuthPermission
from app.core.logger import log
from app.api.v1.module_system.auth.schema import AuthSchema
from app.core.router_class import OperationLogRoute
from .service import McpService, AIProviderService, EmbeddingConfigService, KnowledgeBaseService, AIModelConfigService, AIModelTrainingService, AIModelTestService
from .schema import (
McpCreateSchema, McpUpdateSchema, ChatQuerySchema, McpQueryParam,
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderQueryParam,
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigQueryParam,
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseQueryParam,
AIModelConfigUpdateSchema, AIModelConfigQueryParam, AIModelTrainingChatSchema,
AIModelTestSchema
)
AIRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai", tags=["MCP智能助手"])
AIProviderRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/provider", tags=["AI供应商配置"])
EmbeddingConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/embedding", tags=["向量化配置"])
KnowledgeBaseRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/knowledge", tags=["知识库管理"])
AIModelConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/model", tags=["AI模型配置"])
@AIRouter.post("/chat", summary="智能对话", description="与MCP智能助手进行对话")
async def chat_controller(
query: ChatQuerySchema,
auth: AuthSchema = Depends(AuthPermission())
) -> StreamingResponse:
"""
智能对话接口
参数:
- query (ChatQuerySchema): 聊天查询模型
返回:
- StreamingResponse: 流式响应,每次返回一个聊天响应
"""
user_name = auth.user.name if auth.user else "未知用户"
log.info(f"用户 {user_name} 发起智能对话: {query.message[:50]}...")
async def generate_response():
try:
async for chunk in McpService.chat_query(query=query):
# 确保返回的是字节串
if chunk:
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
except Exception as e:
log.error(f"流式响应出错: {str(e)}")
yield f"抱歉,处理您的请求时出现了错误: {str(e)}".encode('utf-8')
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
@AIRouter.get("/detail/{id}", summary="获取 MCP 服务器详情", description="获取 MCP 服务器详情")
async def detail_controller(
id: int = Path(..., description="MCP ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
获取 MCP 服务器详情接口
参数:
- id (int): MCP 服务器ID
返回:
- JSONResponse: 包含 MCP 服务器详情的 JSON 响应
"""
result_dict = await McpService.detail_service(auth=auth, id=id)
log.info(f"获取 MCP 服务器详情成功 {id}")
return SuccessResponse(data=result_dict, msg="获取 MCP 服务器详情成功")
@AIRouter.get("/list", summary="查询 MCP 服务器列表", description="查询 MCP 服务器列表")
async def list_controller(
page: PaginationQueryParam = Depends(),
search: McpQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
查询 MCP 服务器列表接口
参数:
- page (PaginationQueryParam): 分页查询参数模型
- search (McpQueryParam): 查询参数模型
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含 MCP 服务器列表的 JSON 响应
"""
result_dict_list = await McpService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
log.info(f"查询 MCP 服务器列表成功")
return SuccessResponse(data=result_dict, msg="查询 MCP 服务器列表成功")
@AIRouter.post("/create", summary="创建 MCP 服务器", description="创建 MCP 服务器")
async def create_controller(
data: McpCreateSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
创建 MCP 服务器接口
参数:
- data (McpCreateSchema): 创建 MCP 服务器模型
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含创建 MCP 服务器结果的 JSON 响应
"""
result_dict = await McpService.create_service(auth=auth, data=data)
log.info(f"创建 MCP 服务器成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建 MCP 服务器成功")
@AIRouter.put("/update/{id}", summary="修改 MCP 服务器", description="修改 MCP 服务器")
async def update_controller(
data: McpUpdateSchema,
id: int = Path(..., description="MCP ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
修改 MCP 服务器接口
参数:
- data (McpUpdateSchema): 修改 MCP 服务器模型
- id (int): MCP 服务器ID
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含修改 MCP 服务器结果的 JSON 响应
"""
result_dict = await McpService.update_service(auth=auth, id=id, data=data)
log.info(f"修改 MCP 服务器成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改 MCP 服务器成功")
@AIRouter.delete("/delete", summary="删除 MCP 服务器", description="删除 MCP 服务器")
async def delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
删除 MCP 服务器接口
参数:
- ids (list[int]): MCP 服务器ID列表
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含删除 MCP 服务器结果的 JSON 响应
"""
await McpService.delete_service(auth=auth, ids=ids)
log.info(f"删除 MCP 服务器成功: {ids}")
return SuccessResponse(msg="删除 MCP 服务器成功")
@AIRouter.websocket("/ws/chat", name="WebSocket聊天")
async def websocket_chat_controller(
websocket: WebSocket,
):
"""
WebSocket聊天接口
ws://127.0.0.1:8001/api/v1/ai/mcp/ws/chat
"""
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
# 流式发送响应
try:
async for chunk in McpService.chat_query(query=ChatQuerySchema(message=data)):
if chunk:
await websocket.send_text(chunk)
except Exception as e:
log.error(f"处理聊天查询出错: {str(e)}")
await websocket.send_text(f"抱歉,处理您的请求时出现了错误: {str(e)}")
except Exception as e:
log.error(f"WebSocket聊天出错: {str(e)}")
finally:
await websocket.close()
@AIModelConfigRouter.post("/test", summary="起名测试")
async def naming_test_controller(
test_data: AIModelTestSchema,
) -> JSONResponse:
"""
起名测试接口(用于小程序/外部调用)
- 读取模型配置和训练信息
- 内部流式调用AI供应商
- 组合响应后一次性返回
- 仅读数据,不保存对话记录
参数:
- model_type: 模型类型(enterprise_naming/personal_naming等)
- text: 用户输入的文本
返回:
- data: AI的完整响应文本
"""
log.info(f"[起名测试] 收到请求: model_type={test_data.model_type}, text={test_data.text[:50]}...")
result = await AIModelTestService.test_naming(
model_type=test_data.model_type,
text=test_data.text
)
log.info(f"[起名测试] 响应完成,长度: {len(result)}")
return SuccessResponse(data=result, msg="测试成功")
# ============== AI供应商配置 ==============
@AIProviderRouter.get("/detail/{id}", summary="获取AI供应商详情")
async def provider_detail_controller(
id: int = Path(..., description="AI供应商ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIProviderService.detail_service(auth=auth, id=id)
return SuccessResponse(data=result_dict, msg="获取AI供应商详情成功")
@AIProviderRouter.get("/list", summary="查询AI供应商列表")
async def provider_list_controller(
page: PaginationQueryParam = Depends(),
search: AIProviderQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await AIProviderService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询AI供应商列表成功")
@AIProviderRouter.post("/create", summary="创建AI供应商")
async def provider_create_controller(
data: AIProviderCreateSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIProviderService.create_service(auth=auth, data=data)
log.info(f"创建AI供应商成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建AI供应商成功")
@AIProviderRouter.put("/update/{id}", summary="修改AI供应商")
async def provider_update_controller(
data: AIProviderUpdateSchema,
id: int = Path(..., description="AI供应商ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIProviderService.update_service(auth=auth, id=id, data=data)
log.info(f"修改AI供应商成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改AI供应商成功")
@AIProviderRouter.delete("/delete", summary="删除AI供应商")
async def provider_delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await AIProviderService.delete_service(auth=auth, ids=ids)
log.info(f"删除AI供应商成功: {ids}")
return SuccessResponse(msg="删除AI供应商成功")
# ============== 向量化配置 ==============
@EmbeddingConfigRouter.get("/detail/{id}", summary="获取向量化配置详情")
async def embedding_detail_controller(
id: int = Path(..., description="向量化配置ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await EmbeddingConfigService.detail_service(auth=auth, id=id)
return SuccessResponse(data=result_dict, msg="获取向量化配置详情成功")
@EmbeddingConfigRouter.get("/list", summary="查询向量化配置列表")
async def embedding_list_controller(
page: PaginationQueryParam = Depends(),
search: EmbeddingConfigQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await EmbeddingConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询向量化配置列表成功")
@EmbeddingConfigRouter.post("/create", summary="创建向量化配置")
async def embedding_create_controller(
data: EmbeddingConfigCreateSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await EmbeddingConfigService.create_service(auth=auth, data=data)
log.info(f"创建向量化配置成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建向量化配置成功")
@EmbeddingConfigRouter.put("/update/{id}", summary="修改向量化配置")
async def embedding_update_controller(
data: EmbeddingConfigUpdateSchema,
id: int = Path(..., description="向量化配置ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await EmbeddingConfigService.update_service(auth=auth, id=id, data=data)
log.info(f"修改向量化配置成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改向量化配置成功")
@EmbeddingConfigRouter.delete("/delete", summary="删除向量化配置")
async def embedding_delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await EmbeddingConfigService.delete_service(auth=auth, ids=ids)
log.info(f"删除向量化配置成功: {ids}")
return SuccessResponse(msg="删除向量化配置成功")
# ============== 知识库管理 ==============
@KnowledgeBaseRouter.get("/detail/{id}", summary="获取知识库详情")
async def knowledge_detail_controller(
id: int = Path(..., description="知识库ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await KnowledgeBaseService.detail_service(auth=auth, id=id)
return SuccessResponse(data=result_dict, msg="获取知识库详情成功")
@KnowledgeBaseRouter.get("/list", summary="查询知识库列表")
async def knowledge_list_controller(
page: PaginationQueryParam = Depends(),
search: KnowledgeBaseQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await KnowledgeBaseService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询知识库列表成功")
@KnowledgeBaseRouter.post("/create", summary="创建知识库")
async def knowledge_create_controller(
name: str = Form(..., max_length=100, description="知识库名称"),
embedding_config_id: int | None = Form(None, description="向量化配置ID"),
description: str | None = Form(None, max_length=255, description="备注"),
files: list[UploadFile] = File(default=[], description="上传的文件列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
# 构造数据对象
data = KnowledgeBaseCreateSchema(name=name, embedding_config_id=embedding_config_id, description=description)
result_dict = await KnowledgeBaseService.create_service(auth=auth, data=data, files=files)
log.info(f"创建知识库成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建知识库成功")
@KnowledgeBaseRouter.put("/update/{id}", summary="修改知识库")
async def knowledge_update_controller(
data: KnowledgeBaseUpdateSchema,
id: int = Path(..., description="知识库ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await KnowledgeBaseService.update_service(auth=auth, id=id, data=data)
log.info(f"修改知识库成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改知识库成功")
@KnowledgeBaseRouter.delete("/delete", summary="删除知识库")
async def knowledge_delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await KnowledgeBaseService.delete_service(auth=auth, ids=ids)
log.info(f"删除知识库成功: {ids}")
return SuccessResponse(msg="删除知识库成功")
@KnowledgeBaseRouter.post("/retry/{id}", summary="重试向量化")
async def knowledge_retry_controller(
id: int = Path(..., description="知识库ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await KnowledgeBaseService.retry_service(auth=auth, id=id)
log.info(f"重试向量化成功: {id}")
return SuccessResponse(data=result_dict, msg="已启动重新向量化")
# ============== AI模型配置 ==============
@AIModelConfigRouter.get("/detail/{model_type}", summary="获取AI模型配置详情")
async def model_config_detail_controller(
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIModelConfigService.detail_service(auth=auth, model_type=model_type)
return SuccessResponse(data=result_dict, msg="获取AI模型配置详情成功")
@AIModelConfigRouter.get("/list", summary="查询AI模型配置列表")
async def model_config_list_controller(
page: PaginationQueryParam = Depends(),
search: AIModelConfigQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await AIModelConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询AI模型配置列表成功")
@AIModelConfigRouter.put("/update/{model_type}", summary="更新AI模型配置")
async def model_config_update_controller(
data: AIModelConfigUpdateSchema,
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIModelConfigService.update_service(auth=auth, model_type=model_type, data=data)
log.info(f"更新AI模型配置成功: {model_type}")
return SuccessResponse(data=result_dict, msg="更新AI模型配置成功")
@AIModelConfigRouter.get("/available-models/{provider_id}", summary="获取可用模型列表")
async def available_models_controller(
provider_id: int = Path(..., description="AI供应商ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_list = await AIModelConfigService.get_available_models(auth=auth, provider_id=provider_id)
return SuccessResponse(data=result_list, msg="获取可用模型列表成功")
# ============== AI模型训练对话 ==============
@AIModelConfigRouter.get("/messages/{model_type}", summary="获取训练对话记录")
async def training_messages_controller(
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_list = await AIModelTrainingService.get_messages_service(auth=auth, model_type=model_type)
return SuccessResponse(data=result_list, msg="获取训练对话记录成功")
@AIModelConfigRouter.delete("/message/{message_id}", summary="删除单条训练对话记录")
async def delete_training_message_controller(
message_id: int = Path(..., description="对话记录ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await AIModelTrainingService.delete_message_service(auth=auth, message_id=message_id)
log.info(f"删除训练对话记录成功: {message_id}")
return SuccessResponse(msg="删除训练对话记录成功")
@AIModelConfigRouter.delete("/messages/{model_type}", summary="清空训练对话记录")
async def clear_training_messages_controller(
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await AIModelTrainingService.clear_messages_service(auth=auth, model_type=model_type)
log.info(f"清空训练对话记录成功: {model_type}")
return SuccessResponse(msg="清空训练对话记录成功")
@AIModelConfigRouter.post("/chat", summary="训练对话")
async def training_chat_controller(
chat_data: AIModelTrainingChatSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> StreamingResponse:
"""
训练对话接口,流式输出
"""
log.info(f"[训练对话] 收到请求: model_type={chat_data.model_type}, message={chat_data.message[:50]}...")
async def generate_response():
log.info("[训练对话] 开始生成响应")
chunk_count = 0
try:
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
if chunk:
chunk_count += 1
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
log.info(f"[训练对话] 响应生成完成,共 {chunk_count} 个 chunk")
except Exception as e:
log.error(f"训练对话出错: {str(e)}")
yield f"处理您的请求时出现了错误: {str(e)}".encode('utf-8')
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
@AIModelConfigRouter.websocket("/ws/chat")
async def training_websocket_chat_controller(
websocket: WebSocket,
):
"""
训练对话 WebSocket 接口
"""
from app.core.security import verify_token
from app.api.v1.module_system.auth.schema import AuthSchema
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
model_type = data.get('model_type', 'naming')
message = data.get('message', '')
config_changed = data.get('config_changed', False)
config_data = data.get('config_data')
# 构建chat_data
chat_data = AIModelTrainingChatSchema(
model_type=model_type,
message=message,
config_changed=config_changed,
config_data=AIModelConfigUpdateSchema(**config_data) if config_data else None
)
# 构建简化的authWebSocket没有正常的认证流程实际使用时需要实现认证
auth = AuthSchema()
# 流式发送响应
try:
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
if chunk:
await websocket.send_text(chunk)
except Exception as e:
log.error(f"处理训练对话出错: {str(e)}")
await websocket.send_text(f"处理您的请求时出现了错误: {str(e)}")
except Exception as e:
log.error(f"WebSocket训练对话出错: {str(e)}")
finally:
await websocket.close()

View File

@@ -0,0 +1,270 @@
# -*- coding: utf-8 -*-
from typing import Sequence, Any
from app.core.base_crud import CRUDBase
from app.api.v1.module_system.auth.schema import AuthSchema
from .model import McpModel, AIProviderModel, EmbeddingConfigModel, KnowledgeBaseModel, AIModelConfigModel, AIModelTrainingMessageModel
from .schema import (
McpCreateSchema, McpUpdateSchema,
AIProviderCreateSchema, AIProviderUpdateSchema,
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema,
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema,
AIModelConfigCreateSchema, AIModelConfigUpdateSchema,
AIModelTrainingMessageCreateSchema
)
class McpCRUD(CRUDBase[McpModel, McpCreateSchema, McpUpdateSchema]):
"""MCP 服务器数据层"""
def __init__(self, auth: AuthSchema) -> None:
"""
初始化CRUD
参数:
- auth (AuthSchema): 认证信息模型
"""
self.auth = auth
super().__init__(model=McpModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> McpModel | None:
"""
获取MCP服务器详情
参数:
- id (int): MCP服务器ID
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
返回:
- McpModel | None: MCP服务器模型实例如果存在
"""
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> McpModel | None:
"""
通过名称获取MCP服务器
参数:
- name (str): MCP服务器名称
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
返回:
- Optional[McpModel]: MCP服务器模型实例如果存在
"""
return await self.get(name=name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[McpModel]:
"""
列表查询MCP服务器
参数:
- search (dict | None): 查询参数字典
- order_by (list[dict[str, str]] | None): 排序参数列表
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
返回:
- Sequence[McpModel]: MCP服务器模型实例序列
"""
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: McpCreateSchema) -> McpModel | None:
"""
创建MCP服务器
参数:
- data (McpCreateSchema): 创建MCP服务器模型
返回:
- Optional[McpModel]: 创建的MCP服务器模型实例如果成功
"""
return await self.create(data=data)
async def update_crud(self, id: int, data: McpUpdateSchema) -> McpModel | None:
"""
更新MCP服务器
参数:
- id (int): MCP服务器ID
- data (McpUpdateSchema): 更新MCP服务器模型
返回:
- McpModel | None: 更新的MCP服务器模型实例如果成功
"""
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
"""
批量删除MCP服务器
参数:
- ids (list[int]): MCP服务器ID列表
返回:
- None
"""
return await self.delete(ids=ids)
class AIProviderCRUD(CRUDBase[AIProviderModel, AIProviderCreateSchema, AIProviderUpdateSchema]):
"""AI供应商数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=AIProviderModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIProviderModel | None:
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> AIProviderModel | None:
return await self.get(name=name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIProviderModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: AIProviderCreateSchema) -> AIProviderModel | None:
return await self.create(data=data)
async def update_crud(self, id: int, data: AIProviderUpdateSchema) -> AIProviderModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
async def clear_default_crud(self) -> None:
"""清除所有默认设置"""
obj_list = await self.list(search={'is_default': 1})
for obj in obj_list:
await self.update(id=obj.id, data=AIProviderUpdateSchema(
name=obj.name,
provider_type=obj.provider_type,
base_url=obj.base_url,
api_key=obj.api_key,
is_default=0,
description=obj.description
))
class EmbeddingConfigCRUD(CRUDBase[EmbeddingConfigModel, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema]):
"""向量化配置数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=EmbeddingConfigModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
return await self.get(name=name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[EmbeddingConfigModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: EmbeddingConfigCreateSchema) -> EmbeddingConfigModel | None:
return await self.create(data=data)
async def update_crud(self, id: int, data: EmbeddingConfigUpdateSchema) -> EmbeddingConfigModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
async def clear_default_crud(self) -> None:
"""清除所有默认设置"""
obj_list = await self.list(search={'is_default': 1})
for obj in obj_list:
await self.update(id=obj.id, data=EmbeddingConfigUpdateSchema(
name=obj.name,
embedding_type=obj.embedding_type,
model_name=obj.model_name,
base_url=obj.base_url,
api_key=obj.api_key,
is_default=0,
description=obj.description
))
class KnowledgeBaseCRUD(CRUDBase[KnowledgeBaseModel, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema]):
"""知识库数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=KnowledgeBaseModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
return await self.get(name=name, preload=preload)
async def get_by_collection_crud(self, collection_name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
return await self.get(collection_name=collection_name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[KnowledgeBaseModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'desc'}], preload=preload)
async def create_crud(self, data: dict) -> KnowledgeBaseModel | None:
"""create方法接受字典因为需要额外字段"""
return await self.create(data=data)
async def update_crud(self, id: int, data: KnowledgeBaseUpdateSchema | dict) -> KnowledgeBaseModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
class AIModelConfigCRUD(CRUDBase[AIModelConfigModel, AIModelConfigCreateSchema, AIModelConfigUpdateSchema]):
"""AI模型配置数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=AIModelConfigModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
return await self.get(id=id, preload=preload)
async def get_by_model_type_crud(self, model_type: str, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
return await self.get(model_type=model_type, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIModelConfigModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: AIModelConfigCreateSchema | dict) -> AIModelConfigModel | None:
return await self.create(data=data)
async def update_crud(self, id: int, data: AIModelConfigUpdateSchema | dict) -> AIModelConfigModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
class AIModelTrainingMessageCRUD(CRUDBase[AIModelTrainingMessageModel, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageCreateSchema]):
"""AI模型训练对话数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=AIModelTrainingMessageModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelTrainingMessageModel | None:
return await self.get(id=id, preload=preload)
async def get_list_by_config_crud(self, model_config_id: int, preload: list[str | Any] | None = None) -> Sequence[AIModelTrainingMessageModel]:
"""获取某个模型配置的所有训练对话"""
return await self.list(search={'model_config_id': model_config_id}, order_by=[{'id': 'asc'}], preload=preload)
async def create_crud(self, data: AIModelTrainingMessageCreateSchema | dict) -> AIModelTrainingMessageModel | None:
return await self.create(data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
async def delete_by_config_crud(self, model_config_id: int) -> None:
"""删除某个模型配置的所有训练对话"""
messages = await self.list(search={'model_config_id': model_config_id})
if messages:
ids = [m.id for m in messages]
await self.delete(ids=ids)

View File

@@ -0,0 +1,155 @@
'''
Author: caoziyuan ziyuan.cao@zhuying.com
Date: 2025-12-22 17:42:10
LastEditors: caoziyuan ziyuan.cao@zhuying.com
LastEditTime: 2025-12-22 18:03:28
FilePath: \naming-backend\app\api\v1\module_application\ai\model.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
# -*- coding: utf-8 -*-
from sqlalchemy import JSON, String, Integer, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.base_model import ModelMixin, UserMixin
class McpModel(ModelMixin, UserMixin):
"""
MCP 服务器表
MCP类型:
- 0: stdio (标准输入输出)
- 1: sse (Server-Sent Events)
"""
__tablename__: str = 'app_ai_mcp'
__table_args__: dict[str, str] = ({'comment': 'MCP 服务器表'})
__loader_options__: list[str] = ["created_by", "updated_by"]
name: Mapped[str] = mapped_column(String(50), comment='MCP 名称')
type: Mapped[int] = mapped_column(Integer, default=0, comment='MCP 类型(0:stdio 1:sse)')
url: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程 SSE 地址')
command: Mapped[str | None] = mapped_column(String(255), default=None, comment='MCP 命令')
args: Mapped[str | None] = mapped_column(String(255), default=None, comment='MCP 命令参数')
env: Mapped[dict[str, str] | None] = mapped_column(JSON(), default=None, comment='MCP 环境变量')
class AIProviderModel(ModelMixin, UserMixin):
"""
AI供应商配置表
存储AI服务供应商的接口地址和API Key
"""
__tablename__: str = 'app_ai_provider'
__table_args__: dict[str, str] = ({'comment': 'AI供应商配置表'})
__loader_options__: list[str] = ["created_by", "updated_by"]
name: Mapped[str] = mapped_column(String(50), nullable=False, comment='供应商名称')
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
base_url: Mapped[str] = mapped_column(String(255), nullable=False, comment='接口地址BaseURL')
api_key: Mapped[str] = mapped_column(String(255), nullable=False, comment='API Key')
is_default: Mapped[int] = mapped_column(Integer, default=0, comment='是否默认供应商(0:否 1:是)')
class EmbeddingConfigModel(ModelMixin, UserMixin):
"""
知识库向量化配置表
支持本地或远程向量化服务
"""
__tablename__: str = 'app_ai_embedding_config'
__table_args__: dict[str, str] = ({'comment': '知识库向量化配置表'})
__loader_options__: list[str] = ["created_by", "updated_by"]
name: Mapped[str] = mapped_column(String(50), nullable=False, comment='配置名称')
embedding_type: Mapped[int] = mapped_column(Integer, default=0, comment='向量化类型(0:本地 1:远程)')
model_name: Mapped[str] = mapped_column(String(100), nullable=False, comment='Embedding模型名称')
base_url: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程接口地址(远程模式必填)')
api_key: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程API Key(远程模式必填)')
is_default: Mapped[int] = mapped_column(Integer, default=0, comment='是否默认配置(0:否 1:是)')
class KnowledgeBaseModel(ModelMixin, UserMixin):
"""
知识库表
存储知识库信息,关联向量化配置
"""
__tablename__: str = 'app_ai_knowledge_base'
__table_args__: dict[str, str] = ({'comment': '知识库表'})
__loader_options__: list[str] = ["created_by", "updated_by", "embedding_config"]
name: Mapped[str] = mapped_column(String(100), nullable=False, comment='知识库名称')
embedding_config_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey('app_ai_embedding_config.id', ondelete="SET NULL", onupdate="CASCADE"),
default=None,
nullable=True,
index=True,
comment='向量化配置ID'
)
collection_name: Mapped[str] = mapped_column(String(100), nullable=False, comment='ChromaDB集合名称')
document_count: Mapped[int] = mapped_column(Integer, default=0, comment='文档数量')
vector_count: Mapped[int] = mapped_column(Integer, default=0, comment='向量数量')
kb_status: Mapped[int] = mapped_column(Integer, default=0, comment='知识库状态(0:待处理 1:处理中 2:已完成 3:处理失败)')
error_message: Mapped[str | None] = mapped_column(Text, default=None, comment='错误信息')
file_paths: Mapped[list[str] | None] = mapped_column(JSON, default=None, comment='文件路径列表')
# 关联关系
embedding_config: Mapped["EmbeddingConfigModel | None"] = relationship(
"EmbeddingConfigModel",
lazy="selectin",
foreign_keys=[embedding_config_id],
uselist=False
)
class AIModelConfigModel(ModelMixin, UserMixin):
"""
AI模型配置表
存储不同类型AI模型的配置信息
模型类型:
- enterprise_naming(企业起名), enterprise_renaming(企业改名), enterprise_scoring(企业测名), enterprise_scoring_trial(企业测名试用)
- personal_naming(个人起名), personal_renaming(个人改名), personal_scoring(个人测名), personal_scoring_trial(个人测名试用)
"""
__tablename__: str = 'app_ai_model_config'
__table_args__: dict[str, str] = ({'comment': 'AI模型配置表'})
__loader_options__: list[str] = ["created_by", "updated_by", "provider", "knowledge_bases"]
model_type: Mapped[str] = mapped_column(String(50), nullable=False, unique=True, comment='模型类型(enterprise_naming/enterprise_renaming/enterprise_scoring/enterprise_scoring_trial/personal_naming/personal_renaming/personal_scoring/personal_scoring_trial)')
model_name: Mapped[str | None] = mapped_column(String(100), default=None, comment='使用的模型名称')
provider_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey('app_ai_provider.id', ondelete="SET NULL", onupdate="CASCADE"),
default=None,
nullable=True,
index=True,
comment='AI供应商ID'
)
system_prompt: Mapped[str | None] = mapped_column(Text, default=None, comment='系统提示词')
temperature: Mapped[float] = mapped_column(default=1.0, comment='模型温度(0-2)')
knowledge_base_ids: Mapped[list[int] | None] = mapped_column(JSON, default=None, comment='关联的知识库ID列表')
# 关联关系
provider: Mapped["AIProviderModel | None"] = relationship(
"AIProviderModel",
lazy="selectin",
foreign_keys=[provider_id],
uselist=False
)
class AIModelTrainingMessageModel(ModelMixin, UserMixin):
"""
AI模型训练对话记录表
存储训练对话的历史记录
"""
__tablename__: str = 'app_ai_model_training_message'
__table_args__: dict[str, str] = ({'comment': 'AI模型训练对话记录表'})
__loader_options__: list[str] = ["created_by", "updated_by"]
model_config_id: Mapped[int] = mapped_column(
Integer,
ForeignKey('app_ai_model_config.id', ondelete="CASCADE", onupdate="CASCADE"),
nullable=False,
index=True,
comment='模型配置ID'
)
role: Mapped[str] = mapped_column(String(20), nullable=False, comment='角色(user/assistant)')
content: Mapped[str] = mapped_column(Text, nullable=False, comment='消息内容')

View File

@@ -0,0 +1,307 @@
# -*- coding: utf-8 -*-
from pydantic import ConfigDict, Field, HttpUrl, BaseModel
from fastapi import Query
from app.core.base_schema import BaseSchema
from app.common.enums import McpLLMProvider, EmbeddingType, KnowledgeBaseStatus
from app.core.base_schema import BaseSchema, UserBySchema
from app.common.enums import McpType
from app.core.validator import DateTimeStr
class ChatQuerySchema(BaseModel):
"""聊天查询模型"""
message: str = Field(..., min_length=1, max_length=4000, description="聊天消息")
class McpCreateSchema(BaseModel):
"""创建 MCP 服务器参数"""
name: str = Field(..., max_length=64, description='MCP 名称')
type: McpType = Field(McpType.stdio, description='MCP 类型')
description: str | None = Field(None, max_length=255, description='MCP 描述')
url: HttpUrl | None = Field(None, description='远程 SSE 地址')
command: str | None = Field(None, max_length=255, description='MCP 命令')
args: str | None = Field(None, max_length=255, description='MCP 命令参数,多个参数用英文逗号隔开')
env: dict[str, str] | None = Field(None, description='MCP 环境变量')
class McpUpdateSchema(McpCreateSchema):
"""更新 MCP 服务器参数"""
...
class McpOutSchema(McpCreateSchema, BaseSchema, UserBySchema):
"""MCP 服务器详情"""
model_config = ConfigDict(from_attributes=True)
class McpQueryParam:
"""MCP 服务器查询参数"""
def __init__(
self,
name: str | None = Query(None, description="MCP 名称"),
type: McpType | None = Query(None, description="MCP 类型"),
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
created_id: int | None = Query(None, description="创建人"),
updated_id: int | None = Query(None, description="更新人"),
) -> None:
# 模糊查询字段
self.name = ("like", name) if name else None
# 精确查询字段
self.type = type
self.created_id = created_id
self.updated_id = updated_id
# 时间范围查询
if created_time and len(created_time) == 2:
self.created_time = ("between", (created_time[0], created_time[1]))
if updated_time and len(updated_time) == 2:
self.updated_time = ("between", (updated_time[0], updated_time[1]))
class McpChatParam(BaseSchema):
"""MCP 聊天参数"""
pk: list[int] = Field(..., description='MCP ID 列表')
provider: McpLLMProvider = Field(McpLLMProvider.openai, description='LLM 供应商')
model: str = Field(..., description='LLM 名称')
key: str = Field(..., description='LLM API Key')
base_url: str | None = Field(None, description='自定义 LLM API 地址,必须兼容 openai 供应商')
prompt: str = Field(..., description='用户提示词')
# ============== AI供应商配置 ==============
class AIProviderCreateSchema(BaseModel):
"""创建 AI供应商参数"""
name: str = Field(..., max_length=50, description='供应商名称')
provider_type: str = Field(..., max_length=50, description='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
base_url: str = Field(..., max_length=255, description='接口地址BaseURL')
api_key: str = Field(..., max_length=255, description='API Key')
is_default: int = Field(0, description='是否默认供应商(0:否 1:是)')
description: str | None = Field(None, max_length=255, description='备注')
class AIProviderUpdateSchema(AIProviderCreateSchema):
"""更新 AI供应商参数"""
...
class AIProviderOutSchema(AIProviderCreateSchema, BaseSchema, UserBySchema):
"""AI供应商详情"""
model_config = ConfigDict(from_attributes=True)
class AIProviderQueryParam:
"""AI供应商查询参数"""
def __init__(
self,
name: str | None = Query(None, description="供应商名称"),
provider_type: str | None = Query(None, description="供应商类型"),
is_default: int | None = Query(None, description="是否默认"),
) -> None:
self.name = ("like", name) if name else None
self.provider_type = provider_type
self.is_default = is_default
# ============== 向量化配置 ==============
class EmbeddingConfigCreateSchema(BaseModel):
"""创建 向量化配置参数"""
name: str = Field(..., max_length=50, description='配置名称')
embedding_type: int = Field(0, description='向量化类型(0:本地 1:远程)')
model_name: str = Field(..., max_length=100, description='Embedding模型名称')
base_url: str | None = Field(None, max_length=255, description='远程接口地址(远程模式必填)')
api_key: str | None = Field(None, max_length=255, description='远程API Key(远程模式必填)')
is_default: int = Field(0, description='是否默认配置(0:否 1:是)')
description: str | None = Field(None, max_length=255, description='备注')
class EmbeddingConfigUpdateSchema(EmbeddingConfigCreateSchema):
"""更新 向量化配置参数"""
...
class EmbeddingConfigOutSchema(EmbeddingConfigCreateSchema, BaseSchema, UserBySchema):
"""向量化配置详情"""
model_config = ConfigDict(from_attributes=True)
class EmbeddingConfigQueryParam:
"""向量化配置查询参数"""
def __init__(
self,
name: str | None = Query(None, description="配置名称"),
embedding_type: int | None = Query(None, description="向量化类型"),
is_default: int | None = Query(None, description="是否默认"),
) -> None:
self.name = ("like", name) if name else None
self.embedding_type = embedding_type
self.is_default = is_default
# ============== 知识库 ==============
class KnowledgeBaseCreateSchema(BaseModel):
"""创建 知识库参数"""
name: str = Field(..., max_length=100, description='知识库名称')
embedding_config_id: int | None = Field(None, description='向量化配置ID')
description: str | None = Field(None, max_length=255, description='备注')
class KnowledgeBaseUpdateSchema(BaseModel):
"""更新 知识库参数"""
name: str | None = Field(None, max_length=100, description='知识库名称')
embedding_config_id: int | None = Field(None, description='向量化配置ID')
description: str | None = Field(None, max_length=255, description='备注')
class EmbeddingConfigRefSchema(BaseModel):
"""向量化配置引用"""
model_config = ConfigDict(from_attributes=True)
id: int
name: str
embedding_type: int
model_name: str
class KnowledgeBaseOutSchema(BaseSchema, UserBySchema):
"""知识库详情"""
model_config = ConfigDict(from_attributes=True)
name: str
embedding_config_id: int | None = None
collection_name: str
document_count: int = 0
vector_count: int = 0
kb_status: int = 0
error_message: str | None = None
description: str | None = None
file_paths: list[str] | None = None
embedding_config: EmbeddingConfigRefSchema | None = None
class KnowledgeBaseQueryParam:
"""知识库查询参数"""
def __init__(
self,
name: str | None = Query(None, description="知识库名称"),
embedding_config_id: int | None = Query(None, description="向量化配置ID"),
kb_status: int | None = Query(None, description="知识库状态"),
) -> None:
self.name = ("like", name) if name else None
self.embedding_config_id = embedding_config_id
self.kb_status = kb_status
# ============== AI模型配置 ==============
# AI模型类型常量
AI_MODEL_TYPES = {
'enterprise_naming': '企业起名',
'enterprise_renaming': '企业改名',
'enterprise_scoring': '企业测名',
'enterprise_scoring_trial': '企业测名试用',
'personal_naming': '个人起名',
'personal_renaming': '个人改名',
'personal_scoring': '个人测名',
'personal_scoring_trial': '个人测名试用',
'yuanfen_hepan': '缘分合盘',
'bazi_zeji': '八字择吉',
'caiyun_jiexi': '财运解析',
'caiyun_jiexi_qiye': '企业财运解析'
}
class AIProviderRefSchema(BaseModel):
"""供应商引用"""
model_config = ConfigDict(from_attributes=True)
id: int
name: str
provider_type: str
base_url: str
class AIModelConfigCreateSchema(BaseModel):
"""创建 AI模型配置参数"""
model_type: str = Field(..., max_length=50, description='模型类型(naming/renaming/scoring/report)')
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
provider_id: int | None = Field(None, description='AI供应商ID')
system_prompt: str | None = Field(None, description='系统提示词')
temperature: float = Field(1.0, ge=0, le=2, description='模型温度(0-2)')
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
class AIModelConfigUpdateSchema(BaseModel):
"""更新 AI模型配置参数"""
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
provider_id: int | None = Field(None, description='AI供应商ID')
system_prompt: str | None = Field(None, description='系统提示词')
temperature: float | None = Field(None, ge=0, le=2, description='模型温度(0-2)')
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
class AIModelConfigOutSchema(BaseSchema, UserBySchema):
"""模型配置详情"""
model_config = ConfigDict(from_attributes=True)
model_type: str
model_name: str | None = None
provider_id: int | None = None
system_prompt: str | None = None
temperature: float = 1.0
knowledge_base_ids: list[int] | None = None
provider: AIProviderRefSchema | None = None
class AIModelConfigQueryParam:
"""AI模型配置查询参数"""
def __init__(
self,
model_type: str | None = Query(None, description="模型类型"),
) -> None:
self.model_type = model_type
# ============== AI模型训练对话 ==============
class AIModelTrainingMessageCreateSchema(BaseModel):
"""创建训练对话消息参数"""
model_config_id: int = Field(..., description='模型配置ID')
role: str = Field(..., max_length=20, description='角色(user/assistant)')
content: str = Field(..., description='消息内容')
class AIModelTrainingMessageOutSchema(BaseSchema, UserBySchema):
"""训练对话消息详情"""
model_config = ConfigDict(from_attributes=True)
model_config_id: int
role: str
content: str
class AIModelTrainingChatSchema(BaseModel):
"""训练对话请求参数"""
model_type: str = Field(..., description='模型类型')
message: str = Field(..., min_length=1, description='用户消息')
# 如果配置有变动,先保存配置
config_changed: bool = Field(False, description='配置是否有变动')
config_data: AIModelConfigUpdateSchema | None = Field(None, description='变动的配置数据')
class AIModelTestSchema(BaseModel):
"""起名测试请求参数(用于小程序/外部调用)"""
model_type: str = Field(..., description='模型类型(enterprise_naming/personal_naming等)')
text: str = Field(..., min_length=1, max_length=4000, description='用户输入的文本')

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
from typing import AsyncGenerator
from openai import AsyncOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
import httpx
from app.config.setting import settings
from app.core.logger import log
class AIClient:
"""
AI客户端类用于与OpenAI API交互。
"""
def __init__(self):
self.model = settings.OPENAI_MODEL
# 创建一个不带冲突参数的httpx客户端
self.http_client = httpx.AsyncClient(
timeout=30.0,
follow_redirects=True
)
# 使用自定义的http客户端
self.client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL,
http_client=self.http_client
)
def _friendly_error_message(self, e: Exception) -> str:
"""将 OpenAI 或网络异常转换为友好的中文提示。"""
# 尝试获取状态码与错误体
status_code = getattr(e, "status_code", None)
body = getattr(e, "body", None)
message = None
error_type = None
error_code = None
try:
if isinstance(body, dict) and "error" in body:
err = body.get("error") or {}
error_type = err.get("type")
error_code = err.get("code")
message = err.get("message")
except Exception:
# 忽略解析失败
pass
text = str(e)
msg = message or text
# 特定错误映射
# 欠费/账户状态异常
if (error_code == "Arrearage") or (error_type == "Arrearage") or ("in good standing" in (msg or "")):
return "账户欠费或结算异常,访问被拒绝。请检查账号状态或更换有效的 API Key。"
# 鉴权失败
if status_code == 401 or "invalid api key" in msg.lower():
return "鉴权失败API Key 无效或已过期。请检查系统配置中的 API Key。"
# 权限不足或被拒绝
if status_code == 403 or error_type in {"PermissionDenied", "permission_denied"}:
return "访问被拒绝,权限不足或账号受限。请检查账户权限设置。"
# 配额不足或限流
if status_code == 429 or error_type in {"insufficient_quota", "rate_limit_exceeded"}:
return "请求过于频繁或配额已用尽。请稍后重试或提升账户配额。"
# 客户端错误
if status_code == 400:
return f"请求参数错误或服务拒绝:{message or '请检查输入内容。'}"
# 服务端错误
if status_code in {500, 502, 503, 504}:
return "服务暂时不可用,请稍后重试。"
# 默认兜底
return f"处理您的请求时出现错误:{msg}"
async def process(self, query: str) -> AsyncGenerator[str, None]:
"""
处理查询并返回流式响应
参数:
- query (str): 用户查询。
返回:
- AsyncGenerator[str, None]: 流式响应内容。
"""
system_prompt = """你是一个有用的AI助手可以帮助用户回答问题和提供帮助。请用中文回答用户的问题。"""
try:
# 使用 await 调用异步客户端
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
],
stream=True
)
# 流式返回响应
async for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
# 记录详细错误,返回友好提示
log.error(f"AI处理查询失败: {str(e)}")
yield self._friendly_error_message(e)
async def close(self) -> None:
"""
关闭客户端连接
"""
if hasattr(self, 'client'):
await self.client.close()
if hasattr(self, 'http_client'):
await self.http_client.aclose()

View File

@@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
"""
文档处理工具类
支持 txt、pdf、doc、docx、md 格式解析
"""
import os
import tempfile
from pathlib import Path
from typing import List, Optional, BinaryIO
from langchain.schema import Document
from fastapi import UploadFile
from app.core.logger import log
class DocumentProcessor:
"""文档处理器"""
# 支持的文件扩展名
SUPPORTED_EXTENSIONS = {'.txt', '.pdf', '.doc', '.docx', '.md'}
@classmethod
def is_supported(cls, filename: str) -> bool:
"""检查文件是否支持"""
ext = Path(filename).suffix.lower()
return ext in cls.SUPPORTED_EXTENSIONS
@classmethod
async def process_upload_file(cls, file: UploadFile) -> List[Document]:
"""
处理上传的文件
参数:
- file: FastAPI UploadFile 对象
返回:
- 文档列表
"""
if not file.filename:
log.warning("文件名为空,跳过处理")
return []
ext = Path(file.filename).suffix.lower()
if not cls.is_supported(file.filename):
log.warning(f"不支持的文件类型: {ext}")
return []
# 读取文件内容
content = await file.read()
await file.seek(0) # 重置文件指针
# 根据文件类型处理
try:
if ext == '.txt':
return cls._process_txt(content, file.filename)
elif ext == '.md':
return cls._process_markdown(content, file.filename)
elif ext == '.pdf':
return await cls._process_pdf(content, file.filename)
elif ext in {'.doc', '.docx'}:
return await cls._process_word(content, file.filename, ext)
else:
log.warning(f"未知的文件类型: {ext}")
return []
except Exception as e:
log.error(f"处理文件失败: {file.filename}, 错误: {e}")
return []
@classmethod
def _process_txt(cls, content: bytes, filename: str) -> List[Document]:
"""处理 TXT 文件"""
try:
# 尝试不同编码
text = None
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
try:
text = content.decode(encoding)
break
except UnicodeDecodeError:
continue
if text is None:
log.error(f"无法解码文件: {filename}")
return []
return [Document(
page_content=text,
metadata={"source": filename, "type": "txt"}
)]
except Exception as e:
log.error(f"处理 TXT 文件失败: {filename}, 错误: {e}")
return []
@classmethod
def _process_markdown(cls, content: bytes, filename: str) -> List[Document]:
"""处理 Markdown 文件"""
try:
text = content.decode('utf-8')
return [Document(
page_content=text,
metadata={"source": filename, "type": "markdown"}
)]
except Exception as e:
log.error(f"处理 Markdown 文件失败: {filename}, 错误: {e}")
return []
@classmethod
async def _process_pdf(cls, content: bytes, filename: str) -> List[Document]:
"""处理 PDF 文件"""
try:
# 使用 pypdf 或 pdfplumber 处理 PDF
import pypdf
from io import BytesIO
pdf_file = BytesIO(content)
reader = pypdf.PdfReader(pdf_file)
documents = []
for page_num, page in enumerate(reader.pages):
text = page.extract_text()
if text and text.strip():
documents.append(Document(
page_content=text,
metadata={
"source": filename,
"type": "pdf",
"page": page_num + 1
}
))
log.info(f"PDF 文件处理完成: {filename}, 共 {len(documents)}")
return documents
except ImportError:
log.error("未安装 pypdf 库,请运行: pip install pypdf")
return []
except Exception as e:
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
return []
@classmethod
async def _process_word(cls, content: bytes, filename: str, ext: str) -> List[Document]:
"""处理 Word 文件 (doc/docx)"""
try:
if ext == '.docx':
return cls._process_docx(content, filename)
else:
# .doc 格式需要特殊处理
return cls._process_doc(content, filename)
except Exception as e:
log.error(f"处理 Word 文件失败: {filename}, 错误: {e}")
return []
@classmethod
def _process_docx(cls, content: bytes, filename: str) -> List[Document]:
"""处理 DOCX 文件"""
try:
from docx import Document as DocxDocument
from io import BytesIO
docx_file = BytesIO(content)
doc = DocxDocument(docx_file)
# 提取所有段落文本
paragraphs = []
for para in doc.paragraphs:
if para.text.strip():
paragraphs.append(para.text)
text = '\n'.join(paragraphs)
if text.strip():
return [Document(
page_content=text,
metadata={"source": filename, "type": "docx"}
)]
return []
except ImportError:
log.error("未安装 python-docx 库,请运行: pip install python-docx")
return []
except Exception as e:
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
return []
@classmethod
def _process_doc(cls, content: bytes, filename: str) -> List[Document]:
"""
处理 DOC 文件 (旧版 Word 格式)
注意: .doc 格式处理需要额外依赖,这里做简单提示
"""
try:
# 尝试使用 antiword 或 textract
# 如果没有安装,建议用户转换为 docx 格式
import subprocess
import tempfile
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix='.doc', delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
# 尝试使用 antiword
result = subprocess.run(
['antiword', tmp_path],
capture_output=True,
text=True,
timeout=30
)
if result.returncode == 0 and result.stdout.strip():
return [Document(
page_content=result.stdout,
metadata={"source": filename, "type": "doc"}
)]
except (subprocess.TimeoutExpired, FileNotFoundError):
log.warning(f"无法处理 .doc 文件: {filename},建议转换为 .docx 格式")
finally:
os.unlink(tmp_path)
return []
except Exception as e:
log.error(f"处理 DOC 文件失败: {filename}, 错误: {e}")
return []
@classmethod
async def process_files(cls, files: List[UploadFile]) -> List[Document]:
"""
批量处理上传的文件
参数:
- files: UploadFile 列表
返回:
- 所有文档列表
"""
all_documents = []
for file in files:
if file.filename and cls.is_supported(file.filename):
docs = await cls.process_upload_file(file)
all_documents.extend(docs)
log.info(f"文件处理完成: {file.filename}, 提取 {len(docs)} 个文档")
else:
log.warning(f"跳过不支持的文件: {file.filename}")
log.info(f"批量处理完成,共 {len(all_documents)} 个文档")
return all_documents

View File

@@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
"""
向量化工具类
支持本地和远程 embedding 模型
"""
import asyncio
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor
import chromadb
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from app.core.logger import log
from app.config.path_conf import BASE_DIR
# ChromaDB 持久化目录
CHROMA_PERSIST_DIR = BASE_DIR / "data" / "chroma_db"
# 全局 ChromaDB 客户端实例(单例模式)
_chroma_client = None
def get_chroma_client():
"""获取全局 ChromaDB 客户端实例(单例模式)"""
global _chroma_client
if _chroma_client is None:
# 确保持久化目录存在
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
_chroma_client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST_DIR)
)
return _chroma_client
class EmbeddingUtil:
"""向量化工具类"""
def __init__(
self,
embedding_type: int = 0,
model_name: str = "text-embedding-ada-002",
base_url: Optional[str] = None,
api_key: Optional[str] = None
):
"""
初始化向量化工具
参数:
- embedding_type: 0=本地, 1=远程
- model_name: Embedding模型名称
- base_url: 远程接口地址(远程模式必填)
- api_key: 远程API Key(远程模式必填)
"""
self.embedding_type = embedding_type
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
# 初始化 embedding 模型
self._embeddings = None
@property
def embeddings(self):
"""延迟加载 embedding 模型"""
if self._embeddings is None:
self._embeddings = self._create_embeddings()
return self._embeddings
def _create_embeddings(self):
"""创建 embedding 模型实例"""
if self.embedding_type == 0:
# 本地模式 - 使用 sentence-transformers
# sentence-transformers==3.3.1
# 本地Embedding模型可选手动安装 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
try:
from langchain_community.embeddings import HuggingFaceEmbeddings
log.info(f"使用本地 Embedding 模型: {self.model_name}")
return HuggingFaceEmbeddings(
model_name=self.model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
except Exception as e:
log.error(f"加载本地 Embedding 模型失败: {e}")
raise
else:
# 远程模式 - 使用 OpenAI 兼容接口
if not self.base_url or not self.api_key:
raise ValueError("远程模式必须提供 base_url 和 api_key")
# 自动拼接 /v1 路径(如果未包含)
api_base = self.base_url.rstrip('/')
if not api_base.endswith('/v1'):
api_base = f"{api_base}/v1"
log.info(f"使用远程 Embedding 模型: {self.model_name}, URL: {api_base}")
return OpenAIEmbeddings(
model=self.model_name,
base_url=api_base,
api_key=self.api_key,
)
def get_vector_store(self, collection_name: str) -> Chroma:
"""
获取或创建向量存储
参数:
- collection_name: 集合名称
返回:
- Chroma 向量存储实例
"""
# 确保目录存在
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
return Chroma(
collection_name=collection_name,
embedding_function=self.embeddings,
persist_directory=str(CHROMA_PERSIST_DIR),
client=get_chroma_client(),
)
def split_documents(
self,
documents: List[Document],
chunk_size: int = 1000,
chunk_overlap: int = 200
) -> List[Document]:
"""
分割文档
参数:
- documents: 文档列表
- chunk_size: 分块大小
- chunk_overlap: 分块重叠大小
返回:
- 分割后的文档列表
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", "", "", "", ".", "!", "?", " ", ""]
)
return text_splitter.split_documents(documents)
def add_documents(
self,
collection_name: str,
documents: List[Document]
) -> int:
"""
添加文档到向量存储
参数:
- collection_name: 集合名称
- documents: 文档列表
返回:
- 添加的向量数量
"""
if not documents:
return 0
# 分割文档
split_docs = self.split_documents(documents)
log.info(f"文档分割完成,共 {len(split_docs)} 个片段")
# 获取向量存储
vector_store = self.get_vector_store(collection_name)
# 添加文档
vector_store.add_documents(split_docs)
log.info(f"向量存储完成,集合: {collection_name}, 向量数: {len(split_docs)}")
return len(split_docs)
def delete_collection(self, collection_name: str) -> bool:
"""
删除集合
参数:
- collection_name: 集合名称
返回:
- 是否删除成功
"""
try:
client = get_chroma_client()
client.delete_collection(collection_name)
log.info(f"删除集合成功: {collection_name}")
return True
except Exception as e:
log.error(f"删除集合失败: {collection_name}, 错误: {e}")
return False
def similarity_search(
self,
collection_name: str,
query: str,
k: int = 4
) -> List[Document]:
"""
相似度搜索
参数:
- collection_name: 集合名称
- query: 查询文本
- k: 返回结果数量
返回:
- 相似文档列表
"""
vector_store = self.get_vector_store(collection_name)
return vector_store.similarity_search(query, k=k)
def get_collection_count(self, collection_name: str) -> int:
"""
获取集合中的向量数量
参数:
- collection_name: 集合名称
返回:
- 向量数量
"""
try:
client = get_chroma_client()
collection = client.get_collection(collection_name)
return collection.count()
except Exception as e:
log.warning(f"获取集合数量失败: {collection_name}, 错误: {e}")
return 0