Files
----/后端源码/yifan.action-ai.cn/app/api/v1/module_application/ai/controller.py

538 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
from fastapi import APIRouter, Depends, Path, Body, WebSocket, Form, UploadFile, File
from fastapi.responses import JSONResponse, StreamingResponse
from app.common.response import StreamResponse, SuccessResponse
from app.common.request import PaginationService
from app.core.base_params import PaginationQueryParam
from app.core.dependencies import AuthPermission
from app.core.logger import log
from app.api.v1.module_system.auth.schema import AuthSchema
from app.core.router_class import OperationLogRoute
from .service import McpService, AIProviderService, EmbeddingConfigService, KnowledgeBaseService, AIModelConfigService, AIModelTrainingService, AIModelTestService
from .schema import (
McpCreateSchema, McpUpdateSchema, ChatQuerySchema, McpQueryParam,
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderQueryParam,
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigQueryParam,
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseQueryParam,
AIModelConfigUpdateSchema, AIModelConfigQueryParam, AIModelTrainingChatSchema,
AIModelTestSchema
)
AIRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai", tags=["MCP智能助手"])
AIProviderRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/provider", tags=["AI供应商配置"])
EmbeddingConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/embedding", tags=["向量化配置"])
KnowledgeBaseRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/knowledge", tags=["知识库管理"])
AIModelConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/model", tags=["AI模型配置"])
@AIRouter.post("/chat", summary="智能对话", description="与MCP智能助手进行对话")
async def chat_controller(
query: ChatQuerySchema,
auth: AuthSchema = Depends(AuthPermission())
) -> StreamingResponse:
"""
智能对话接口
参数:
- query (ChatQuerySchema): 聊天查询模型
返回:
- StreamingResponse: 流式响应,每次返回一个聊天响应
"""
user_name = auth.user.name if auth.user else "未知用户"
log.info(f"用户 {user_name} 发起智能对话: {query.message[:50]}...")
async def generate_response():
try:
async for chunk in McpService.chat_query(query=query):
# 确保返回的是字节串
if chunk:
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
except Exception as e:
log.error(f"流式响应出错: {str(e)}")
yield f"抱歉,处理您的请求时出现了错误: {str(e)}".encode('utf-8')
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
@AIRouter.get("/detail/{id}", summary="获取 MCP 服务器详情", description="获取 MCP 服务器详情")
async def detail_controller(
id: int = Path(..., description="MCP ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
获取 MCP 服务器详情接口
参数:
- id (int): MCP 服务器ID
返回:
- JSONResponse: 包含 MCP 服务器详情的 JSON 响应
"""
result_dict = await McpService.detail_service(auth=auth, id=id)
log.info(f"获取 MCP 服务器详情成功 {id}")
return SuccessResponse(data=result_dict, msg="获取 MCP 服务器详情成功")
@AIRouter.get("/list", summary="查询 MCP 服务器列表", description="查询 MCP 服务器列表")
async def list_controller(
page: PaginationQueryParam = Depends(),
search: McpQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
查询 MCP 服务器列表接口
参数:
- page (PaginationQueryParam): 分页查询参数模型
- search (McpQueryParam): 查询参数模型
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含 MCP 服务器列表的 JSON 响应
"""
result_dict_list = await McpService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
log.info(f"查询 MCP 服务器列表成功")
return SuccessResponse(data=result_dict, msg="查询 MCP 服务器列表成功")
@AIRouter.post("/create", summary="创建 MCP 服务器", description="创建 MCP 服务器")
async def create_controller(
data: McpCreateSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
创建 MCP 服务器接口
参数:
- data (McpCreateSchema): 创建 MCP 服务器模型
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含创建 MCP 服务器结果的 JSON 响应
"""
result_dict = await McpService.create_service(auth=auth, data=data)
log.info(f"创建 MCP 服务器成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建 MCP 服务器成功")
@AIRouter.put("/update/{id}", summary="修改 MCP 服务器", description="修改 MCP 服务器")
async def update_controller(
data: McpUpdateSchema,
id: int = Path(..., description="MCP ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
修改 MCP 服务器接口
参数:
- data (McpUpdateSchema): 修改 MCP 服务器模型
- id (int): MCP 服务器ID
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含修改 MCP 服务器结果的 JSON 响应
"""
result_dict = await McpService.update_service(auth=auth, id=id, data=data)
log.info(f"修改 MCP 服务器成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改 MCP 服务器成功")
@AIRouter.delete("/delete", summary="删除 MCP 服务器", description="删除 MCP 服务器")
async def delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
"""
删除 MCP 服务器接口
参数:
- ids (list[int]): MCP 服务器ID列表
- auth (AuthSchema): 认证信息模型
返回:
- JSONResponse: 包含删除 MCP 服务器结果的 JSON 响应
"""
await McpService.delete_service(auth=auth, ids=ids)
log.info(f"删除 MCP 服务器成功: {ids}")
return SuccessResponse(msg="删除 MCP 服务器成功")
@AIRouter.websocket("/ws/chat", name="WebSocket聊天")
async def websocket_chat_controller(
websocket: WebSocket,
):
"""
WebSocket聊天接口
ws://127.0.0.1:8001/api/v1/ai/mcp/ws/chat
"""
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
# 流式发送响应
try:
async for chunk in McpService.chat_query(query=ChatQuerySchema(message=data)):
if chunk:
await websocket.send_text(chunk)
except Exception as e:
log.error(f"处理聊天查询出错: {str(e)}")
await websocket.send_text(f"抱歉,处理您的请求时出现了错误: {str(e)}")
except Exception as e:
log.error(f"WebSocket聊天出错: {str(e)}")
finally:
await websocket.close()
@AIModelConfigRouter.post("/test", summary="起名测试")
async def naming_test_controller(
test_data: AIModelTestSchema,
) -> JSONResponse:
"""
起名测试接口(用于小程序/外部调用)
- 读取模型配置和训练信息
- 内部流式调用AI供应商
- 组合响应后一次性返回
- 仅读数据,不保存对话记录
参数:
- model_type: 模型类型(enterprise_naming/personal_naming等)
- text: 用户输入的文本
返回:
- data: AI的完整响应文本
"""
log.info(f"[起名测试] 收到请求: model_type={test_data.model_type}, text={test_data.text[:50]}...")
result = await AIModelTestService.test_naming(
model_type=test_data.model_type,
text=test_data.text
)
log.info(f"[起名测试] 响应完成,长度: {len(result)}")
return SuccessResponse(data=result, msg="测试成功")
# ============== AI供应商配置 ==============
@AIProviderRouter.get("/detail/{id}", summary="获取AI供应商详情")
async def provider_detail_controller(
id: int = Path(..., description="AI供应商ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIProviderService.detail_service(auth=auth, id=id)
return SuccessResponse(data=result_dict, msg="获取AI供应商详情成功")
@AIProviderRouter.get("/list", summary="查询AI供应商列表")
async def provider_list_controller(
page: PaginationQueryParam = Depends(),
search: AIProviderQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await AIProviderService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询AI供应商列表成功")
@AIProviderRouter.post("/create", summary="创建AI供应商")
async def provider_create_controller(
data: AIProviderCreateSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIProviderService.create_service(auth=auth, data=data)
log.info(f"创建AI供应商成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建AI供应商成功")
@AIProviderRouter.put("/update/{id}", summary="修改AI供应商")
async def provider_update_controller(
data: AIProviderUpdateSchema,
id: int = Path(..., description="AI供应商ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIProviderService.update_service(auth=auth, id=id, data=data)
log.info(f"修改AI供应商成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改AI供应商成功")
@AIProviderRouter.delete("/delete", summary="删除AI供应商")
async def provider_delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await AIProviderService.delete_service(auth=auth, ids=ids)
log.info(f"删除AI供应商成功: {ids}")
return SuccessResponse(msg="删除AI供应商成功")
# ============== 向量化配置 ==============
@EmbeddingConfigRouter.get("/detail/{id}", summary="获取向量化配置详情")
async def embedding_detail_controller(
id: int = Path(..., description="向量化配置ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await EmbeddingConfigService.detail_service(auth=auth, id=id)
return SuccessResponse(data=result_dict, msg="获取向量化配置详情成功")
@EmbeddingConfigRouter.get("/list", summary="查询向量化配置列表")
async def embedding_list_controller(
page: PaginationQueryParam = Depends(),
search: EmbeddingConfigQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await EmbeddingConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询向量化配置列表成功")
@EmbeddingConfigRouter.post("/create", summary="创建向量化配置")
async def embedding_create_controller(
data: EmbeddingConfigCreateSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await EmbeddingConfigService.create_service(auth=auth, data=data)
log.info(f"创建向量化配置成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建向量化配置成功")
@EmbeddingConfigRouter.put("/update/{id}", summary="修改向量化配置")
async def embedding_update_controller(
data: EmbeddingConfigUpdateSchema,
id: int = Path(..., description="向量化配置ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await EmbeddingConfigService.update_service(auth=auth, id=id, data=data)
log.info(f"修改向量化配置成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改向量化配置成功")
@EmbeddingConfigRouter.delete("/delete", summary="删除向量化配置")
async def embedding_delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await EmbeddingConfigService.delete_service(auth=auth, ids=ids)
log.info(f"删除向量化配置成功: {ids}")
return SuccessResponse(msg="删除向量化配置成功")
# ============== 知识库管理 ==============
@KnowledgeBaseRouter.get("/detail/{id}", summary="获取知识库详情")
async def knowledge_detail_controller(
id: int = Path(..., description="知识库ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await KnowledgeBaseService.detail_service(auth=auth, id=id)
return SuccessResponse(data=result_dict, msg="获取知识库详情成功")
@KnowledgeBaseRouter.get("/list", summary="查询知识库列表")
async def knowledge_list_controller(
page: PaginationQueryParam = Depends(),
search: KnowledgeBaseQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await KnowledgeBaseService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询知识库列表成功")
@KnowledgeBaseRouter.post("/create", summary="创建知识库")
async def knowledge_create_controller(
name: str = Form(..., max_length=100, description="知识库名称"),
embedding_config_id: int | None = Form(None, description="向量化配置ID"),
description: str | None = Form(None, max_length=255, description="备注"),
files: list[UploadFile] = File(default=[], description="上传的文件列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
# 构造数据对象
data = KnowledgeBaseCreateSchema(name=name, embedding_config_id=embedding_config_id, description=description)
result_dict = await KnowledgeBaseService.create_service(auth=auth, data=data, files=files)
log.info(f"创建知识库成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="创建知识库成功")
@KnowledgeBaseRouter.put("/update/{id}", summary="修改知识库")
async def knowledge_update_controller(
data: KnowledgeBaseUpdateSchema,
id: int = Path(..., description="知识库ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await KnowledgeBaseService.update_service(auth=auth, id=id, data=data)
log.info(f"修改知识库成功: {result_dict}")
return SuccessResponse(data=result_dict, msg="修改知识库成功")
@KnowledgeBaseRouter.delete("/delete", summary="删除知识库")
async def knowledge_delete_controller(
ids: list[int] = Body(..., description="ID列表"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await KnowledgeBaseService.delete_service(auth=auth, ids=ids)
log.info(f"删除知识库成功: {ids}")
return SuccessResponse(msg="删除知识库成功")
@KnowledgeBaseRouter.post("/retry/{id}", summary="重试向量化")
async def knowledge_retry_controller(
id: int = Path(..., description="知识库ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await KnowledgeBaseService.retry_service(auth=auth, id=id)
log.info(f"重试向量化成功: {id}")
return SuccessResponse(data=result_dict, msg="已启动重新向量化")
# ============== AI模型配置 ==============
@AIModelConfigRouter.get("/detail/{model_type}", summary="获取AI模型配置详情")
async def model_config_detail_controller(
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIModelConfigService.detail_service(auth=auth, model_type=model_type)
return SuccessResponse(data=result_dict, msg="获取AI模型配置详情成功")
@AIModelConfigRouter.get("/list", summary="查询AI模型配置列表")
async def model_config_list_controller(
page: PaginationQueryParam = Depends(),
search: AIModelConfigQueryParam = Depends(),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict_list = await AIModelConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
return SuccessResponse(data=result_dict, msg="查询AI模型配置列表成功")
@AIModelConfigRouter.put("/update/{model_type}", summary="更新AI模型配置")
async def model_config_update_controller(
data: AIModelConfigUpdateSchema,
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_dict = await AIModelConfigService.update_service(auth=auth, model_type=model_type, data=data)
log.info(f"更新AI模型配置成功: {model_type}")
return SuccessResponse(data=result_dict, msg="更新AI模型配置成功")
@AIModelConfigRouter.get("/available-models/{provider_id}", summary="获取可用模型列表")
async def available_models_controller(
provider_id: int = Path(..., description="AI供应商ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_list = await AIModelConfigService.get_available_models(auth=auth, provider_id=provider_id)
return SuccessResponse(data=result_list, msg="获取可用模型列表成功")
# ============== AI模型训练对话 ==============
@AIModelConfigRouter.get("/messages/{model_type}", summary="获取训练对话记录")
async def training_messages_controller(
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
result_list = await AIModelTrainingService.get_messages_service(auth=auth, model_type=model_type)
return SuccessResponse(data=result_list, msg="获取训练对话记录成功")
@AIModelConfigRouter.delete("/message/{message_id}", summary="删除单条训练对话记录")
async def delete_training_message_controller(
message_id: int = Path(..., description="对话记录ID"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await AIModelTrainingService.delete_message_service(auth=auth, message_id=message_id)
log.info(f"删除训练对话记录成功: {message_id}")
return SuccessResponse(msg="删除训练对话记录成功")
@AIModelConfigRouter.delete("/messages/{model_type}", summary="清空训练对话记录")
async def clear_training_messages_controller(
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
auth: AuthSchema = Depends(AuthPermission())
) -> JSONResponse:
await AIModelTrainingService.clear_messages_service(auth=auth, model_type=model_type)
log.info(f"清空训练对话记录成功: {model_type}")
return SuccessResponse(msg="清空训练对话记录成功")
@AIModelConfigRouter.post("/chat", summary="训练对话")
async def training_chat_controller(
chat_data: AIModelTrainingChatSchema,
auth: AuthSchema = Depends(AuthPermission())
) -> StreamingResponse:
"""
训练对话接口,流式输出
"""
log.info(f"[训练对话] 收到请求: model_type={chat_data.model_type}, message={chat_data.message[:50]}...")
async def generate_response():
log.info("[训练对话] 开始生成响应")
chunk_count = 0
try:
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
if chunk:
chunk_count += 1
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
log.info(f"[训练对话] 响应生成完成,共 {chunk_count} 个 chunk")
except Exception as e:
log.error(f"训练对话出错: {str(e)}")
yield f"处理您的请求时出现了错误: {str(e)}".encode('utf-8')
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
@AIModelConfigRouter.websocket("/ws/chat")
async def training_websocket_chat_controller(
websocket: WebSocket,
):
"""
训练对话 WebSocket 接口
"""
from app.core.security import verify_token
from app.api.v1.module_system.auth.schema import AuthSchema
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
model_type = data.get('model_type', 'naming')
message = data.get('message', '')
config_changed = data.get('config_changed', False)
config_data = data.get('config_data')
# 构建chat_data
chat_data = AIModelTrainingChatSchema(
model_type=model_type,
message=message,
config_changed=config_changed,
config_data=AIModelConfigUpdateSchema(**config_data) if config_data else None
)
# 构建简化的authWebSocket没有正常的认证流程实际使用时需要实现认证
auth = AuthSchema()
# 流式发送响应
try:
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
if chunk:
await websocket.send_text(chunk)
except Exception as e:
log.error(f"处理训练对话出错: {str(e)}")
await websocket.send_text(f"处理您的请求时出现了错误: {str(e)}")
except Exception as e:
log.error(f"WebSocket训练对话出错: {str(e)}")
finally:
await websocket.close()