upload project source code
This commit is contained in:
@@ -0,0 +1,537 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Body, WebSocket, Form, UploadFile, File
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.common.response import StreamResponse, SuccessResponse
|
||||
from app.common.request import PaginationService
|
||||
from app.core.base_params import PaginationQueryParam
|
||||
from app.core.dependencies import AuthPermission
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from app.core.router_class import OperationLogRoute
|
||||
from .service import McpService, AIProviderService, EmbeddingConfigService, KnowledgeBaseService, AIModelConfigService, AIModelTrainingService, AIModelTestService
|
||||
from .schema import (
|
||||
McpCreateSchema, McpUpdateSchema, ChatQuerySchema, McpQueryParam,
|
||||
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderQueryParam,
|
||||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigQueryParam,
|
||||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseQueryParam,
|
||||
AIModelConfigUpdateSchema, AIModelConfigQueryParam, AIModelTrainingChatSchema,
|
||||
AIModelTestSchema
|
||||
)
|
||||
|
||||
|
||||
AIRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai", tags=["MCP智能助手"])
|
||||
AIProviderRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/provider", tags=["AI供应商配置"])
|
||||
EmbeddingConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/embedding", tags=["向量化配置"])
|
||||
KnowledgeBaseRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/knowledge", tags=["知识库管理"])
|
||||
AIModelConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/model", tags=["AI模型配置"])
|
||||
|
||||
|
||||
@AIRouter.post("/chat", summary="智能对话", description="与MCP智能助手进行对话")
|
||||
async def chat_controller(
|
||||
query: ChatQuerySchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
智能对话接口
|
||||
|
||||
参数:
|
||||
- query (ChatQuerySchema): 聊天查询模型
|
||||
|
||||
返回:
|
||||
- StreamingResponse: 流式响应,每次返回一个聊天响应
|
||||
"""
|
||||
user_name = auth.user.name if auth.user else "未知用户"
|
||||
log.info(f"用户 {user_name} 发起智能对话: {query.message[:50]}...")
|
||||
|
||||
async def generate_response():
|
||||
try:
|
||||
async for chunk in McpService.chat_query(query=query):
|
||||
# 确保返回的是字节串
|
||||
if chunk:
|
||||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||||
except Exception as e:
|
||||
log.error(f"流式响应出错: {str(e)}")
|
||||
yield f"抱歉,处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||||
|
||||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
@AIRouter.get("/detail/{id}", summary="获取 MCP 服务器详情", description="获取 MCP 服务器详情")
|
||||
async def detail_controller(
|
||||
id: int = Path(..., description="MCP ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取 MCP 服务器详情接口
|
||||
|
||||
参数:
|
||||
- id (int): MCP 服务器ID
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含 MCP 服务器详情的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.detail_service(auth=auth, id=id)
|
||||
log.info(f"获取 MCP 服务器详情成功 {id}")
|
||||
return SuccessResponse(data=result_dict, msg="获取 MCP 服务器详情成功")
|
||||
|
||||
|
||||
@AIRouter.get("/list", summary="查询 MCP 服务器列表", description="查询 MCP 服务器列表")
|
||||
async def list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: McpQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
查询 MCP 服务器列表接口
|
||||
|
||||
参数:
|
||||
- page (PaginationQueryParam): 分页查询参数模型
|
||||
- search (McpQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含 MCP 服务器列表的 JSON 响应
|
||||
"""
|
||||
result_dict_list = await McpService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
log.info(f"查询 MCP 服务器列表成功")
|
||||
return SuccessResponse(data=result_dict, msg="查询 MCP 服务器列表成功")
|
||||
|
||||
|
||||
@AIRouter.post("/create", summary="创建 MCP 服务器", description="创建 MCP 服务器")
|
||||
async def create_controller(
|
||||
data: McpCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
创建 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- data (McpCreateSchema): 创建 MCP 服务器模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含创建 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建 MCP 服务器成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.put("/update/{id}", summary="修改 MCP 服务器", description="修改 MCP 服务器")
|
||||
async def update_controller(
|
||||
data: McpUpdateSchema,
|
||||
id: int = Path(..., description="MCP ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
修改 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- data (McpUpdateSchema): 修改 MCP 服务器模型
|
||||
- id (int): MCP 服务器ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含修改 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改 MCP 服务器成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.delete("/delete", summary="删除 MCP 服务器", description="删除 MCP 服务器")
|
||||
async def delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
删除 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- ids (list[int]): MCP 服务器ID列表
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含删除 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
await McpService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除 MCP 服务器成功: {ids}")
|
||||
return SuccessResponse(msg="删除 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.websocket("/ws/chat", name="WebSocket聊天")
|
||||
async def websocket_chat_controller(
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""
|
||||
WebSocket聊天接口
|
||||
|
||||
ws://127.0.0.1:8001/api/v1/ai/mcp/ws/chat
|
||||
"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# 流式发送响应
|
||||
try:
|
||||
async for chunk in McpService.chat_query(query=ChatQuerySchema(message=data)):
|
||||
if chunk:
|
||||
await websocket.send_text(chunk)
|
||||
except Exception as e:
|
||||
log.error(f"处理聊天查询出错: {str(e)}")
|
||||
await websocket.send_text(f"抱歉,处理您的请求时出现了错误: {str(e)}")
|
||||
except Exception as e:
|
||||
log.error(f"WebSocket聊天出错: {str(e)}")
|
||||
finally:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@AIModelConfigRouter.post("/test", summary="起名测试")
|
||||
async def naming_test_controller(
|
||||
test_data: AIModelTestSchema,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
起名测试接口(用于小程序/外部调用)
|
||||
|
||||
- 读取模型配置和训练信息
|
||||
- 内部流式调用AI供应商
|
||||
- 组合响应后一次性返回
|
||||
- 仅读数据,不保存对话记录
|
||||
|
||||
参数:
|
||||
- model_type: 模型类型(enterprise_naming/personal_naming等)
|
||||
- text: 用户输入的文本
|
||||
|
||||
返回:
|
||||
- data: AI的完整响应文本
|
||||
"""
|
||||
log.info(f"[起名测试] 收到请求: model_type={test_data.model_type}, text={test_data.text[:50]}...")
|
||||
|
||||
result = await AIModelTestService.test_naming(
|
||||
model_type=test_data.model_type,
|
||||
text=test_data.text
|
||||
)
|
||||
|
||||
log.info(f"[起名测试] 响应完成,长度: {len(result)}")
|
||||
return SuccessResponse(data=result, msg="测试成功")
|
||||
|
||||
|
||||
# ============== AI供应商配置 ==============
|
||||
|
||||
@AIProviderRouter.get("/detail/{id}", summary="获取AI供应商详情")
|
||||
async def provider_detail_controller(
|
||||
id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取AI供应商详情成功")
|
||||
|
||||
|
||||
@AIProviderRouter.get("/list", summary="查询AI供应商列表")
|
||||
async def provider_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: AIProviderQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await AIProviderService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询AI供应商列表成功")
|
||||
|
||||
|
||||
@AIProviderRouter.post("/create", summary="创建AI供应商")
|
||||
async def provider_create_controller(
|
||||
data: AIProviderCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建AI供应商成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建AI供应商成功")
|
||||
|
||||
|
||||
@AIProviderRouter.put("/update/{id}", summary="修改AI供应商")
|
||||
async def provider_update_controller(
|
||||
data: AIProviderUpdateSchema,
|
||||
id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改AI供应商成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改AI供应商成功")
|
||||
|
||||
|
||||
@AIProviderRouter.delete("/delete", summary="删除AI供应商")
|
||||
async def provider_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIProviderService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除AI供应商成功: {ids}")
|
||||
return SuccessResponse(msg="删除AI供应商成功")
|
||||
|
||||
|
||||
# ============== 向量化配置 ==============
|
||||
|
||||
@EmbeddingConfigRouter.get("/detail/{id}", summary="获取向量化配置详情")
|
||||
async def embedding_detail_controller(
|
||||
id: int = Path(..., description="向量化配置ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取向量化配置详情成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.get("/list", summary="查询向量化配置列表")
|
||||
async def embedding_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: EmbeddingConfigQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await EmbeddingConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询向量化配置列表成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.post("/create", summary="创建向量化配置")
|
||||
async def embedding_create_controller(
|
||||
data: EmbeddingConfigCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建向量化配置成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建向量化配置成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.put("/update/{id}", summary="修改向量化配置")
|
||||
async def embedding_update_controller(
|
||||
data: EmbeddingConfigUpdateSchema,
|
||||
id: int = Path(..., description="向量化配置ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改向量化配置成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改向量化配置成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.delete("/delete", summary="删除向量化配置")
|
||||
async def embedding_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await EmbeddingConfigService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除向量化配置成功: {ids}")
|
||||
return SuccessResponse(msg="删除向量化配置成功")
|
||||
|
||||
|
||||
# ============== 知识库管理 ==============
|
||||
|
||||
@KnowledgeBaseRouter.get("/detail/{id}", summary="获取知识库详情")
|
||||
async def knowledge_detail_controller(
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取知识库详情成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.get("/list", summary="查询知识库列表")
|
||||
async def knowledge_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: KnowledgeBaseQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await KnowledgeBaseService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询知识库列表成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.post("/create", summary="创建知识库")
|
||||
async def knowledge_create_controller(
|
||||
name: str = Form(..., max_length=100, description="知识库名称"),
|
||||
embedding_config_id: int | None = Form(None, description="向量化配置ID"),
|
||||
description: str | None = Form(None, max_length=255, description="备注"),
|
||||
files: list[UploadFile] = File(default=[], description="上传的文件列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
# 构造数据对象
|
||||
data = KnowledgeBaseCreateSchema(name=name, embedding_config_id=embedding_config_id, description=description)
|
||||
result_dict = await KnowledgeBaseService.create_service(auth=auth, data=data, files=files)
|
||||
log.info(f"创建知识库成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.put("/update/{id}", summary="修改知识库")
|
||||
async def knowledge_update_controller(
|
||||
data: KnowledgeBaseUpdateSchema,
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改知识库成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.delete("/delete", summary="删除知识库")
|
||||
async def knowledge_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await KnowledgeBaseService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除知识库成功: {ids}")
|
||||
return SuccessResponse(msg="删除知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.post("/retry/{id}", summary="重试向量化")
|
||||
async def knowledge_retry_controller(
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.retry_service(auth=auth, id=id)
|
||||
log.info(f"重试向量化成功: {id}")
|
||||
return SuccessResponse(data=result_dict, msg="已启动重新向量化")
|
||||
|
||||
|
||||
# ============== AI模型配置 ==============
|
||||
|
||||
@AIModelConfigRouter.get("/detail/{model_type}", summary="获取AI模型配置详情")
|
||||
async def model_config_detail_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIModelConfigService.detail_service(auth=auth, model_type=model_type)
|
||||
return SuccessResponse(data=result_dict, msg="获取AI模型配置详情成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.get("/list", summary="查询AI模型配置列表")
|
||||
async def model_config_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: AIModelConfigQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await AIModelConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询AI模型配置列表成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.put("/update/{model_type}", summary="更新AI模型配置")
|
||||
async def model_config_update_controller(
|
||||
data: AIModelConfigUpdateSchema,
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIModelConfigService.update_service(auth=auth, model_type=model_type, data=data)
|
||||
log.info(f"更新AI模型配置成功: {model_type}")
|
||||
return SuccessResponse(data=result_dict, msg="更新AI模型配置成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.get("/available-models/{provider_id}", summary="获取可用模型列表")
|
||||
async def available_models_controller(
|
||||
provider_id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_list = await AIModelConfigService.get_available_models(auth=auth, provider_id=provider_id)
|
||||
return SuccessResponse(data=result_list, msg="获取可用模型列表成功")
|
||||
|
||||
|
||||
# ============== AI模型训练对话 ==============
|
||||
|
||||
@AIModelConfigRouter.get("/messages/{model_type}", summary="获取训练对话记录")
|
||||
async def training_messages_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_list = await AIModelTrainingService.get_messages_service(auth=auth, model_type=model_type)
|
||||
return SuccessResponse(data=result_list, msg="获取训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.delete("/message/{message_id}", summary="删除单条训练对话记录")
|
||||
async def delete_training_message_controller(
|
||||
message_id: int = Path(..., description="对话记录ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIModelTrainingService.delete_message_service(auth=auth, message_id=message_id)
|
||||
log.info(f"删除训练对话记录成功: {message_id}")
|
||||
return SuccessResponse(msg="删除训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.delete("/messages/{model_type}", summary="清空训练对话记录")
|
||||
async def clear_training_messages_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIModelTrainingService.clear_messages_service(auth=auth, model_type=model_type)
|
||||
log.info(f"清空训练对话记录成功: {model_type}")
|
||||
return SuccessResponse(msg="清空训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.post("/chat", summary="训练对话")
|
||||
async def training_chat_controller(
|
||||
chat_data: AIModelTrainingChatSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
训练对话接口,流式输出
|
||||
"""
|
||||
log.info(f"[训练对话] 收到请求: model_type={chat_data.model_type}, message={chat_data.message[:50]}...")
|
||||
|
||||
async def generate_response():
|
||||
log.info("[训练对话] 开始生成响应")
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||||
log.info(f"[训练对话] 响应生成完成,共 {chunk_count} 个 chunk")
|
||||
except Exception as e:
|
||||
log.error(f"训练对话出错: {str(e)}")
|
||||
yield f"处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||||
|
||||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
@AIModelConfigRouter.websocket("/ws/chat")
|
||||
async def training_websocket_chat_controller(
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""
|
||||
训练对话 WebSocket 接口
|
||||
"""
|
||||
from app.core.security import verify_token
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
model_type = data.get('model_type', 'naming')
|
||||
message = data.get('message', '')
|
||||
config_changed = data.get('config_changed', False)
|
||||
config_data = data.get('config_data')
|
||||
|
||||
# 构建chat_data
|
||||
chat_data = AIModelTrainingChatSchema(
|
||||
model_type=model_type,
|
||||
message=message,
|
||||
config_changed=config_changed,
|
||||
config_data=AIModelConfigUpdateSchema(**config_data) if config_data else None
|
||||
)
|
||||
|
||||
# 构建简化的auth(WebSocket没有正常的认证流程,实际使用时需要实现认证)
|
||||
auth = AuthSchema()
|
||||
|
||||
# 流式发送响应
|
||||
try:
|
||||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||||
if chunk:
|
||||
await websocket.send_text(chunk)
|
||||
except Exception as e:
|
||||
log.error(f"处理训练对话出错: {str(e)}")
|
||||
await websocket.send_text(f"处理您的请求时出现了错误: {str(e)}")
|
||||
except Exception as e:
|
||||
log.error(f"WebSocket训练对话出错: {str(e)}")
|
||||
finally:
|
||||
await websocket.close()
|
||||
Reference in New Issue
Block a user