538 lines
22 KiB
Python
538 lines
22 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
from fastapi import APIRouter, Depends, Path, Body, WebSocket, Form, UploadFile, File
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
||
from app.common.response import StreamResponse, SuccessResponse
|
||
from app.common.request import PaginationService
|
||
from app.core.base_params import PaginationQueryParam
|
||
from app.core.dependencies import AuthPermission
|
||
from app.core.logger import log
|
||
|
||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||
from app.core.router_class import OperationLogRoute
|
||
from .service import McpService, AIProviderService, EmbeddingConfigService, KnowledgeBaseService, AIModelConfigService, AIModelTrainingService, AIModelTestService
|
||
from .schema import (
|
||
McpCreateSchema, McpUpdateSchema, ChatQuerySchema, McpQueryParam,
|
||
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderQueryParam,
|
||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigQueryParam,
|
||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseQueryParam,
|
||
AIModelConfigUpdateSchema, AIModelConfigQueryParam, AIModelTrainingChatSchema,
|
||
AIModelTestSchema
|
||
)
|
||
|
||
|
||
AIRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai", tags=["MCP智能助手"])
|
||
AIProviderRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/provider", tags=["AI供应商配置"])
|
||
EmbeddingConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/embedding", tags=["向量化配置"])
|
||
KnowledgeBaseRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/knowledge", tags=["知识库管理"])
|
||
AIModelConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/model", tags=["AI模型配置"])
|
||
|
||
|
||
@AIRouter.post("/chat", summary="智能对话", description="与MCP智能助手进行对话")
|
||
async def chat_controller(
|
||
query: ChatQuerySchema,
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> StreamingResponse:
|
||
"""
|
||
智能对话接口
|
||
|
||
参数:
|
||
- query (ChatQuerySchema): 聊天查询模型
|
||
|
||
返回:
|
||
- StreamingResponse: 流式响应,每次返回一个聊天响应
|
||
"""
|
||
user_name = auth.user.name if auth.user else "未知用户"
|
||
log.info(f"用户 {user_name} 发起智能对话: {query.message[:50]}...")
|
||
|
||
async def generate_response():
|
||
try:
|
||
async for chunk in McpService.chat_query(query=query):
|
||
# 确保返回的是字节串
|
||
if chunk:
|
||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||
except Exception as e:
|
||
log.error(f"流式响应出错: {str(e)}")
|
||
yield f"抱歉,处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||
|
||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||
|
||
|
||
@AIRouter.get("/detail/{id}", summary="获取 MCP 服务器详情", description="获取 MCP 服务器详情")
|
||
async def detail_controller(
|
||
id: int = Path(..., description="MCP ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
"""
|
||
获取 MCP 服务器详情接口
|
||
|
||
参数:
|
||
- id (int): MCP 服务器ID
|
||
|
||
返回:
|
||
- JSONResponse: 包含 MCP 服务器详情的 JSON 响应
|
||
"""
|
||
result_dict = await McpService.detail_service(auth=auth, id=id)
|
||
log.info(f"获取 MCP 服务器详情成功 {id}")
|
||
return SuccessResponse(data=result_dict, msg="获取 MCP 服务器详情成功")
|
||
|
||
|
||
@AIRouter.get("/list", summary="查询 MCP 服务器列表", description="查询 MCP 服务器列表")
|
||
async def list_controller(
|
||
page: PaginationQueryParam = Depends(),
|
||
search: McpQueryParam = Depends(),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
"""
|
||
查询 MCP 服务器列表接口
|
||
|
||
参数:
|
||
- page (PaginationQueryParam): 分页查询参数模型
|
||
- search (McpQueryParam): 查询参数模型
|
||
- auth (AuthSchema): 认证信息模型
|
||
|
||
返回:
|
||
- JSONResponse: 包含 MCP 服务器列表的 JSON 响应
|
||
"""
|
||
result_dict_list = await McpService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||
log.info(f"查询 MCP 服务器列表成功")
|
||
return SuccessResponse(data=result_dict, msg="查询 MCP 服务器列表成功")
|
||
|
||
|
||
@AIRouter.post("/create", summary="创建 MCP 服务器", description="创建 MCP 服务器")
|
||
async def create_controller(
|
||
data: McpCreateSchema,
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
"""
|
||
创建 MCP 服务器接口
|
||
|
||
参数:
|
||
- data (McpCreateSchema): 创建 MCP 服务器模型
|
||
- auth (AuthSchema): 认证信息模型
|
||
|
||
返回:
|
||
- JSONResponse: 包含创建 MCP 服务器结果的 JSON 响应
|
||
"""
|
||
result_dict = await McpService.create_service(auth=auth, data=data)
|
||
log.info(f"创建 MCP 服务器成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="创建 MCP 服务器成功")
|
||
|
||
|
||
@AIRouter.put("/update/{id}", summary="修改 MCP 服务器", description="修改 MCP 服务器")
|
||
async def update_controller(
|
||
data: McpUpdateSchema,
|
||
id: int = Path(..., description="MCP ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
"""
|
||
修改 MCP 服务器接口
|
||
|
||
参数:
|
||
- data (McpUpdateSchema): 修改 MCP 服务器模型
|
||
- id (int): MCP 服务器ID
|
||
- auth (AuthSchema): 认证信息模型
|
||
|
||
返回:
|
||
- JSONResponse: 包含修改 MCP 服务器结果的 JSON 响应
|
||
"""
|
||
result_dict = await McpService.update_service(auth=auth, id=id, data=data)
|
||
log.info(f"修改 MCP 服务器成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="修改 MCP 服务器成功")
|
||
|
||
|
||
@AIRouter.delete("/delete", summary="删除 MCP 服务器", description="删除 MCP 服务器")
|
||
async def delete_controller(
|
||
ids: list[int] = Body(..., description="ID列表"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
"""
|
||
删除 MCP 服务器接口
|
||
|
||
参数:
|
||
- ids (list[int]): MCP 服务器ID列表
|
||
- auth (AuthSchema): 认证信息模型
|
||
|
||
返回:
|
||
- JSONResponse: 包含删除 MCP 服务器结果的 JSON 响应
|
||
"""
|
||
await McpService.delete_service(auth=auth, ids=ids)
|
||
log.info(f"删除 MCP 服务器成功: {ids}")
|
||
return SuccessResponse(msg="删除 MCP 服务器成功")
|
||
|
||
|
||
@AIRouter.websocket("/ws/chat", name="WebSocket聊天")
|
||
async def websocket_chat_controller(
|
||
websocket: WebSocket,
|
||
):
|
||
"""
|
||
WebSocket聊天接口
|
||
|
||
ws://127.0.0.1:8001/api/v1/ai/mcp/ws/chat
|
||
"""
|
||
await websocket.accept()
|
||
try:
|
||
while True:
|
||
data = await websocket.receive_text()
|
||
# 流式发送响应
|
||
try:
|
||
async for chunk in McpService.chat_query(query=ChatQuerySchema(message=data)):
|
||
if chunk:
|
||
await websocket.send_text(chunk)
|
||
except Exception as e:
|
||
log.error(f"处理聊天查询出错: {str(e)}")
|
||
await websocket.send_text(f"抱歉,处理您的请求时出现了错误: {str(e)}")
|
||
except Exception as e:
|
||
log.error(f"WebSocket聊天出错: {str(e)}")
|
||
finally:
|
||
await websocket.close()
|
||
|
||
|
||
@AIModelConfigRouter.post("/test", summary="起名测试")
|
||
async def naming_test_controller(
|
||
test_data: AIModelTestSchema,
|
||
) -> JSONResponse:
|
||
"""
|
||
起名测试接口(用于小程序/外部调用)
|
||
|
||
- 读取模型配置和训练信息
|
||
- 内部流式调用AI供应商
|
||
- 组合响应后一次性返回
|
||
- 仅读数据,不保存对话记录
|
||
|
||
参数:
|
||
- model_type: 模型类型(enterprise_naming/personal_naming等)
|
||
- text: 用户输入的文本
|
||
|
||
返回:
|
||
- data: AI的完整响应文本
|
||
"""
|
||
log.info(f"[起名测试] 收到请求: model_type={test_data.model_type}, text={test_data.text[:50]}...")
|
||
|
||
result = await AIModelTestService.test_naming(
|
||
model_type=test_data.model_type,
|
||
text=test_data.text
|
||
)
|
||
|
||
log.info(f"[起名测试] 响应完成,长度: {len(result)}")
|
||
return SuccessResponse(data=result, msg="测试成功")
|
||
|
||
|
||
# ============== AI供应商配置 ==============
|
||
|
||
@AIProviderRouter.get("/detail/{id}", summary="获取AI供应商详情")
|
||
async def provider_detail_controller(
|
||
id: int = Path(..., description="AI供应商ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await AIProviderService.detail_service(auth=auth, id=id)
|
||
return SuccessResponse(data=result_dict, msg="获取AI供应商详情成功")
|
||
|
||
|
||
@AIProviderRouter.get("/list", summary="查询AI供应商列表")
|
||
async def provider_list_controller(
|
||
page: PaginationQueryParam = Depends(),
|
||
search: AIProviderQueryParam = Depends(),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict_list = await AIProviderService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||
return SuccessResponse(data=result_dict, msg="查询AI供应商列表成功")
|
||
|
||
|
||
@AIProviderRouter.post("/create", summary="创建AI供应商")
|
||
async def provider_create_controller(
|
||
data: AIProviderCreateSchema,
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await AIProviderService.create_service(auth=auth, data=data)
|
||
log.info(f"创建AI供应商成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="创建AI供应商成功")
|
||
|
||
|
||
@AIProviderRouter.put("/update/{id}", summary="修改AI供应商")
|
||
async def provider_update_controller(
|
||
data: AIProviderUpdateSchema,
|
||
id: int = Path(..., description="AI供应商ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await AIProviderService.update_service(auth=auth, id=id, data=data)
|
||
log.info(f"修改AI供应商成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="修改AI供应商成功")
|
||
|
||
|
||
@AIProviderRouter.delete("/delete", summary="删除AI供应商")
|
||
async def provider_delete_controller(
|
||
ids: list[int] = Body(..., description="ID列表"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
await AIProviderService.delete_service(auth=auth, ids=ids)
|
||
log.info(f"删除AI供应商成功: {ids}")
|
||
return SuccessResponse(msg="删除AI供应商成功")
|
||
|
||
|
||
# ============== 向量化配置 ==============
|
||
|
||
@EmbeddingConfigRouter.get("/detail/{id}", summary="获取向量化配置详情")
|
||
async def embedding_detail_controller(
|
||
id: int = Path(..., description="向量化配置ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await EmbeddingConfigService.detail_service(auth=auth, id=id)
|
||
return SuccessResponse(data=result_dict, msg="获取向量化配置详情成功")
|
||
|
||
|
||
@EmbeddingConfigRouter.get("/list", summary="查询向量化配置列表")
|
||
async def embedding_list_controller(
|
||
page: PaginationQueryParam = Depends(),
|
||
search: EmbeddingConfigQueryParam = Depends(),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict_list = await EmbeddingConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||
return SuccessResponse(data=result_dict, msg="查询向量化配置列表成功")
|
||
|
||
|
||
@EmbeddingConfigRouter.post("/create", summary="创建向量化配置")
|
||
async def embedding_create_controller(
|
||
data: EmbeddingConfigCreateSchema,
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await EmbeddingConfigService.create_service(auth=auth, data=data)
|
||
log.info(f"创建向量化配置成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="创建向量化配置成功")
|
||
|
||
|
||
@EmbeddingConfigRouter.put("/update/{id}", summary="修改向量化配置")
|
||
async def embedding_update_controller(
|
||
data: EmbeddingConfigUpdateSchema,
|
||
id: int = Path(..., description="向量化配置ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await EmbeddingConfigService.update_service(auth=auth, id=id, data=data)
|
||
log.info(f"修改向量化配置成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="修改向量化配置成功")
|
||
|
||
|
||
@EmbeddingConfigRouter.delete("/delete", summary="删除向量化配置")
|
||
async def embedding_delete_controller(
|
||
ids: list[int] = Body(..., description="ID列表"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
await EmbeddingConfigService.delete_service(auth=auth, ids=ids)
|
||
log.info(f"删除向量化配置成功: {ids}")
|
||
return SuccessResponse(msg="删除向量化配置成功")
|
||
|
||
|
||
# ============== 知识库管理 ==============
|
||
|
||
@KnowledgeBaseRouter.get("/detail/{id}", summary="获取知识库详情")
|
||
async def knowledge_detail_controller(
|
||
id: int = Path(..., description="知识库ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await KnowledgeBaseService.detail_service(auth=auth, id=id)
|
||
return SuccessResponse(data=result_dict, msg="获取知识库详情成功")
|
||
|
||
|
||
@KnowledgeBaseRouter.get("/list", summary="查询知识库列表")
|
||
async def knowledge_list_controller(
|
||
page: PaginationQueryParam = Depends(),
|
||
search: KnowledgeBaseQueryParam = Depends(),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict_list = await KnowledgeBaseService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||
return SuccessResponse(data=result_dict, msg="查询知识库列表成功")
|
||
|
||
|
||
@KnowledgeBaseRouter.post("/create", summary="创建知识库")
|
||
async def knowledge_create_controller(
|
||
name: str = Form(..., max_length=100, description="知识库名称"),
|
||
embedding_config_id: int | None = Form(None, description="向量化配置ID"),
|
||
description: str | None = Form(None, max_length=255, description="备注"),
|
||
files: list[UploadFile] = File(default=[], description="上传的文件列表"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
# 构造数据对象
|
||
data = KnowledgeBaseCreateSchema(name=name, embedding_config_id=embedding_config_id, description=description)
|
||
result_dict = await KnowledgeBaseService.create_service(auth=auth, data=data, files=files)
|
||
log.info(f"创建知识库成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="创建知识库成功")
|
||
|
||
|
||
@KnowledgeBaseRouter.put("/update/{id}", summary="修改知识库")
|
||
async def knowledge_update_controller(
|
||
data: KnowledgeBaseUpdateSchema,
|
||
id: int = Path(..., description="知识库ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await KnowledgeBaseService.update_service(auth=auth, id=id, data=data)
|
||
log.info(f"修改知识库成功: {result_dict}")
|
||
return SuccessResponse(data=result_dict, msg="修改知识库成功")
|
||
|
||
|
||
@KnowledgeBaseRouter.delete("/delete", summary="删除知识库")
|
||
async def knowledge_delete_controller(
|
||
ids: list[int] = Body(..., description="ID列表"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
await KnowledgeBaseService.delete_service(auth=auth, ids=ids)
|
||
log.info(f"删除知识库成功: {ids}")
|
||
return SuccessResponse(msg="删除知识库成功")
|
||
|
||
|
||
@KnowledgeBaseRouter.post("/retry/{id}", summary="重试向量化")
|
||
async def knowledge_retry_controller(
|
||
id: int = Path(..., description="知识库ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await KnowledgeBaseService.retry_service(auth=auth, id=id)
|
||
log.info(f"重试向量化成功: {id}")
|
||
return SuccessResponse(data=result_dict, msg="已启动重新向量化")
|
||
|
||
|
||
# ============== AI模型配置 ==============
|
||
|
||
@AIModelConfigRouter.get("/detail/{model_type}", summary="获取AI模型配置详情")
|
||
async def model_config_detail_controller(
|
||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await AIModelConfigService.detail_service(auth=auth, model_type=model_type)
|
||
return SuccessResponse(data=result_dict, msg="获取AI模型配置详情成功")
|
||
|
||
|
||
@AIModelConfigRouter.get("/list", summary="查询AI模型配置列表")
|
||
async def model_config_list_controller(
|
||
page: PaginationQueryParam = Depends(),
|
||
search: AIModelConfigQueryParam = Depends(),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict_list = await AIModelConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||
return SuccessResponse(data=result_dict, msg="查询AI模型配置列表成功")
|
||
|
||
|
||
@AIModelConfigRouter.put("/update/{model_type}", summary="更新AI模型配置")
|
||
async def model_config_update_controller(
|
||
data: AIModelConfigUpdateSchema,
|
||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_dict = await AIModelConfigService.update_service(auth=auth, model_type=model_type, data=data)
|
||
log.info(f"更新AI模型配置成功: {model_type}")
|
||
return SuccessResponse(data=result_dict, msg="更新AI模型配置成功")
|
||
|
||
|
||
@AIModelConfigRouter.get("/available-models/{provider_id}", summary="获取可用模型列表")
|
||
async def available_models_controller(
|
||
provider_id: int = Path(..., description="AI供应商ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_list = await AIModelConfigService.get_available_models(auth=auth, provider_id=provider_id)
|
||
return SuccessResponse(data=result_list, msg="获取可用模型列表成功")
|
||
|
||
|
||
# ============== AI模型训练对话 ==============
|
||
|
||
@AIModelConfigRouter.get("/messages/{model_type}", summary="获取训练对话记录")
|
||
async def training_messages_controller(
|
||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
result_list = await AIModelTrainingService.get_messages_service(auth=auth, model_type=model_type)
|
||
return SuccessResponse(data=result_list, msg="获取训练对话记录成功")
|
||
|
||
|
||
@AIModelConfigRouter.delete("/message/{message_id}", summary="删除单条训练对话记录")
|
||
async def delete_training_message_controller(
|
||
message_id: int = Path(..., description="对话记录ID"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
await AIModelTrainingService.delete_message_service(auth=auth, message_id=message_id)
|
||
log.info(f"删除训练对话记录成功: {message_id}")
|
||
return SuccessResponse(msg="删除训练对话记录成功")
|
||
|
||
|
||
@AIModelConfigRouter.delete("/messages/{model_type}", summary="清空训练对话记录")
|
||
async def clear_training_messages_controller(
|
||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> JSONResponse:
|
||
await AIModelTrainingService.clear_messages_service(auth=auth, model_type=model_type)
|
||
log.info(f"清空训练对话记录成功: {model_type}")
|
||
return SuccessResponse(msg="清空训练对话记录成功")
|
||
|
||
|
||
@AIModelConfigRouter.post("/chat", summary="训练对话")
|
||
async def training_chat_controller(
|
||
chat_data: AIModelTrainingChatSchema,
|
||
auth: AuthSchema = Depends(AuthPermission())
|
||
) -> StreamingResponse:
|
||
"""
|
||
训练对话接口,流式输出
|
||
"""
|
||
log.info(f"[训练对话] 收到请求: model_type={chat_data.model_type}, message={chat_data.message[:50]}...")
|
||
|
||
async def generate_response():
|
||
log.info("[训练对话] 开始生成响应")
|
||
chunk_count = 0
|
||
try:
|
||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||
if chunk:
|
||
chunk_count += 1
|
||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||
log.info(f"[训练对话] 响应生成完成,共 {chunk_count} 个 chunk")
|
||
except Exception as e:
|
||
log.error(f"训练对话出错: {str(e)}")
|
||
yield f"处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||
|
||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||
|
||
|
||
@AIModelConfigRouter.websocket("/ws/chat")
|
||
async def training_websocket_chat_controller(
|
||
websocket: WebSocket,
|
||
):
|
||
"""
|
||
训练对话 WebSocket 接口
|
||
"""
|
||
from app.core.security import verify_token
|
||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||
|
||
await websocket.accept()
|
||
try:
|
||
while True:
|
||
data = await websocket.receive_json()
|
||
model_type = data.get('model_type', 'naming')
|
||
message = data.get('message', '')
|
||
config_changed = data.get('config_changed', False)
|
||
config_data = data.get('config_data')
|
||
|
||
# 构建chat_data
|
||
chat_data = AIModelTrainingChatSchema(
|
||
model_type=model_type,
|
||
message=message,
|
||
config_changed=config_changed,
|
||
config_data=AIModelConfigUpdateSchema(**config_data) if config_data else None
|
||
)
|
||
|
||
# 构建简化的auth(WebSocket没有正常的认证流程,实际使用时需要实现认证)
|
||
auth = AuthSchema()
|
||
|
||
# 流式发送响应
|
||
try:
|
||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||
if chunk:
|
||
await websocket.send_text(chunk)
|
||
except Exception as e:
|
||
log.error(f"处理训练对话出错: {str(e)}")
|
||
await websocket.send_text(f"处理您的请求时出现了错误: {str(e)}")
|
||
except Exception as e:
|
||
log.error(f"WebSocket训练对话出错: {str(e)}")
|
||
finally:
|
||
await websocket.close()
|