upload project source code
This commit is contained in:
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
@@ -0,0 +1,537 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Body, WebSocket, Form, UploadFile, File
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.common.response import StreamResponse, SuccessResponse
|
||||
from app.common.request import PaginationService
|
||||
from app.core.base_params import PaginationQueryParam
|
||||
from app.core.dependencies import AuthPermission
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from app.core.router_class import OperationLogRoute
|
||||
from .service import McpService, AIProviderService, EmbeddingConfigService, KnowledgeBaseService, AIModelConfigService, AIModelTrainingService, AIModelTestService
|
||||
from .schema import (
|
||||
McpCreateSchema, McpUpdateSchema, ChatQuerySchema, McpQueryParam,
|
||||
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderQueryParam,
|
||||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigQueryParam,
|
||||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseQueryParam,
|
||||
AIModelConfigUpdateSchema, AIModelConfigQueryParam, AIModelTrainingChatSchema,
|
||||
AIModelTestSchema
|
||||
)
|
||||
|
||||
|
||||
AIRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai", tags=["MCP智能助手"])
|
||||
AIProviderRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/provider", tags=["AI供应商配置"])
|
||||
EmbeddingConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/embedding", tags=["向量化配置"])
|
||||
KnowledgeBaseRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/knowledge", tags=["知识库管理"])
|
||||
AIModelConfigRouter = APIRouter(route_class=OperationLogRoute, prefix="/ai/model", tags=["AI模型配置"])
|
||||
|
||||
|
||||
@AIRouter.post("/chat", summary="智能对话", description="与MCP智能助手进行对话")
|
||||
async def chat_controller(
|
||||
query: ChatQuerySchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
智能对话接口
|
||||
|
||||
参数:
|
||||
- query (ChatQuerySchema): 聊天查询模型
|
||||
|
||||
返回:
|
||||
- StreamingResponse: 流式响应,每次返回一个聊天响应
|
||||
"""
|
||||
user_name = auth.user.name if auth.user else "未知用户"
|
||||
log.info(f"用户 {user_name} 发起智能对话: {query.message[:50]}...")
|
||||
|
||||
async def generate_response():
|
||||
try:
|
||||
async for chunk in McpService.chat_query(query=query):
|
||||
# 确保返回的是字节串
|
||||
if chunk:
|
||||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||||
except Exception as e:
|
||||
log.error(f"流式响应出错: {str(e)}")
|
||||
yield f"抱歉,处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||||
|
||||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
@AIRouter.get("/detail/{id}", summary="获取 MCP 服务器详情", description="获取 MCP 服务器详情")
|
||||
async def detail_controller(
|
||||
id: int = Path(..., description="MCP ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取 MCP 服务器详情接口
|
||||
|
||||
参数:
|
||||
- id (int): MCP 服务器ID
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含 MCP 服务器详情的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.detail_service(auth=auth, id=id)
|
||||
log.info(f"获取 MCP 服务器详情成功 {id}")
|
||||
return SuccessResponse(data=result_dict, msg="获取 MCP 服务器详情成功")
|
||||
|
||||
|
||||
@AIRouter.get("/list", summary="查询 MCP 服务器列表", description="查询 MCP 服务器列表")
|
||||
async def list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: McpQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
查询 MCP 服务器列表接口
|
||||
|
||||
参数:
|
||||
- page (PaginationQueryParam): 分页查询参数模型
|
||||
- search (McpQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含 MCP 服务器列表的 JSON 响应
|
||||
"""
|
||||
result_dict_list = await McpService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
log.info(f"查询 MCP 服务器列表成功")
|
||||
return SuccessResponse(data=result_dict, msg="查询 MCP 服务器列表成功")
|
||||
|
||||
|
||||
@AIRouter.post("/create", summary="创建 MCP 服务器", description="创建 MCP 服务器")
|
||||
async def create_controller(
|
||||
data: McpCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
创建 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- data (McpCreateSchema): 创建 MCP 服务器模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含创建 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建 MCP 服务器成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.put("/update/{id}", summary="修改 MCP 服务器", description="修改 MCP 服务器")
|
||||
async def update_controller(
|
||||
data: McpUpdateSchema,
|
||||
id: int = Path(..., description="MCP ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
修改 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- data (McpUpdateSchema): 修改 MCP 服务器模型
|
||||
- id (int): MCP 服务器ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含修改 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
result_dict = await McpService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改 MCP 服务器成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.delete("/delete", summary="删除 MCP 服务器", description="删除 MCP 服务器")
|
||||
async def delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
删除 MCP 服务器接口
|
||||
|
||||
参数:
|
||||
- ids (list[int]): MCP 服务器ID列表
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含删除 MCP 服务器结果的 JSON 响应
|
||||
"""
|
||||
await McpService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除 MCP 服务器成功: {ids}")
|
||||
return SuccessResponse(msg="删除 MCP 服务器成功")
|
||||
|
||||
|
||||
@AIRouter.websocket("/ws/chat", name="WebSocket聊天")
|
||||
async def websocket_chat_controller(
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""
|
||||
WebSocket聊天接口
|
||||
|
||||
ws://127.0.0.1:8001/api/v1/ai/mcp/ws/chat
|
||||
"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# 流式发送响应
|
||||
try:
|
||||
async for chunk in McpService.chat_query(query=ChatQuerySchema(message=data)):
|
||||
if chunk:
|
||||
await websocket.send_text(chunk)
|
||||
except Exception as e:
|
||||
log.error(f"处理聊天查询出错: {str(e)}")
|
||||
await websocket.send_text(f"抱歉,处理您的请求时出现了错误: {str(e)}")
|
||||
except Exception as e:
|
||||
log.error(f"WebSocket聊天出错: {str(e)}")
|
||||
finally:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@AIModelConfigRouter.post("/test", summary="起名测试")
|
||||
async def naming_test_controller(
|
||||
test_data: AIModelTestSchema,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
起名测试接口(用于小程序/外部调用)
|
||||
|
||||
- 读取模型配置和训练信息
|
||||
- 内部流式调用AI供应商
|
||||
- 组合响应后一次性返回
|
||||
- 仅读数据,不保存对话记录
|
||||
|
||||
参数:
|
||||
- model_type: 模型类型(enterprise_naming/personal_naming等)
|
||||
- text: 用户输入的文本
|
||||
|
||||
返回:
|
||||
- data: AI的完整响应文本
|
||||
"""
|
||||
log.info(f"[起名测试] 收到请求: model_type={test_data.model_type}, text={test_data.text[:50]}...")
|
||||
|
||||
result = await AIModelTestService.test_naming(
|
||||
model_type=test_data.model_type,
|
||||
text=test_data.text
|
||||
)
|
||||
|
||||
log.info(f"[起名测试] 响应完成,长度: {len(result)}")
|
||||
return SuccessResponse(data=result, msg="测试成功")
|
||||
|
||||
|
||||
# ============== AI供应商配置 ==============
|
||||
|
||||
@AIProviderRouter.get("/detail/{id}", summary="获取AI供应商详情")
|
||||
async def provider_detail_controller(
|
||||
id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取AI供应商详情成功")
|
||||
|
||||
|
||||
@AIProviderRouter.get("/list", summary="查询AI供应商列表")
|
||||
async def provider_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: AIProviderQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await AIProviderService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询AI供应商列表成功")
|
||||
|
||||
|
||||
@AIProviderRouter.post("/create", summary="创建AI供应商")
|
||||
async def provider_create_controller(
|
||||
data: AIProviderCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建AI供应商成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建AI供应商成功")
|
||||
|
||||
|
||||
@AIProviderRouter.put("/update/{id}", summary="修改AI供应商")
|
||||
async def provider_update_controller(
|
||||
data: AIProviderUpdateSchema,
|
||||
id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIProviderService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改AI供应商成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改AI供应商成功")
|
||||
|
||||
|
||||
@AIProviderRouter.delete("/delete", summary="删除AI供应商")
|
||||
async def provider_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIProviderService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除AI供应商成功: {ids}")
|
||||
return SuccessResponse(msg="删除AI供应商成功")
|
||||
|
||||
|
||||
# ============== 向量化配置 ==============
|
||||
|
||||
@EmbeddingConfigRouter.get("/detail/{id}", summary="获取向量化配置详情")
|
||||
async def embedding_detail_controller(
|
||||
id: int = Path(..., description="向量化配置ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取向量化配置详情成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.get("/list", summary="查询向量化配置列表")
|
||||
async def embedding_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: EmbeddingConfigQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await EmbeddingConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询向量化配置列表成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.post("/create", summary="创建向量化配置")
|
||||
async def embedding_create_controller(
|
||||
data: EmbeddingConfigCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建向量化配置成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建向量化配置成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.put("/update/{id}", summary="修改向量化配置")
|
||||
async def embedding_update_controller(
|
||||
data: EmbeddingConfigUpdateSchema,
|
||||
id: int = Path(..., description="向量化配置ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await EmbeddingConfigService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改向量化配置成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改向量化配置成功")
|
||||
|
||||
|
||||
@EmbeddingConfigRouter.delete("/delete", summary="删除向量化配置")
|
||||
async def embedding_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await EmbeddingConfigService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除向量化配置成功: {ids}")
|
||||
return SuccessResponse(msg="删除向量化配置成功")
|
||||
|
||||
|
||||
# ============== 知识库管理 ==============
|
||||
|
||||
@KnowledgeBaseRouter.get("/detail/{id}", summary="获取知识库详情")
|
||||
async def knowledge_detail_controller(
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.detail_service(auth=auth, id=id)
|
||||
return SuccessResponse(data=result_dict, msg="获取知识库详情成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.get("/list", summary="查询知识库列表")
|
||||
async def knowledge_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: KnowledgeBaseQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await KnowledgeBaseService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询知识库列表成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.post("/create", summary="创建知识库")
|
||||
async def knowledge_create_controller(
|
||||
name: str = Form(..., max_length=100, description="知识库名称"),
|
||||
embedding_config_id: int | None = Form(None, description="向量化配置ID"),
|
||||
description: str | None = Form(None, max_length=255, description="备注"),
|
||||
files: list[UploadFile] = File(default=[], description="上传的文件列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
# 构造数据对象
|
||||
data = KnowledgeBaseCreateSchema(name=name, embedding_config_id=embedding_config_id, description=description)
|
||||
result_dict = await KnowledgeBaseService.create_service(auth=auth, data=data, files=files)
|
||||
log.info(f"创建知识库成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.put("/update/{id}", summary="修改知识库")
|
||||
async def knowledge_update_controller(
|
||||
data: KnowledgeBaseUpdateSchema,
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改知识库成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.delete("/delete", summary="删除知识库")
|
||||
async def knowledge_delete_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await KnowledgeBaseService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除知识库成功: {ids}")
|
||||
return SuccessResponse(msg="删除知识库成功")
|
||||
|
||||
|
||||
@KnowledgeBaseRouter.post("/retry/{id}", summary="重试向量化")
|
||||
async def knowledge_retry_controller(
|
||||
id: int = Path(..., description="知识库ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await KnowledgeBaseService.retry_service(auth=auth, id=id)
|
||||
log.info(f"重试向量化成功: {id}")
|
||||
return SuccessResponse(data=result_dict, msg="已启动重新向量化")
|
||||
|
||||
|
||||
# ============== AI模型配置 ==============
|
||||
|
||||
@AIModelConfigRouter.get("/detail/{model_type}", summary="获取AI模型配置详情")
|
||||
async def model_config_detail_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIModelConfigService.detail_service(auth=auth, model_type=model_type)
|
||||
return SuccessResponse(data=result_dict, msg="获取AI模型配置详情成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.get("/list", summary="查询AI模型配置列表")
|
||||
async def model_config_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: AIModelConfigQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict_list = await AIModelConfigService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
return SuccessResponse(data=result_dict, msg="查询AI模型配置列表成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.put("/update/{model_type}", summary="更新AI模型配置")
|
||||
async def model_config_update_controller(
|
||||
data: AIModelConfigUpdateSchema,
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_dict = await AIModelConfigService.update_service(auth=auth, model_type=model_type, data=data)
|
||||
log.info(f"更新AI模型配置成功: {model_type}")
|
||||
return SuccessResponse(data=result_dict, msg="更新AI模型配置成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.get("/available-models/{provider_id}", summary="获取可用模型列表")
|
||||
async def available_models_controller(
|
||||
provider_id: int = Path(..., description="AI供应商ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_list = await AIModelConfigService.get_available_models(auth=auth, provider_id=provider_id)
|
||||
return SuccessResponse(data=result_list, msg="获取可用模型列表成功")
|
||||
|
||||
|
||||
# ============== AI模型训练对话 ==============
|
||||
|
||||
@AIModelConfigRouter.get("/messages/{model_type}", summary="获取训练对话记录")
|
||||
async def training_messages_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
result_list = await AIModelTrainingService.get_messages_service(auth=auth, model_type=model_type)
|
||||
return SuccessResponse(data=result_list, msg="获取训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.delete("/message/{message_id}", summary="删除单条训练对话记录")
|
||||
async def delete_training_message_controller(
|
||||
message_id: int = Path(..., description="对话记录ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIModelTrainingService.delete_message_service(auth=auth, message_id=message_id)
|
||||
log.info(f"删除训练对话记录成功: {message_id}")
|
||||
return SuccessResponse(msg="删除训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.delete("/messages/{model_type}", summary="清空训练对话记录")
|
||||
async def clear_training_messages_controller(
|
||||
model_type: str = Path(..., description="模型类型(naming/renaming/scoring/report)"),
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> JSONResponse:
|
||||
await AIModelTrainingService.clear_messages_service(auth=auth, model_type=model_type)
|
||||
log.info(f"清空训练对话记录成功: {model_type}")
|
||||
return SuccessResponse(msg="清空训练对话记录成功")
|
||||
|
||||
|
||||
@AIModelConfigRouter.post("/chat", summary="训练对话")
|
||||
async def training_chat_controller(
|
||||
chat_data: AIModelTrainingChatSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission())
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
训练对话接口,流式输出
|
||||
"""
|
||||
log.info(f"[训练对话] 收到请求: model_type={chat_data.model_type}, message={chat_data.message[:50]}...")
|
||||
|
||||
async def generate_response():
|
||||
log.info("[训练对话] 开始生成响应")
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
yield chunk.encode('utf-8') if isinstance(chunk, str) else chunk
|
||||
log.info(f"[训练对话] 响应生成完成,共 {chunk_count} 个 chunk")
|
||||
except Exception as e:
|
||||
log.error(f"训练对话出错: {str(e)}")
|
||||
yield f"处理您的请求时出现了错误: {str(e)}".encode('utf-8')
|
||||
|
||||
return StreamResponse(generate_response(), media_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
@AIModelConfigRouter.websocket("/ws/chat")
|
||||
async def training_websocket_chat_controller(
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""
|
||||
训练对话 WebSocket 接口
|
||||
"""
|
||||
from app.core.security import verify_token
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
model_type = data.get('model_type', 'naming')
|
||||
message = data.get('message', '')
|
||||
config_changed = data.get('config_changed', False)
|
||||
config_data = data.get('config_data')
|
||||
|
||||
# 构建chat_data
|
||||
chat_data = AIModelTrainingChatSchema(
|
||||
model_type=model_type,
|
||||
message=message,
|
||||
config_changed=config_changed,
|
||||
config_data=AIModelConfigUpdateSchema(**config_data) if config_data else None
|
||||
)
|
||||
|
||||
# 构建简化的auth(WebSocket没有正常的认证流程,实际使用时需要实现认证)
|
||||
auth = AuthSchema()
|
||||
|
||||
# 流式发送响应
|
||||
try:
|
||||
async for chunk in AIModelTrainingService.chat_stream(auth=auth, chat_data=chat_data):
|
||||
if chunk:
|
||||
await websocket.send_text(chunk)
|
||||
except Exception as e:
|
||||
log.error(f"处理训练对话出错: {str(e)}")
|
||||
await websocket.send_text(f"处理您的请求时出现了错误: {str(e)}")
|
||||
except Exception as e:
|
||||
log.error(f"WebSocket训练对话出错: {str(e)}")
|
||||
finally:
|
||||
await websocket.close()
|
||||
@@ -0,0 +1,270 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Sequence, Any
|
||||
|
||||
from app.core.base_crud import CRUDBase
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .model import McpModel, AIProviderModel, EmbeddingConfigModel, KnowledgeBaseModel, AIModelConfigModel, AIModelTrainingMessageModel
|
||||
from .schema import (
|
||||
McpCreateSchema, McpUpdateSchema,
|
||||
AIProviderCreateSchema, AIProviderUpdateSchema,
|
||||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema,
|
||||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema,
|
||||
AIModelConfigCreateSchema, AIModelConfigUpdateSchema,
|
||||
AIModelTrainingMessageCreateSchema
|
||||
)
|
||||
|
||||
|
||||
class McpCRUD(CRUDBase[McpModel, McpCreateSchema, McpUpdateSchema]):
|
||||
"""MCP 服务器数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
"""
|
||||
初始化CRUD
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
self.auth = auth
|
||||
super().__init__(model=McpModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> McpModel | None:
|
||||
"""
|
||||
获取MCP服务器详情
|
||||
|
||||
参数:
|
||||
- id (int): MCP服务器ID
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- McpModel | None: MCP服务器模型实例(如果存在)
|
||||
"""
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> McpModel | None:
|
||||
"""
|
||||
通过名称获取MCP服务器
|
||||
|
||||
参数:
|
||||
- name (str): MCP服务器名称
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Optional[McpModel]: MCP服务器模型实例(如果存在)
|
||||
"""
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[McpModel]:
|
||||
"""
|
||||
列表查询MCP服务器
|
||||
|
||||
参数:
|
||||
- search (dict | None): 查询参数字典
|
||||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Sequence[McpModel]: MCP服务器模型实例序列
|
||||
"""
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: McpCreateSchema) -> McpModel | None:
|
||||
"""
|
||||
创建MCP服务器
|
||||
|
||||
参数:
|
||||
- data (McpCreateSchema): 创建MCP服务器模型
|
||||
|
||||
返回:
|
||||
- Optional[McpModel]: 创建的MCP服务器模型实例(如果成功)
|
||||
"""
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: McpUpdateSchema) -> McpModel | None:
|
||||
"""
|
||||
更新MCP服务器
|
||||
|
||||
参数:
|
||||
- id (int): MCP服务器ID
|
||||
- data (McpUpdateSchema): 更新MCP服务器模型
|
||||
|
||||
返回:
|
||||
- McpModel | None: 更新的MCP服务器模型实例(如果成功)
|
||||
"""
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
"""
|
||||
批量删除MCP服务器
|
||||
|
||||
参数:
|
||||
- ids (list[int]): MCP服务器ID列表
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
|
||||
class AIProviderCRUD(CRUDBase[AIProviderModel, AIProviderCreateSchema, AIProviderUpdateSchema]):
|
||||
"""AI供应商数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=AIProviderModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIProviderModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> AIProviderModel | None:
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIProviderModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: AIProviderCreateSchema) -> AIProviderModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: AIProviderUpdateSchema) -> AIProviderModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def clear_default_crud(self) -> None:
|
||||
"""清除所有默认设置"""
|
||||
obj_list = await self.list(search={'is_default': 1})
|
||||
for obj in obj_list:
|
||||
await self.update(id=obj.id, data=AIProviderUpdateSchema(
|
||||
name=obj.name,
|
||||
provider_type=obj.provider_type,
|
||||
base_url=obj.base_url,
|
||||
api_key=obj.api_key,
|
||||
is_default=0,
|
||||
description=obj.description
|
||||
))
|
||||
|
||||
|
||||
class EmbeddingConfigCRUD(CRUDBase[EmbeddingConfigModel, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema]):
|
||||
"""向量化配置数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=EmbeddingConfigModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[EmbeddingConfigModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: EmbeddingConfigCreateSchema) -> EmbeddingConfigModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: EmbeddingConfigUpdateSchema) -> EmbeddingConfigModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def clear_default_crud(self) -> None:
|
||||
"""清除所有默认设置"""
|
||||
obj_list = await self.list(search={'is_default': 1})
|
||||
for obj in obj_list:
|
||||
await self.update(id=obj.id, data=EmbeddingConfigUpdateSchema(
|
||||
name=obj.name,
|
||||
embedding_type=obj.embedding_type,
|
||||
model_name=obj.model_name,
|
||||
base_url=obj.base_url,
|
||||
api_key=obj.api_key,
|
||||
is_default=0,
|
||||
description=obj.description
|
||||
))
|
||||
|
||||
|
||||
class KnowledgeBaseCRUD(CRUDBase[KnowledgeBaseModel, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema]):
|
||||
"""知识库数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=KnowledgeBaseModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||||
return await self.get(name=name, preload=preload)
|
||||
|
||||
async def get_by_collection_crud(self, collection_name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||||
return await self.get(collection_name=collection_name, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[KnowledgeBaseModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'desc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: dict) -> KnowledgeBaseModel | None:
|
||||
"""create方法接受字典,因为需要额外字段"""
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: KnowledgeBaseUpdateSchema | dict) -> KnowledgeBaseModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
|
||||
class AIModelConfigCRUD(CRUDBase[AIModelConfigModel, AIModelConfigCreateSchema, AIModelConfigUpdateSchema]):
|
||||
"""AI模型配置数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=AIModelConfigModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_by_model_type_crud(self, model_type: str, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
|
||||
return await self.get(model_type=model_type, preload=preload)
|
||||
|
||||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIModelConfigModel]:
|
||||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: AIModelConfigCreateSchema | dict) -> AIModelConfigModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: AIModelConfigUpdateSchema | dict) -> AIModelConfigModel | None:
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
|
||||
class AIModelTrainingMessageCRUD(CRUDBase[AIModelTrainingMessageModel, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageCreateSchema]):
|
||||
"""AI模型训练对话数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
self.auth = auth
|
||||
super().__init__(model=AIModelTrainingMessageModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelTrainingMessageModel | None:
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_list_by_config_crud(self, model_config_id: int, preload: list[str | Any] | None = None) -> Sequence[AIModelTrainingMessageModel]:
|
||||
"""获取某个模型配置的所有训练对话"""
|
||||
return await self.list(search={'model_config_id': model_config_id}, order_by=[{'id': 'asc'}], preload=preload)
|
||||
|
||||
async def create_crud(self, data: AIModelTrainingMessageCreateSchema | dict) -> AIModelTrainingMessageModel | None:
|
||||
return await self.create(data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def delete_by_config_crud(self, model_config_id: int) -> None:
|
||||
"""删除某个模型配置的所有训练对话"""
|
||||
messages = await self.list(search={'model_config_id': model_config_id})
|
||||
if messages:
|
||||
ids = [m.id for m in messages]
|
||||
await self.delete(ids=ids)
|
||||
@@ -0,0 +1,155 @@
|
||||
'''
|
||||
Author: caoziyuan ziyuan.cao@zhuying.com
|
||||
Date: 2025-12-22 17:42:10
|
||||
LastEditors: caoziyuan ziyuan.cao@zhuying.com
|
||||
LastEditTime: 2025-12-22 18:03:28
|
||||
FilePath: \naming-backend\app\api\v1\module_application\ai\model.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sqlalchemy import JSON, String, Integer, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.base_model import ModelMixin, UserMixin
|
||||
|
||||
|
||||
class McpModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
MCP 服务器表
|
||||
MCP类型:
|
||||
- 0: stdio (标准输入输出)
|
||||
- 1: sse (Server-Sent Events)
|
||||
"""
|
||||
__tablename__: str = 'app_ai_mcp'
|
||||
__table_args__: dict[str, str] = ({'comment': 'MCP 服务器表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), comment='MCP 名称')
|
||||
type: Mapped[int] = mapped_column(Integer, default=0, comment='MCP 类型(0:stdio 1:sse)')
|
||||
url: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程 SSE 地址')
|
||||
command: Mapped[str | None] = mapped_column(String(255), default=None, comment='MCP 命令')
|
||||
args: Mapped[str | None] = mapped_column(String(255), default=None, comment='MCP 命令参数')
|
||||
env: Mapped[dict[str, str] | None] = mapped_column(JSON(), default=None, comment='MCP 环境变量')
|
||||
|
||||
|
||||
class AIProviderModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
AI供应商配置表
|
||||
存储AI服务供应商的接口地址和API Key
|
||||
"""
|
||||
__tablename__: str = 'app_ai_provider'
|
||||
__table_args__: dict[str, str] = ({'comment': 'AI供应商配置表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False, comment='供应商名称')
|
||||
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
|
||||
base_url: Mapped[str] = mapped_column(String(255), nullable=False, comment='接口地址BaseURL')
|
||||
api_key: Mapped[str] = mapped_column(String(255), nullable=False, comment='API Key')
|
||||
is_default: Mapped[int] = mapped_column(Integer, default=0, comment='是否默认供应商(0:否 1:是)')
|
||||
|
||||
|
||||
class EmbeddingConfigModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
知识库向量化配置表
|
||||
支持本地或远程向量化服务
|
||||
"""
|
||||
__tablename__: str = 'app_ai_embedding_config'
|
||||
__table_args__: dict[str, str] = ({'comment': '知识库向量化配置表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False, comment='配置名称')
|
||||
embedding_type: Mapped[int] = mapped_column(Integer, default=0, comment='向量化类型(0:本地 1:远程)')
|
||||
model_name: Mapped[str] = mapped_column(String(100), nullable=False, comment='Embedding模型名称')
|
||||
base_url: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程接口地址(远程模式必填)')
|
||||
api_key: Mapped[str | None] = mapped_column(String(255), default=None, comment='远程API Key(远程模式必填)')
|
||||
is_default: Mapped[int] = mapped_column(Integer, default=0, comment='是否默认配置(0:否 1:是)')
|
||||
|
||||
|
||||
class KnowledgeBaseModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
知识库表
|
||||
存储知识库信息,关联向量化配置
|
||||
"""
|
||||
__tablename__: str = 'app_ai_knowledge_base'
|
||||
__table_args__: dict[str, str] = ({'comment': '知识库表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by", "embedding_config"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, comment='知识库名称')
|
||||
embedding_config_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('app_ai_embedding_config.id', ondelete="SET NULL", onupdate="CASCADE"),
|
||||
default=None,
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment='向量化配置ID'
|
||||
)
|
||||
collection_name: Mapped[str] = mapped_column(String(100), nullable=False, comment='ChromaDB集合名称')
|
||||
document_count: Mapped[int] = mapped_column(Integer, default=0, comment='文档数量')
|
||||
vector_count: Mapped[int] = mapped_column(Integer, default=0, comment='向量数量')
|
||||
kb_status: Mapped[int] = mapped_column(Integer, default=0, comment='知识库状态(0:待处理 1:处理中 2:已完成 3:处理失败)')
|
||||
error_message: Mapped[str | None] = mapped_column(Text, default=None, comment='错误信息')
|
||||
file_paths: Mapped[list[str] | None] = mapped_column(JSON, default=None, comment='文件路径列表')
|
||||
|
||||
# 关联关系
|
||||
embedding_config: Mapped["EmbeddingConfigModel | None"] = relationship(
|
||||
"EmbeddingConfigModel",
|
||||
lazy="selectin",
|
||||
foreign_keys=[embedding_config_id],
|
||||
uselist=False
|
||||
)
|
||||
|
||||
|
||||
class AIModelConfigModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
AI模型配置表
|
||||
存储不同类型AI模型的配置信息
|
||||
模型类型:
|
||||
- enterprise_naming(企业起名), enterprise_renaming(企业改名), enterprise_scoring(企业测名), enterprise_scoring_trial(企业测名试用)
|
||||
- personal_naming(个人起名), personal_renaming(个人改名), personal_scoring(个人测名), personal_scoring_trial(个人测名试用)
|
||||
"""
|
||||
__tablename__: str = 'app_ai_model_config'
|
||||
__table_args__: dict[str, str] = ({'comment': 'AI模型配置表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by", "provider", "knowledge_bases"]
|
||||
|
||||
model_type: Mapped[str] = mapped_column(String(50), nullable=False, unique=True, comment='模型类型(enterprise_naming/enterprise_renaming/enterprise_scoring/enterprise_scoring_trial/personal_naming/personal_renaming/personal_scoring/personal_scoring_trial)')
|
||||
model_name: Mapped[str | None] = mapped_column(String(100), default=None, comment='使用的模型名称')
|
||||
provider_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('app_ai_provider.id', ondelete="SET NULL", onupdate="CASCADE"),
|
||||
default=None,
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment='AI供应商ID'
|
||||
)
|
||||
system_prompt: Mapped[str | None] = mapped_column(Text, default=None, comment='系统提示词')
|
||||
temperature: Mapped[float] = mapped_column(default=1.0, comment='模型温度(0-2)')
|
||||
knowledge_base_ids: Mapped[list[int] | None] = mapped_column(JSON, default=None, comment='关联的知识库ID列表')
|
||||
|
||||
# 关联关系
|
||||
provider: Mapped["AIProviderModel | None"] = relationship(
|
||||
"AIProviderModel",
|
||||
lazy="selectin",
|
||||
foreign_keys=[provider_id],
|
||||
uselist=False
|
||||
)
|
||||
|
||||
|
||||
class AIModelTrainingMessageModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
AI模型训练对话记录表
|
||||
存储训练对话的历史记录
|
||||
"""
|
||||
__tablename__: str = 'app_ai_model_training_message'
|
||||
__table_args__: dict[str, str] = ({'comment': 'AI模型训练对话记录表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
model_config_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('app_ai_model_config.id', ondelete="CASCADE", onupdate="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment='模型配置ID'
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False, comment='角色(user/assistant)')
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False, comment='消息内容')
|
||||
@@ -0,0 +1,307 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import ConfigDict, Field, HttpUrl, BaseModel
|
||||
from fastapi import Query
|
||||
|
||||
from app.core.base_schema import BaseSchema
|
||||
from app.common.enums import McpLLMProvider, EmbeddingType, KnowledgeBaseStatus
|
||||
from app.core.base_schema import BaseSchema, UserBySchema
|
||||
from app.common.enums import McpType
|
||||
from app.core.validator import DateTimeStr
|
||||
|
||||
|
||||
class ChatQuerySchema(BaseModel):
|
||||
"""聊天查询模型"""
|
||||
message: str = Field(..., min_length=1, max_length=4000, description="聊天消息")
|
||||
|
||||
|
||||
class McpCreateSchema(BaseModel):
|
||||
"""创建 MCP 服务器参数"""
|
||||
name: str = Field(..., max_length=64, description='MCP 名称')
|
||||
type: McpType = Field(McpType.stdio, description='MCP 类型')
|
||||
description: str | None = Field(None, max_length=255, description='MCP 描述')
|
||||
url: HttpUrl | None = Field(None, description='远程 SSE 地址')
|
||||
command: str | None = Field(None, max_length=255, description='MCP 命令')
|
||||
args: str | None = Field(None, max_length=255, description='MCP 命令参数,多个参数用英文逗号隔开')
|
||||
env: dict[str, str] | None = Field(None, description='MCP 环境变量')
|
||||
|
||||
|
||||
class McpUpdateSchema(McpCreateSchema):
|
||||
"""更新 MCP 服务器参数"""
|
||||
...
|
||||
|
||||
|
||||
class McpOutSchema(McpCreateSchema, BaseSchema, UserBySchema):
|
||||
"""MCP 服务器详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class McpQueryParam:
|
||||
"""MCP 服务器查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="MCP 名称"),
|
||||
type: McpType | None = Query(None, description="MCP 类型"),
|
||||
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
created_id: int | None = Query(None, description="创建人"),
|
||||
updated_id: int | None = Query(None, description="更新人"),
|
||||
) -> None:
|
||||
|
||||
# 模糊查询字段
|
||||
self.name = ("like", name) if name else None
|
||||
|
||||
# 精确查询字段
|
||||
self.type = type
|
||||
self.created_id = created_id
|
||||
self.updated_id = updated_id
|
||||
|
||||
# 时间范围查询
|
||||
if created_time and len(created_time) == 2:
|
||||
self.created_time = ("between", (created_time[0], created_time[1]))
|
||||
if updated_time and len(updated_time) == 2:
|
||||
self.updated_time = ("between", (updated_time[0], updated_time[1]))
|
||||
|
||||
|
||||
class McpChatParam(BaseSchema):
|
||||
"""MCP 聊天参数"""
|
||||
pk: list[int] = Field(..., description='MCP ID 列表')
|
||||
provider: McpLLMProvider = Field(McpLLMProvider.openai, description='LLM 供应商')
|
||||
model: str = Field(..., description='LLM 名称')
|
||||
key: str = Field(..., description='LLM API Key')
|
||||
base_url: str | None = Field(None, description='自定义 LLM API 地址,必须兼容 openai 供应商')
|
||||
prompt: str = Field(..., description='用户提示词')
|
||||
|
||||
|
||||
# ============== AI供应商配置 ==============
|
||||
|
||||
class AIProviderCreateSchema(BaseModel):
|
||||
"""创建 AI供应商参数"""
|
||||
name: str = Field(..., max_length=50, description='供应商名称')
|
||||
provider_type: str = Field(..., max_length=50, description='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
|
||||
base_url: str = Field(..., max_length=255, description='接口地址BaseURL')
|
||||
api_key: str = Field(..., max_length=255, description='API Key')
|
||||
is_default: int = Field(0, description='是否默认供应商(0:否 1:是)')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class AIProviderUpdateSchema(AIProviderCreateSchema):
|
||||
"""更新 AI供应商参数"""
|
||||
...
|
||||
|
||||
|
||||
class AIProviderOutSchema(AIProviderCreateSchema, BaseSchema, UserBySchema):
|
||||
"""AI供应商详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AIProviderQueryParam:
|
||||
"""AI供应商查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="供应商名称"),
|
||||
provider_type: str | None = Query(None, description="供应商类型"),
|
||||
is_default: int | None = Query(None, description="是否默认"),
|
||||
) -> None:
|
||||
self.name = ("like", name) if name else None
|
||||
self.provider_type = provider_type
|
||||
self.is_default = is_default
|
||||
|
||||
|
||||
# ============== 向量化配置 ==============
|
||||
|
||||
class EmbeddingConfigCreateSchema(BaseModel):
|
||||
"""创建 向量化配置参数"""
|
||||
name: str = Field(..., max_length=50, description='配置名称')
|
||||
embedding_type: int = Field(0, description='向量化类型(0:本地 1:远程)')
|
||||
model_name: str = Field(..., max_length=100, description='Embedding模型名称')
|
||||
base_url: str | None = Field(None, max_length=255, description='远程接口地址(远程模式必填)')
|
||||
api_key: str | None = Field(None, max_length=255, description='远程API Key(远程模式必填)')
|
||||
is_default: int = Field(0, description='是否默认配置(0:否 1:是)')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class EmbeddingConfigUpdateSchema(EmbeddingConfigCreateSchema):
|
||||
"""更新 向量化配置参数"""
|
||||
...
|
||||
|
||||
|
||||
class EmbeddingConfigOutSchema(EmbeddingConfigCreateSchema, BaseSchema, UserBySchema):
|
||||
"""向量化配置详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class EmbeddingConfigQueryParam:
|
||||
"""向量化配置查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="配置名称"),
|
||||
embedding_type: int | None = Query(None, description="向量化类型"),
|
||||
is_default: int | None = Query(None, description="是否默认"),
|
||||
) -> None:
|
||||
self.name = ("like", name) if name else None
|
||||
self.embedding_type = embedding_type
|
||||
self.is_default = is_default
|
||||
|
||||
|
||||
# ============== 知识库 ==============
|
||||
|
||||
class KnowledgeBaseCreateSchema(BaseModel):
|
||||
"""创建 知识库参数"""
|
||||
name: str = Field(..., max_length=100, description='知识库名称')
|
||||
embedding_config_id: int | None = Field(None, description='向量化配置ID')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class KnowledgeBaseUpdateSchema(BaseModel):
|
||||
"""更新 知识库参数"""
|
||||
name: str | None = Field(None, max_length=100, description='知识库名称')
|
||||
embedding_config_id: int | None = Field(None, description='向量化配置ID')
|
||||
description: str | None = Field(None, max_length=255, description='备注')
|
||||
|
||||
|
||||
class EmbeddingConfigRefSchema(BaseModel):
|
||||
"""向量化配置引用"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
embedding_type: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class KnowledgeBaseOutSchema(BaseSchema, UserBySchema):
|
||||
"""知识库详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
name: str
|
||||
embedding_config_id: int | None = None
|
||||
collection_name: str
|
||||
document_count: int = 0
|
||||
vector_count: int = 0
|
||||
kb_status: int = 0
|
||||
error_message: str | None = None
|
||||
description: str | None = None
|
||||
file_paths: list[str] | None = None
|
||||
embedding_config: EmbeddingConfigRefSchema | None = None
|
||||
|
||||
|
||||
class KnowledgeBaseQueryParam:
|
||||
"""知识库查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="知识库名称"),
|
||||
embedding_config_id: int | None = Query(None, description="向量化配置ID"),
|
||||
kb_status: int | None = Query(None, description="知识库状态"),
|
||||
) -> None:
|
||||
self.name = ("like", name) if name else None
|
||||
self.embedding_config_id = embedding_config_id
|
||||
self.kb_status = kb_status
|
||||
|
||||
|
||||
# ============== AI模型配置 ==============
|
||||
|
||||
# AI模型类型常量
|
||||
AI_MODEL_TYPES = {
|
||||
'enterprise_naming': '企业起名',
|
||||
'enterprise_renaming': '企业改名',
|
||||
'enterprise_scoring': '企业测名',
|
||||
'enterprise_scoring_trial': '企业测名试用',
|
||||
'personal_naming': '个人起名',
|
||||
'personal_renaming': '个人改名',
|
||||
'personal_scoring': '个人测名',
|
||||
'personal_scoring_trial': '个人测名试用',
|
||||
'yuanfen_hepan': '缘分合盘',
|
||||
'bazi_zeji': '八字择吉',
|
||||
'caiyun_jiexi': '财运解析',
|
||||
'caiyun_jiexi_qiye': '企业财运解析'
|
||||
}
|
||||
|
||||
|
||||
class AIProviderRefSchema(BaseModel):
|
||||
"""供应商引用"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str
|
||||
base_url: str
|
||||
|
||||
|
||||
class AIModelConfigCreateSchema(BaseModel):
|
||||
"""创建 AI模型配置参数"""
|
||||
model_type: str = Field(..., max_length=50, description='模型类型(naming/renaming/scoring/report)')
|
||||
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
|
||||
provider_id: int | None = Field(None, description='AI供应商ID')
|
||||
system_prompt: str | None = Field(None, description='系统提示词')
|
||||
temperature: float = Field(1.0, ge=0, le=2, description='模型温度(0-2)')
|
||||
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
|
||||
|
||||
|
||||
class AIModelConfigUpdateSchema(BaseModel):
|
||||
"""更新 AI模型配置参数"""
|
||||
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
|
||||
provider_id: int | None = Field(None, description='AI供应商ID')
|
||||
system_prompt: str | None = Field(None, description='系统提示词')
|
||||
temperature: float | None = Field(None, ge=0, le=2, description='模型温度(0-2)')
|
||||
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
|
||||
|
||||
|
||||
class AIModelConfigOutSchema(BaseSchema, UserBySchema):
|
||||
"""模型配置详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_type: str
|
||||
model_name: str | None = None
|
||||
provider_id: int | None = None
|
||||
system_prompt: str | None = None
|
||||
temperature: float = 1.0
|
||||
knowledge_base_ids: list[int] | None = None
|
||||
provider: AIProviderRefSchema | None = None
|
||||
|
||||
|
||||
class AIModelConfigQueryParam:
|
||||
"""AI模型配置查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str | None = Query(None, description="模型类型"),
|
||||
) -> None:
|
||||
self.model_type = model_type
|
||||
|
||||
|
||||
# ============== AI模型训练对话 ==============
|
||||
|
||||
class AIModelTrainingMessageCreateSchema(BaseModel):
|
||||
"""创建训练对话消息参数"""
|
||||
model_config_id: int = Field(..., description='模型配置ID')
|
||||
role: str = Field(..., max_length=20, description='角色(user/assistant)')
|
||||
content: str = Field(..., description='消息内容')
|
||||
|
||||
|
||||
class AIModelTrainingMessageOutSchema(BaseSchema, UserBySchema):
|
||||
"""训练对话消息详情"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config_id: int
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class AIModelTrainingChatSchema(BaseModel):
|
||||
"""训练对话请求参数"""
|
||||
model_type: str = Field(..., description='模型类型')
|
||||
message: str = Field(..., min_length=1, description='用户消息')
|
||||
# 如果配置有变动,先保存配置
|
||||
config_changed: bool = Field(False, description='配置是否有变动')
|
||||
config_data: AIModelConfigUpdateSchema | None = Field(None, description='变动的配置数据')
|
||||
|
||||
|
||||
class AIModelTestSchema(BaseModel):
|
||||
"""起名测试请求参数(用于小程序/外部调用)"""
|
||||
model_type: str = Field(..., description='模型类型(enterprise_naming/personal_naming等)')
|
||||
text: str = Field(..., min_length=1, max_length=4000, description='用户输入的文本')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
import httpx
|
||||
|
||||
from app.config.setting import settings
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""
|
||||
AI客户端类,用于与OpenAI API交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = settings.OPENAI_MODEL
|
||||
# 创建一个不带冲突参数的httpx客户端
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
# 使用自定义的http客户端
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
http_client=self.http_client
|
||||
)
|
||||
|
||||
def _friendly_error_message(self, e: Exception) -> str:
|
||||
"""将 OpenAI 或网络异常转换为友好的中文提示。"""
|
||||
# 尝试获取状态码与错误体
|
||||
status_code = getattr(e, "status_code", None)
|
||||
body = getattr(e, "body", None)
|
||||
message = None
|
||||
error_type = None
|
||||
error_code = None
|
||||
try:
|
||||
if isinstance(body, dict) and "error" in body:
|
||||
err = body.get("error") or {}
|
||||
error_type = err.get("type")
|
||||
error_code = err.get("code")
|
||||
message = err.get("message")
|
||||
except Exception:
|
||||
# 忽略解析失败
|
||||
pass
|
||||
|
||||
text = str(e)
|
||||
msg = message or text
|
||||
|
||||
# 特定错误映射
|
||||
# 欠费/账户状态异常
|
||||
if (error_code == "Arrearage") or (error_type == "Arrearage") or ("in good standing" in (msg or "")):
|
||||
return "账户欠费或结算异常,访问被拒绝。请检查账号状态或更换有效的 API Key。"
|
||||
# 鉴权失败
|
||||
if status_code == 401 or "invalid api key" in msg.lower():
|
||||
return "鉴权失败,API Key 无效或已过期。请检查系统配置中的 API Key。"
|
||||
# 权限不足或被拒绝
|
||||
if status_code == 403 or error_type in {"PermissionDenied", "permission_denied"}:
|
||||
return "访问被拒绝,权限不足或账号受限。请检查账户权限设置。"
|
||||
# 配额不足或限流
|
||||
if status_code == 429 or error_type in {"insufficient_quota", "rate_limit_exceeded"}:
|
||||
return "请求过于频繁或配额已用尽。请稍后重试或提升账户配额。"
|
||||
# 客户端错误
|
||||
if status_code == 400:
|
||||
return f"请求参数错误或服务拒绝:{message or '请检查输入内容。'}"
|
||||
# 服务端错误
|
||||
if status_code in {500, 502, 503, 504}:
|
||||
return "服务暂时不可用,请稍后重试。"
|
||||
|
||||
# 默认兜底
|
||||
return f"处理您的请求时出现错误:{msg}"
|
||||
|
||||
async def process(self, query: str) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
处理查询并返回流式响应
|
||||
|
||||
参数:
|
||||
- query (str): 用户查询。
|
||||
|
||||
返回:
|
||||
- AsyncGenerator[str, None]: 流式响应内容。
|
||||
"""
|
||||
system_prompt = """你是一个有用的AI助手,可以帮助用户回答问题和提供帮助。请用中文回答用户的问题。"""
|
||||
|
||||
try:
|
||||
# 使用 await 调用异步客户端
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 流式返回响应
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误,返回友好提示
|
||||
log.error(f"AI处理查询失败: {str(e)}")
|
||||
yield self._friendly_error_message(e)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
关闭客户端连接
|
||||
"""
|
||||
if hasattr(self, 'client'):
|
||||
await self.client.close()
|
||||
if hasattr(self, 'http_client'):
|
||||
await self.http_client.aclose()
|
||||
@@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
文档处理工具类
|
||||
支持 txt、pdf、doc、docx、md 格式解析
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, BinaryIO
|
||||
|
||||
from langchain.schema import Document
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理器"""
|
||||
|
||||
# 支持的文件扩展名
|
||||
SUPPORTED_EXTENSIONS = {'.txt', '.pdf', '.doc', '.docx', '.md'}
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, filename: str) -> bool:
|
||||
"""检查文件是否支持"""
|
||||
ext = Path(filename).suffix.lower()
|
||||
return ext in cls.SUPPORTED_EXTENSIONS
|
||||
|
||||
@classmethod
|
||||
async def process_upload_file(cls, file: UploadFile) -> List[Document]:
|
||||
"""
|
||||
处理上传的文件
|
||||
|
||||
参数:
|
||||
- file: FastAPI UploadFile 对象
|
||||
|
||||
返回:
|
||||
- 文档列表
|
||||
"""
|
||||
if not file.filename:
|
||||
log.warning("文件名为空,跳过处理")
|
||||
return []
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
|
||||
if not cls.is_supported(file.filename):
|
||||
log.warning(f"不支持的文件类型: {ext}")
|
||||
return []
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
await file.seek(0) # 重置文件指针
|
||||
|
||||
# 根据文件类型处理
|
||||
try:
|
||||
if ext == '.txt':
|
||||
return cls._process_txt(content, file.filename)
|
||||
elif ext == '.md':
|
||||
return cls._process_markdown(content, file.filename)
|
||||
elif ext == '.pdf':
|
||||
return await cls._process_pdf(content, file.filename)
|
||||
elif ext in {'.doc', '.docx'}:
|
||||
return await cls._process_word(content, file.filename, ext)
|
||||
else:
|
||||
log.warning(f"未知的文件类型: {ext}")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理文件失败: {file.filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_txt(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 TXT 文件"""
|
||||
try:
|
||||
# 尝试不同编码
|
||||
text = None
|
||||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||||
try:
|
||||
text = content.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
log.error(f"无法解码文件: {filename}")
|
||||
return []
|
||||
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "txt"}
|
||||
)]
|
||||
except Exception as e:
|
||||
log.error(f"处理 TXT 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_markdown(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 Markdown 文件"""
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "markdown"}
|
||||
)]
|
||||
except Exception as e:
|
||||
log.error(f"处理 Markdown 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def _process_pdf(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 PDF 文件"""
|
||||
try:
|
||||
# 使用 pypdf 或 pdfplumber 处理 PDF
|
||||
import pypdf
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(content)
|
||||
reader = pypdf.PdfReader(pdf_file)
|
||||
|
||||
documents = []
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
text = page.extract_text()
|
||||
if text and text.strip():
|
||||
documents.append(Document(
|
||||
page_content=text,
|
||||
metadata={
|
||||
"source": filename,
|
||||
"type": "pdf",
|
||||
"page": page_num + 1
|
||||
}
|
||||
))
|
||||
|
||||
log.info(f"PDF 文件处理完成: {filename}, 共 {len(documents)} 页")
|
||||
return documents
|
||||
|
||||
except ImportError:
|
||||
log.error("未安装 pypdf 库,请运行: pip install pypdf")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def _process_word(cls, content: bytes, filename: str, ext: str) -> List[Document]:
|
||||
"""处理 Word 文件 (doc/docx)"""
|
||||
try:
|
||||
if ext == '.docx':
|
||||
return cls._process_docx(content, filename)
|
||||
else:
|
||||
# .doc 格式需要特殊处理
|
||||
return cls._process_doc(content, filename)
|
||||
except Exception as e:
|
||||
log.error(f"处理 Word 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_docx(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 DOCX 文件"""
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
from io import BytesIO
|
||||
|
||||
docx_file = BytesIO(content)
|
||||
doc = DocxDocument(docx_file)
|
||||
|
||||
# 提取所有段落文本
|
||||
paragraphs = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
paragraphs.append(para.text)
|
||||
|
||||
text = '\n'.join(paragraphs)
|
||||
|
||||
if text.strip():
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "docx"}
|
||||
)]
|
||||
return []
|
||||
|
||||
except ImportError:
|
||||
log.error("未安装 python-docx 库,请运行: pip install python-docx")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_doc(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""
|
||||
处理 DOC 文件 (旧版 Word 格式)
|
||||
注意: .doc 格式处理需要额外依赖,这里做简单提示
|
||||
"""
|
||||
try:
|
||||
# 尝试使用 antiword 或 textract
|
||||
# 如果没有安装,建议用户转换为 docx 格式
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.doc', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# 尝试使用 antiword
|
||||
result = subprocess.run(
|
||||
['antiword', tmp_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return [Document(
|
||||
page_content=result.stdout,
|
||||
metadata={"source": filename, "type": "doc"}
|
||||
)]
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
log.warning(f"无法处理 .doc 文件: {filename},建议转换为 .docx 格式")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"处理 DOC 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def process_files(cls, files: List[UploadFile]) -> List[Document]:
|
||||
"""
|
||||
批量处理上传的文件
|
||||
|
||||
参数:
|
||||
- files: UploadFile 列表
|
||||
|
||||
返回:
|
||||
- 所有文档列表
|
||||
"""
|
||||
all_documents = []
|
||||
|
||||
for file in files:
|
||||
if file.filename and cls.is_supported(file.filename):
|
||||
docs = await cls.process_upload_file(file)
|
||||
all_documents.extend(docs)
|
||||
log.info(f"文件处理完成: {file.filename}, 提取 {len(docs)} 个文档")
|
||||
else:
|
||||
log.warning(f"跳过不支持的文件: {file.filename}")
|
||||
|
||||
log.info(f"批量处理完成,共 {len(all_documents)} 个文档")
|
||||
return all_documents
|
||||
@@ -0,0 +1,239 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量化工具类
|
||||
支持本地和远程 embedding 模型
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import chromadb
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_chroma import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.schema import Document
|
||||
|
||||
from app.core.logger import log
|
||||
from app.config.path_conf import BASE_DIR
|
||||
|
||||
|
||||
# ChromaDB 持久化目录
|
||||
CHROMA_PERSIST_DIR = BASE_DIR / "data" / "chroma_db"
|
||||
|
||||
# 全局 ChromaDB 客户端实例(单例模式)
|
||||
_chroma_client = None
|
||||
|
||||
|
||||
def get_chroma_client():
|
||||
"""获取全局 ChromaDB 客户端实例(单例模式)"""
|
||||
global _chroma_client
|
||||
if _chroma_client is None:
|
||||
# 确保持久化目录存在
|
||||
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_chroma_client = chromadb.PersistentClient(
|
||||
path=str(CHROMA_PERSIST_DIR)
|
||||
)
|
||||
return _chroma_client
|
||||
|
||||
|
||||
class EmbeddingUtil:
|
||||
"""向量化工具类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type: int = 0,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化向量化工具
|
||||
|
||||
参数:
|
||||
- embedding_type: 0=本地, 1=远程
|
||||
- model_name: Embedding模型名称
|
||||
- base_url: 远程接口地址(远程模式必填)
|
||||
- api_key: 远程API Key(远程模式必填)
|
||||
"""
|
||||
self.embedding_type = embedding_type
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
# 初始化 embedding 模型
|
||||
self._embeddings = None
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""延迟加载 embedding 模型"""
|
||||
if self._embeddings is None:
|
||||
self._embeddings = self._create_embeddings()
|
||||
return self._embeddings
|
||||
|
||||
def _create_embeddings(self):
|
||||
"""创建 embedding 模型实例"""
|
||||
if self.embedding_type == 0:
|
||||
# 本地模式 - 使用 sentence-transformers
|
||||
# sentence-transformers==3.3.1
|
||||
# 本地Embedding模型(可选)手动安装 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
try:
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
log.info(f"使用本地 Embedding 模型: {self.model_name}")
|
||||
return HuggingFaceEmbeddings(
|
||||
model_name=self.model_name,
|
||||
model_kwargs={'device': 'cpu'},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"加载本地 Embedding 模型失败: {e}")
|
||||
raise
|
||||
else:
|
||||
# 远程模式 - 使用 OpenAI 兼容接口
|
||||
if not self.base_url or not self.api_key:
|
||||
raise ValueError("远程模式必须提供 base_url 和 api_key")
|
||||
|
||||
# 自动拼接 /v1 路径(如果未包含)
|
||||
api_base = self.base_url.rstrip('/')
|
||||
if not api_base.endswith('/v1'):
|
||||
api_base = f"{api_base}/v1"
|
||||
|
||||
log.info(f"使用远程 Embedding 模型: {self.model_name}, URL: {api_base}")
|
||||
return OpenAIEmbeddings(
|
||||
model=self.model_name,
|
||||
base_url=api_base,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
def get_vector_store(self, collection_name: str) -> Chroma:
|
||||
"""
|
||||
获取或创建向量存储
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- Chroma 向量存储实例
|
||||
"""
|
||||
# 确保目录存在
|
||||
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return Chroma(
|
||||
collection_name=collection_name,
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory=str(CHROMA_PERSIST_DIR),
|
||||
client=get_chroma_client(),
|
||||
)
|
||||
|
||||
def split_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200
|
||||
) -> List[Document]:
|
||||
"""
|
||||
分割文档
|
||||
|
||||
参数:
|
||||
- documents: 文档列表
|
||||
- chunk_size: 分块大小
|
||||
- chunk_overlap: 分块重叠大小
|
||||
|
||||
返回:
|
||||
- 分割后的文档列表
|
||||
"""
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
|
||||
)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
collection_name: str,
|
||||
documents: List[Document]
|
||||
) -> int:
|
||||
"""
|
||||
添加文档到向量存储
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
- documents: 文档列表
|
||||
|
||||
返回:
|
||||
- 添加的向量数量
|
||||
"""
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
# 分割文档
|
||||
split_docs = self.split_documents(documents)
|
||||
log.info(f"文档分割完成,共 {len(split_docs)} 个片段")
|
||||
|
||||
# 获取向量存储
|
||||
vector_store = self.get_vector_store(collection_name)
|
||||
|
||||
# 添加文档
|
||||
vector_store.add_documents(split_docs)
|
||||
log.info(f"向量存储完成,集合: {collection_name}, 向量数: {len(split_docs)}")
|
||||
|
||||
return len(split_docs)
|
||||
|
||||
def delete_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
删除集合
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- 是否删除成功
|
||||
"""
|
||||
try:
|
||||
client = get_chroma_client()
|
||||
client.delete_collection(collection_name)
|
||||
log.info(f"删除集合成功: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"删除集合失败: {collection_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
k: int = 4
|
||||
) -> List[Document]:
|
||||
"""
|
||||
相似度搜索
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
- query: 查询文本
|
||||
- k: 返回结果数量
|
||||
|
||||
返回:
|
||||
- 相似文档列表
|
||||
"""
|
||||
vector_store = self.get_vector_store(collection_name)
|
||||
return vector_store.similarity_search(query, k=k)
|
||||
|
||||
def get_collection_count(self, collection_name: str) -> int:
|
||||
"""
|
||||
获取集合中的向量数量
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- 向量数量
|
||||
"""
|
||||
try:
|
||||
client = get_chroma_client()
|
||||
collection = client.get_collection(collection_name)
|
||||
return collection.count()
|
||||
except Exception as e:
|
||||
log.warning(f"获取集合数量失败: {collection_name}, 错误: {e}")
|
||||
return 0
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Path
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.common.response import StreamResponse, SuccessResponse
|
||||
from app.common.request import PaginationService
|
||||
from app.core.router_class import OperationLogRoute
|
||||
from app.utils.common_util import bytes2file_response
|
||||
from app.core.base_params import PaginationQueryParam
|
||||
from app.core.dependencies import AuthPermission
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .tools.ap_scheduler import SchedulerUtil
|
||||
from .service import JobService, JobLogService
|
||||
from .schema import (
|
||||
JobCreateSchema,
|
||||
JobUpdateSchema,
|
||||
JobQueryParam,
|
||||
JobLogQueryParam
|
||||
)
|
||||
|
||||
|
||||
JobRouter = APIRouter(route_class=OperationLogRoute, prefix="/job", tags=["定时任务"])
|
||||
|
||||
@JobRouter.get("/detail/{id}", summary="获取定时任务详情", description="获取定时任务详情")
|
||||
async def get_obj_detail_controller(
|
||||
id: int = Path(..., description="定时任务ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:query"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取定时任务详情
|
||||
|
||||
参数:
|
||||
- id (int): 定时任务ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含定时任务详情的JSON响应
|
||||
"""
|
||||
result_dict = await JobService.get_job_detail_service(id=id, auth=auth)
|
||||
log.info(f"获取定时任务详情成功 {id}")
|
||||
return SuccessResponse(data=result_dict, msg="获取定时任务详情成功")
|
||||
|
||||
@JobRouter.get("/list", summary="查询定时任务", description="查询定时任务")
|
||||
async def get_obj_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: JobQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:query"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
查询定时任务
|
||||
|
||||
参数:
|
||||
- page (PaginationQueryParam): 分页查询参数模型
|
||||
- search (JobQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含分页后的定时任务列表的JSON响应
|
||||
"""
|
||||
result_dict_list = await JobService.get_job_list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list= result_dict_list, page_no= page.page_no, page_size = page.page_size)
|
||||
log.info(f"查询定时任务列表成功")
|
||||
return SuccessResponse(data=result_dict, msg="查询定时任务列表成功")
|
||||
|
||||
@JobRouter.post("/create", summary="创建定时任务", description="创建定时任务")
|
||||
async def create_obj_controller(
|
||||
data: JobCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:create"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
创建定时任务
|
||||
|
||||
参数:
|
||||
- data (JobCreateSchema): 创建参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含创建定时任务结果的JSON响应
|
||||
"""
|
||||
result_dict = await JobService.create_job_service(auth=auth, data=data)
|
||||
log.info(f"创建定时任务成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建定时任务成功")
|
||||
|
||||
@JobRouter.put("/update/{id}", summary="修改定时任务", description="修改定时任务")
|
||||
async def update_obj_controller(
|
||||
data: JobUpdateSchema,
|
||||
id: int = Path(..., description="定时任务ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:update"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
修改定时任务
|
||||
|
||||
参数:
|
||||
- data (JobUpdateSchema): 更新参数模型
|
||||
- id (int): 定时任务ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含修改定时任务结果的JSON响应
|
||||
"""
|
||||
result_dict = await JobService.update_job_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改定时任务成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改定时任务成功")
|
||||
|
||||
@JobRouter.delete("/delete", summary="删除定时任务", description="删除定时任务")
|
||||
async def delete_obj_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:delete"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
删除定时任务
|
||||
|
||||
参数:
|
||||
- ids (list[int]): ID列表
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含删除定时任务结果的JSON响应
|
||||
"""
|
||||
await JobService.delete_job_service(auth=auth, ids=ids)
|
||||
log.info(f"删除定时任务成功: {ids}")
|
||||
return SuccessResponse(msg="删除定时任务成功")
|
||||
|
||||
@JobRouter.post('/export', summary="导出定时任务", description="导出定时任务")
|
||||
async def export_obj_list_controller(
|
||||
search: JobQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:export"]))
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
导出定时任务
|
||||
|
||||
参数:
|
||||
- search (JobQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- StreamingResponse: 包含导出定时任务结果的流式响应
|
||||
"""
|
||||
result_dict_list = await JobService.get_job_list_service(search=search, auth=auth)
|
||||
export_result = await JobService.export_job_service(data_list=result_dict_list)
|
||||
log.info('导出定时任务成功')
|
||||
|
||||
return StreamResponse(
|
||||
data=bytes2file_response(export_result),
|
||||
media_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
headers = {
|
||||
'Content-Disposition': 'attachment; filename=job.xlsx'
|
||||
}
|
||||
)
|
||||
|
||||
@JobRouter.delete("/clear", summary="清空定时任务日志", description="清空定时任务日志")
|
||||
async def clear_obj_log_controller(
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:delete"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
清空定时任务日志
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含清空定时任务结果的JSON响应
|
||||
"""
|
||||
await JobService.clear_job_service(auth=auth)
|
||||
log.info(f"清空定时任务成功")
|
||||
return SuccessResponse(msg="清空定时任务成功")
|
||||
|
||||
@JobRouter.put("/option", summary="暂停/恢复/重启定时任务", description="暂停/恢复/重启定时任务")
|
||||
async def option_obj_controller(
|
||||
id: int = Body(..., description="定时任务ID"),
|
||||
option: int = Body(..., description="操作类型 1: 暂停 2: 恢复 3: 重启"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:update"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
暂停/恢复/重启定时任务
|
||||
|
||||
参数:
|
||||
- id (int): 定时任务ID
|
||||
- option (int): 操作类型 1: 暂停 2: 恢复 3: 重启
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含操作定时任务结果的JSON响应
|
||||
"""
|
||||
await JobService.option_job_service(auth=auth, id=id, option=option)
|
||||
log.info(f"操作定时任务成功: {id}")
|
||||
return SuccessResponse(msg="操作定时任务成功")
|
||||
|
||||
@JobRouter.get("/log", summary="获取定时任务日志", description="获取定时任务日志", dependencies=[Depends(AuthPermission(["module_application:job:query"]))])
|
||||
async def get_job_log_controller():
|
||||
"""
|
||||
获取定时任务日志
|
||||
|
||||
返回:
|
||||
- JSONResponse: 获取定时任务日志的JSON响应
|
||||
"""
|
||||
data = [
|
||||
{
|
||||
"id": i.id,
|
||||
"name": i.name,
|
||||
"trigger": i.trigger.__class__.__name__,
|
||||
"executor": i.executor,
|
||||
"func": i.func,
|
||||
"func_ref": i.func_ref,
|
||||
"args": i.args,
|
||||
"kwargs": i.kwargs,
|
||||
"misfire_grace_time": i.misfire_grace_time,
|
||||
"coalesce": i.coalesce,
|
||||
"max_instances": i.max_instances,
|
||||
"next_run_time": i.next_run_time,
|
||||
"state": SchedulerUtil.get_single_job_status(job_id=i.id)
|
||||
}
|
||||
for i in SchedulerUtil.get_all_jobs()
|
||||
]
|
||||
|
||||
return SuccessResponse(msg="获取定时任务日志成功", data=data)
|
||||
|
||||
|
||||
# 定时任务日志管理接口
|
||||
@JobRouter.get("/log/detail/{id}", summary="获取定时任务日志详情", description="获取定时任务日志详情")
|
||||
async def get_job_log_detail_controller(
|
||||
id: int = Path(..., description="定时任务日志ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:query"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取定时任务日志详情
|
||||
|
||||
参数:
|
||||
- id (int): 定时任务日志ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 获取定时任务日志详情的JSON响应
|
||||
"""
|
||||
result_dict = await JobLogService.get_job_log_detail_service(id=id, auth=auth)
|
||||
log.info(f"获取定时任务日志详情成功 {id}")
|
||||
return SuccessResponse(data=result_dict, msg="获取定时任务日志详情成功")
|
||||
|
||||
|
||||
@JobRouter.get("/log/list", summary="查询定时任务日志", description="查询定时任务日志")
|
||||
async def get_job_log_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: JobLogQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:query"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
查询定时任务日志
|
||||
|
||||
参数:
|
||||
- page (PaginationQueryParam): 分页查询参数模型
|
||||
- search (JobLogQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 查询定时任务日志列表的JSON响应
|
||||
"""
|
||||
order_by = [{"created_time": "desc"}]
|
||||
result_dict_list = await JobLogService.get_job_log_list_service(auth=auth, search=search, order_by=order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
log.info(f"查询定时任务日志列表成功")
|
||||
return SuccessResponse(data=result_dict, msg="查询定时任务日志列表成功")
|
||||
|
||||
|
||||
@JobRouter.delete("/log/delete", summary="删除定时任务日志", description="删除定时任务日志")
|
||||
async def delete_job_log_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:delete"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
删除定时任务日志
|
||||
|
||||
参数:
|
||||
- ids (list[int]): ID列表
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含删除定时任务日志结果的JSON响应
|
||||
"""
|
||||
await JobLogService.delete_job_log_service(auth=auth, ids=ids)
|
||||
log.info(f"删除定时任务日志成功: {ids}")
|
||||
return SuccessResponse(msg="删除定时任务日志成功")
|
||||
|
||||
|
||||
@JobRouter.delete("/log/clear", summary="清空定时任务日志", description="清空定时任务日志")
|
||||
async def clear_job_log_controller(
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:delete"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
清空定时任务日志
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含清空定时任务日志结果的JSON响应
|
||||
"""
|
||||
await JobLogService.clear_job_log_service(auth=auth)
|
||||
log.info(f"清空定时任务日志成功")
|
||||
return SuccessResponse(msg="清空定时任务日志成功")
|
||||
|
||||
|
||||
@JobRouter.post('/log/export', summary="导出定时任务日志", description="导出定时任务日志")
|
||||
async def export_job_log_list_controller(
|
||||
search: JobLogQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:job:export"]))
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
导出定时任务日志
|
||||
|
||||
参数:
|
||||
- search (JobLogQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- StreamingResponse: 包含导出定时任务日志结果的流式响应
|
||||
"""
|
||||
result_dict_list = await JobLogService.get_job_log_list_service(search=search, auth=auth)
|
||||
export_result = await JobLogService.export_job_log_service(data_list=result_dict_list)
|
||||
log.info('导出定时任务日志成功')
|
||||
|
||||
return StreamResponse(
|
||||
data=bytes2file_response(export_result),
|
||||
media_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
headers={
|
||||
'Content-Disposition': 'attachment; filename=job_log.xlsx'
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Sequence, Any
|
||||
|
||||
from app.core.base_crud import CRUDBase
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .model import JobModel, JobLogModel
|
||||
from .schema import JobCreateSchema,JobUpdateSchema,JobLogCreateSchema,JobLogUpdateSchema
|
||||
|
||||
|
||||
class JobCRUD(CRUDBase[JobModel, JobCreateSchema, JobUpdateSchema]):
|
||||
"""定时任务数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
"""
|
||||
初始化定时任务CRUD
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
self.auth = auth
|
||||
super().__init__(model=JobModel, auth=auth)
|
||||
|
||||
async def get_obj_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> JobModel | None:
|
||||
"""
|
||||
获取定时任务详情
|
||||
|
||||
参数:
|
||||
- id (int): 定时任务ID
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- JobModel | None: 定时任务模型,如果不存在则为None
|
||||
"""
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_obj_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[JobModel]:
|
||||
"""
|
||||
获取定时任务列表
|
||||
|
||||
参数:
|
||||
- search (dict | None): 查询参数字典
|
||||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Sequence[JobModel]: 定时任务模型序列
|
||||
"""
|
||||
return await self.list(search=search, order_by=order_by, preload=preload)
|
||||
|
||||
async def create_obj_crud(self, data: JobCreateSchema) -> JobModel | None:
|
||||
"""
|
||||
创建定时任务
|
||||
|
||||
参数:
|
||||
- data (JobCreateSchema): 创建定时任务模型
|
||||
|
||||
返回:
|
||||
- JobModel | None: 创建的定时任务模型,如果创建失败则为None
|
||||
"""
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_obj_crud(self, id: int, data: JobUpdateSchema) -> JobModel | None:
|
||||
"""
|
||||
更新定时任务
|
||||
|
||||
参数:
|
||||
- id (int): 定时任务ID
|
||||
- data (JobUpdateSchema): 更新定时任务模型
|
||||
|
||||
返回:
|
||||
- JobModel | None: 更新后的定时任务模型,如果更新失败则为None
|
||||
"""
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_obj_crud(self, ids: list[int]) -> None:
|
||||
"""
|
||||
删除定时任务
|
||||
|
||||
参数:
|
||||
- ids (list[int]): 定时任务ID列表
|
||||
"""
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def set_obj_field_crud(self, ids: list[int], **kwargs) -> None:
|
||||
"""
|
||||
设置定时任务的可用状态
|
||||
|
||||
参数:
|
||||
- ids (list[int]): 定时任务ID列表
|
||||
- kwargs: 其他要设置的字段,例如 available=True 或 available=False
|
||||
"""
|
||||
return await self.set(ids=ids, **kwargs)
|
||||
|
||||
async def clear_obj_crud(self) -> None:
|
||||
"""
|
||||
清除定时任务日志
|
||||
|
||||
注意:
|
||||
- 此操作会删除所有定时任务日志,请谨慎操作
|
||||
"""
|
||||
return await self.clear()
|
||||
|
||||
|
||||
class JobLogCRUD(CRUDBase[JobLogModel, JobLogCreateSchema, JobLogUpdateSchema]):
|
||||
"""定时任务日志数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
"""
|
||||
初始化定时任务日志CRUD
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
self.auth = auth
|
||||
super().__init__(model=JobLogModel, auth=auth)
|
||||
|
||||
async def get_obj_log_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> JobLogModel | None:
|
||||
"""
|
||||
获取定时任务日志详情
|
||||
|
||||
参数:
|
||||
- id (int): 定时任务日志ID
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- JobLogModel | None: 定时任务日志模型,如果不存在则为None
|
||||
"""
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def get_obj_log_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[JobLogModel]:
|
||||
"""
|
||||
获取定时任务日志列表
|
||||
|
||||
参数:
|
||||
- search (dict | None): 查询参数字典
|
||||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Sequence[JobLogModel]: 定时任务日志模型序列
|
||||
"""
|
||||
return await self.list(search=search, order_by=order_by, preload=preload)
|
||||
|
||||
async def delete_obj_log_crud(self, ids: list[int]) -> None:
|
||||
"""
|
||||
删除定时任务日志
|
||||
|
||||
参数:
|
||||
- ids (list[int]): 定时任务日志ID列表
|
||||
"""
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def clear_obj_log_crud(self) -> None:
|
||||
"""
|
||||
清除定时任务日志
|
||||
|
||||
注意:
|
||||
- 此操作会删除所有定时任务日志,请谨慎操作
|
||||
"""
|
||||
return await self.clear()
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
def job(*args, **kwargs) -> None:
|
||||
"""
|
||||
定时任务执行同步函数示例
|
||||
|
||||
参数:
|
||||
- args: 位置参数。
|
||||
- kwargs: 关键字参数。
|
||||
"""
|
||||
try:
|
||||
print(f"开始执行任务: {args}-{kwargs}")
|
||||
time.sleep(3)
|
||||
print(f'{datetime.now()}同步函数执行完成')
|
||||
except Exception as e:
|
||||
log.error(f"同步任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
async def async_job(*args, **kwargs) -> None:
|
||||
"""
|
||||
定时任务执行异步函数示例
|
||||
|
||||
参数:
|
||||
- args: 位置参数。
|
||||
- kwargs: 关键字参数。
|
||||
"""
|
||||
try:
|
||||
print(f"开始执行任务: {args}-{kwargs}")
|
||||
time.sleep(3)
|
||||
print(f'{datetime.now()}异步函数执行完成')
|
||||
except Exception as e:
|
||||
log.error(f"异步任务执行失败: {e}")
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sqlalchemy import Boolean, String, Integer, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.base_model import ModelMixin, UserMixin
|
||||
|
||||
|
||||
class JobModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
定时任务调度表
|
||||
- 0: 运行中
|
||||
- 1: 暂停中
|
||||
"""
|
||||
__tablename__: str = 'app_job'
|
||||
__table_args__: dict[str, str] = ({'comment': '定时任务调度表'})
|
||||
__loader_options__: list[str] = ["job_logs", "created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str | None] = mapped_column(String(64), nullable=True, default='', comment='任务名称')
|
||||
jobstore: Mapped[str | None] = mapped_column(String(64), nullable=True, default='default', comment='存储器')
|
||||
executor: Mapped[str | None] = mapped_column(String(64), nullable=True, default='default', comment='执行器:将运行此作业的执行程序的名称')
|
||||
trigger: Mapped[str] = mapped_column(String(64), nullable=False, comment='触发器:控制此作业计划的 trigger 对象')
|
||||
trigger_args: Mapped[str | None] = mapped_column(Text, nullable=True, comment='触发器参数')
|
||||
func: Mapped[str] = mapped_column(Text, nullable=False, comment='任务函数')
|
||||
args: Mapped[str | None] = mapped_column(Text, nullable=True, comment='位置参数')
|
||||
kwargs: Mapped[str | None] = mapped_column(Text, nullable=True, comment='关键字参数')
|
||||
coalesce: Mapped[bool] = mapped_column(Boolean, nullable=True, default=False, comment='是否合并运行:是否在多个运行时间到期时仅运行作业一次')
|
||||
max_instances: Mapped[int] = mapped_column(Integer, nullable=True, default=1, comment='最大实例数:允许的最大并发执行实例数')
|
||||
start_date: Mapped[str | None] = mapped_column(String(64), nullable=True, comment='开始时间')
|
||||
end_date: Mapped[str | None] = mapped_column(String(64), nullable=True, comment='结束时间')
|
||||
|
||||
# 关联关系
|
||||
job_logs: Mapped[list['JobLogModel'] | None] = relationship(
|
||||
back_populates="job",
|
||||
lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
class JobLogModel(ModelMixin):
|
||||
"""
|
||||
定时任务调度日志表
|
||||
"""
|
||||
__tablename__: str = 'app_job_log'
|
||||
__table_args__: dict[str, str] = ({'comment': '定时任务调度日志表'})
|
||||
__loader_options__: list[str] = ["job"]
|
||||
|
||||
job_name: Mapped[str] = mapped_column(String(64), nullable=False, comment='任务名称')
|
||||
job_group: Mapped[str] = mapped_column(String(64), nullable=False, comment='任务组名')
|
||||
job_executor: Mapped[str] = mapped_column(String(64), nullable=False, comment='任务执行器')
|
||||
invoke_target: Mapped[str] = mapped_column(String(500), nullable=False, comment='调用目标字符串')
|
||||
job_args: Mapped[str | None] = mapped_column(String(255), nullable=True, default='', comment='位置参数')
|
||||
job_kwargs: Mapped[str | None] = mapped_column(String(255), nullable=True, default='', comment='关键字参数')
|
||||
job_trigger: Mapped[str | None] = mapped_column(String(255), nullable=True, default='', comment='任务触发器')
|
||||
job_message: Mapped[str | None] = mapped_column(String(500), nullable=True, default='', comment='日志信息')
|
||||
exception_info: Mapped[str | None] = mapped_column(String(2000), nullable=True, default='', comment='异常信息')
|
||||
|
||||
# 任务关联
|
||||
job_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey('app_job.id', ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment='任务ID'
|
||||
)
|
||||
|
||||
job: Mapped["JobModel | None"] = relationship(
|
||||
back_populates="job_logs",
|
||||
lazy="selectin"
|
||||
)
|
||||
@@ -0,0 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from fastapi import Query
|
||||
|
||||
from app.core.base_schema import BaseSchema, UserBySchema
|
||||
from app.core.validator import DateTimeStr, datetime_validator
|
||||
|
||||
|
||||
class JobCreateSchema(BaseModel):
|
||||
"""
|
||||
定时任务调度表对应pydantic模型
|
||||
"""
|
||||
name: str = Field(..., max_length=64, description='任务名称')
|
||||
func: str = Field(..., description='任务函数')
|
||||
trigger: str = Field(..., description='触发器:控制此作业计划的 trigger 对象')
|
||||
args: str | None = Field(default=None, description='位置参数')
|
||||
kwargs: str | None = Field(default=None, description='关键字参数')
|
||||
coalesce: bool | None = Field(..., description='是否合并运行:是否在多个运行时间到期时仅运行作业一次')
|
||||
max_instances: int | None = Field(default=1, ge=1, description='最大实例数:允许的最大并发执行实例数')
|
||||
jobstore: str | None = Field(..., max_length=64, description='任务存储')
|
||||
executor: str | None = Field(..., max_length=64, description='任务执行器:将运行此作业的执行程序的名称')
|
||||
trigger_args: str | None = Field(default=None, description='触发器参数')
|
||||
start_date: str | None = Field(default=None, description='开始时间')
|
||||
end_date: str | None = Field(default=None, description='结束时间')
|
||||
description: str | None = Field(default=None, max_length=255, description='描述')
|
||||
status: str = Field(default='0', description='任务状态:启动,停止')
|
||||
|
||||
@field_validator('trigger')
|
||||
@classmethod
|
||||
def _validate_trigger(cls, v: str) -> str:
|
||||
allowed = {'cron', 'interval', 'date'}
|
||||
v = v.strip()
|
||||
if v not in allowed:
|
||||
raise ValueError('触发器必须为 cron/interval/date')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def _validate_dates(self):
|
||||
"""跨字段校验:结束时间不得早于开始时间。"""
|
||||
if self.start_date and self.end_date:
|
||||
try:
|
||||
start = datetime_validator(self.start_date)
|
||||
end = datetime_validator(self.end_date)
|
||||
except Exception:
|
||||
raise ValueError('时间格式必须为 YYYY-MM-DD HH:MM:SS')
|
||||
if end < start:
|
||||
raise ValueError('结束时间不能早于开始时间')
|
||||
return self
|
||||
|
||||
|
||||
class JobUpdateSchema(JobCreateSchema):
|
||||
"""定时任务更新模型"""
|
||||
...
|
||||
|
||||
|
||||
class JobOutSchema(JobCreateSchema, BaseSchema, UserBySchema):
|
||||
"""定时任务响应模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
...
|
||||
|
||||
|
||||
class JobLogCreateSchema(BaseModel):
|
||||
"""
|
||||
定时任务调度日志表对应pydantic模型
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
job_name: str = Field(..., description='任务名称')
|
||||
job_group: str | None = Field(default=None, description='任务组名')
|
||||
job_executor: str | None = Field(default=None, description='任务执行器')
|
||||
invoke_target: str | None = Field(default=None, description='调用目标字符串')
|
||||
job_args: str | None = Field(default=None, description='位置参数')
|
||||
job_kwargs: str | None = Field(default=None, description='关键字参数')
|
||||
job_trigger: str | None = Field(default=None, description='任务触发器')
|
||||
job_message: str | None = Field(default=None, description='日志信息')
|
||||
exception_info: str | None = Field(default=None, description='异常信息')
|
||||
status: str = Field(default='0', description='任务状态:正常,失败')
|
||||
description: str | None = Field(default=None, max_length=255, description='描述')
|
||||
created_time: DateTimeStr | None = Field(default=None, description='创建时间')
|
||||
updated_time: DateTimeStr | None = Field(default=None, description='更新时间')
|
||||
|
||||
|
||||
class JobLogUpdateSchema(JobLogCreateSchema):
|
||||
"""定时任务调度日志表更新模型"""
|
||||
...
|
||||
id: int | None = Field(default=None, description='任务日志ID')
|
||||
|
||||
|
||||
class JobLogOutSchema(JobLogUpdateSchema, BaseSchema, UserBySchema):
|
||||
"""定时任务调度日志表响应模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
...
|
||||
|
||||
|
||||
class JobQueryParam:
|
||||
"""定时任务查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="任务名称"),
|
||||
status: str | None = Query(None, description="状态: 启动,停止"),
|
||||
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
created_id: int | None = Query(None, description="创建人"),
|
||||
updated_id: int | None = Query(None, description="更新人"),
|
||||
) -> None:
|
||||
|
||||
# 模糊查询字段
|
||||
self.name = ("like", f"%{name}%") if name else None
|
||||
|
||||
# 精确查询字段
|
||||
self.created_id = created_id
|
||||
self.updated_id = updated_id
|
||||
self.status = status
|
||||
|
||||
# 时间范围查询
|
||||
if created_time and len(created_time) == 2:
|
||||
self.created_time = ("between", (created_time[0], created_time[1]))
|
||||
if updated_time and len(updated_time) == 2:
|
||||
self.updated_time = ("between", (updated_time[0], updated_time[1]))
|
||||
|
||||
|
||||
class JobLogQueryParam:
|
||||
"""定时任务查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_id: int | None = Query(None, description="定时任务ID"),
|
||||
job_name: str | None = Query(None, description="任务名称"),
|
||||
status: str | None = Query(None, description="状态: 正常,失败"),
|
||||
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
) -> None:
|
||||
# 定时任务ID查询
|
||||
self.job_id = job_id
|
||||
# 模糊查询字段
|
||||
self.job_name = ("like", job_name)
|
||||
# 精确查询字段
|
||||
self.status = status
|
||||
# 时间范围查询
|
||||
if created_time and len(created_time) == 2:
|
||||
self.created_time = ("between", (created_time[0], created_time[1]))
|
||||
if updated_time and len(updated_time) == 2:
|
||||
self.updated_time = ("between", (updated_time[0], updated_time[1]))
|
||||
@@ -0,0 +1,307 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from app.core.exceptions import CustomException
|
||||
from app.utils.cron_util import CronUtil
|
||||
from app.utils.excel_util import ExcelUtil
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .tools.ap_scheduler import SchedulerUtil
|
||||
from .crud import JobCRUD, JobLogCRUD
|
||||
from .schema import (
|
||||
JobCreateSchema,
|
||||
JobUpdateSchema,
|
||||
JobOutSchema,
|
||||
JobLogOutSchema,
|
||||
JobQueryParam,
|
||||
JobLogQueryParam
|
||||
)
|
||||
|
||||
|
||||
class JobService:
|
||||
"""
|
||||
定时任务管理模块服务层
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def get_job_detail_service(cls, auth: AuthSchema, id: int) -> dict:
|
||||
"""
|
||||
获取定时任务详情
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- id (int): 定时任务ID
|
||||
|
||||
返回:
|
||||
- Dict: 定时任务详情字典
|
||||
"""
|
||||
obj = await JobCRUD(auth).get_obj_by_id_crud(id=id)
|
||||
return JobOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def get_job_list_service(cls, auth: AuthSchema, search: JobQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict]:
|
||||
"""
|
||||
获取定时任务列表
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- search (JobQueryParam | None): 查询参数模型
|
||||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||||
|
||||
返回:
|
||||
- List[Dict]: 定时任务详情字典列表
|
||||
"""
|
||||
obj_list = await JobCRUD(auth).get_obj_list_crud(search=search.__dict__, order_by=order_by)
|
||||
return [JobOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||||
|
||||
@classmethod
|
||||
async def create_job_service(cls, auth: AuthSchema, data: JobCreateSchema) -> dict:
|
||||
"""
|
||||
创建定时任务
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- data (JobCreateSchema): 定时任务创建模型
|
||||
|
||||
返回:
|
||||
- Dict: 定时任务详情字典
|
||||
"""
|
||||
exist_obj = await JobCRUD(auth).get(name=data.name)
|
||||
if exist_obj:
|
||||
raise CustomException(msg='创建失败,该定时任务已存在')
|
||||
|
||||
obj = await JobCRUD(auth).create_obj_crud(data=data)
|
||||
if not obj:
|
||||
raise CustomException(msg='创建失败,该数据定时任务不存在')
|
||||
SchedulerUtil().add_job(job_info=obj)
|
||||
return JobOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def update_job_service(cls, auth: AuthSchema, id:int, data: JobUpdateSchema) -> dict:
|
||||
"""
|
||||
更新定时任务
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- id (int): 定时任务ID
|
||||
- data (JobUpdateSchema): 定时任务更新模型
|
||||
|
||||
返回:
|
||||
- dict: 定时任务详情字典
|
||||
"""
|
||||
exist_obj = await JobCRUD(auth).get_obj_by_id_crud(id=id)
|
||||
if not exist_obj:
|
||||
raise CustomException(msg='更新失败,该定时任务不存在')
|
||||
if data.trigger == 'cron' and data.trigger_args and not CronUtil.validate_cron_expression(data.trigger_args):
|
||||
raise CustomException(msg=f'新增定时任务{data.name}失败, Cron表达式不正确')
|
||||
obj = await JobCRUD(auth).update_obj_crud(id=id, data=data)
|
||||
if not obj:
|
||||
raise CustomException(msg='更新失败,该数据定时任务不存在')
|
||||
SchedulerUtil().modify_job(job_id=obj.id)
|
||||
return JobOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def delete_job_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||||
"""
|
||||
删除定时任务
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- ids (list[int]): 定时任务ID列表
|
||||
"""
|
||||
if len(ids) < 1:
|
||||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||||
for id in ids:
|
||||
exist_obj = await JobCRUD(auth).get_obj_by_id_crud(id=id)
|
||||
if not exist_obj:
|
||||
raise CustomException(msg='删除失败,该数据定时任务不存在')
|
||||
obj = await JobLogCRUD(auth).get(job_id=id)
|
||||
if obj:
|
||||
raise CustomException(msg=f'删除失败,该定时任务存 {exist_obj.name} 在日志记录')
|
||||
|
||||
SchedulerUtil().remove_job(job_id=id)
|
||||
await JobCRUD(auth).delete_obj_crud(ids=ids)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def clear_job_service(cls, auth: AuthSchema) -> None:
|
||||
"""
|
||||
清空所有定时任务
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
SchedulerUtil().clear_jobs()
|
||||
await JobLogCRUD(auth).clear_obj_log_crud()
|
||||
await JobCRUD(auth).clear_obj_crud()
|
||||
|
||||
@classmethod
|
||||
async def option_job_service(cls, auth: AuthSchema, id: int, option: int) -> None:
|
||||
"""
|
||||
操作定时任务
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- id (int): 定时任务ID
|
||||
- option (int): 操作类型, 1: 暂停 2: 恢复 3: 重启
|
||||
"""
|
||||
# 1: 暂停 2: 恢复 3: 重启
|
||||
obj = await JobCRUD(auth).get_obj_by_id_crud(id=id)
|
||||
if not obj:
|
||||
raise CustomException(msg='操作失败,该数据定时任务不存在')
|
||||
if option == 1:
|
||||
SchedulerUtil().pause_job(job_id=id)
|
||||
await JobCRUD(auth).set_obj_field_crud(ids=[id], status=False)
|
||||
elif option == 2:
|
||||
SchedulerUtil().resume_job(job_id=id)
|
||||
await JobCRUD(auth).set_obj_field_crud(ids=[id], status=True)
|
||||
elif option == 3:
|
||||
# 重启任务:先移除再添加,确保使用最新的任务配置
|
||||
SchedulerUtil().remove_job(job_id=id)
|
||||
# 获取最新的任务配置
|
||||
updated_job = await JobCRUD(auth).get_obj_by_id_crud(id=id)
|
||||
if updated_job:
|
||||
# 重新添加任务
|
||||
SchedulerUtil.add_job(job_info=updated_job)
|
||||
# 设置状态为运行中
|
||||
await JobCRUD(auth).set_obj_field_crud(ids=[id], status=True)
|
||||
|
||||
@classmethod
|
||||
async def export_job_service(cls, data_list: list[dict]) -> bytes:
|
||||
"""
|
||||
导出定时任务列表
|
||||
|
||||
参数:
|
||||
- data_list (list[dict]): 定时任务列表
|
||||
|
||||
返回:
|
||||
- bytes: Excel文件字节流
|
||||
"""
|
||||
mapping_dict = {
|
||||
'id': '编号',
|
||||
'name': '任务名称',
|
||||
'func': '任务函数',
|
||||
'trigger': '触发器',
|
||||
'args': '位置参数',
|
||||
'kwargs': '关键字参数',
|
||||
'coalesce': '是否合并运行',
|
||||
'max_instances': '最大实例数',
|
||||
'jobstore': '任务存储',
|
||||
'executor': '任务执行器',
|
||||
'trigger_args': '触发器参数',
|
||||
'status': '任务状态',
|
||||
'message': '日志信息',
|
||||
'description': '备注',
|
||||
'created_time': '创建时间',
|
||||
'updated_time': '更新时间',
|
||||
'created_id': '创建者ID',
|
||||
'updated_id': '更新者ID',
|
||||
}
|
||||
|
||||
# 复制数据并转换状态
|
||||
data = data_list.copy()
|
||||
for item in data:
|
||||
item['status'] = '已完成' if item['status'] == '0' else '运行中' if item['status'] == '1' else '暂停'
|
||||
|
||||
return ExcelUtil.export_list2excel(list_data=data, mapping_dict=mapping_dict)
|
||||
|
||||
|
||||
class JobLogService:
|
||||
"""
|
||||
定时任务日志管理模块服务层
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def get_job_log_detail_service(cls, auth: AuthSchema, id: int) -> dict:
|
||||
"""
|
||||
获取定时任务日志详情
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- id (int): 定时任务日志ID
|
||||
|
||||
返回:
|
||||
- dict: 定时任务日志详情字典
|
||||
"""
|
||||
obj = await JobLogCRUD(auth).get_obj_log_by_id_crud(id=id)
|
||||
return JobLogOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def get_job_log_list_service(cls, auth: AuthSchema, search: JobLogQueryParam | None = None, order_by: list[dict] | None = None) -> list[dict]:
|
||||
"""
|
||||
获取定时任务日志列表
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- search (JobLogQueryParam | None): 查询参数模型, 包含分页信息和查询条件
|
||||
- order_by (list[dict] | None): 排序参数列表, 每个元素为一个字典, 包含字段名和排序方向
|
||||
|
||||
返回:
|
||||
- list[dict]: 定时任务日志详情字典列表
|
||||
"""
|
||||
obj_list = await JobLogCRUD(auth).get_obj_log_list_crud(search=search.__dict__, order_by=order_by)
|
||||
return [JobLogOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||||
|
||||
@classmethod
|
||||
async def delete_job_log_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||||
"""
|
||||
删除定时任务日志
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- ids (list[int]): 定时任务日志ID列表
|
||||
"""
|
||||
if len(ids) < 1:
|
||||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||||
for id in ids:
|
||||
exist_obj = await JobLogCRUD(auth).get_obj_log_by_id_crud(id=id)
|
||||
if not exist_obj:
|
||||
raise CustomException(msg=f'删除失败,该定时任务日志ID为{id}的记录不存在')
|
||||
await JobLogCRUD(auth).delete_obj_log_crud(ids=ids)
|
||||
|
||||
@classmethod
|
||||
async def clear_job_log_service(cls, auth: AuthSchema) -> None:
|
||||
"""
|
||||
清空定时任务日志
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
# 获取所有日志ID并批量删除
|
||||
all_logs = await JobLogCRUD(auth).get_obj_log_list_crud()
|
||||
if all_logs:
|
||||
ids = [log.id for log in all_logs]
|
||||
await JobLogCRUD(auth).delete_obj_log_crud(ids=ids)
|
||||
|
||||
@classmethod
|
||||
async def export_job_log_service(cls, data_list: list[dict]) -> bytes:
|
||||
"""
|
||||
导出定时任务日志列表
|
||||
|
||||
参数:
|
||||
- data_list (List[Dict[str, Any]]): 定时任务日志列表
|
||||
|
||||
返回:
|
||||
- bytes: Excel文件字节流
|
||||
"""
|
||||
mapping_dict = {
|
||||
'id': '编号',
|
||||
'job_name': '任务名称',
|
||||
'job_group': '任务组名',
|
||||
'job_executor': '任务执行器',
|
||||
'invoke_target': '调用目标字符串',
|
||||
'job_args': '位置参数',
|
||||
'job_kwargs': '关键字参数',
|
||||
'job_trigger': '任务触发器',
|
||||
'job_message': '日志信息',
|
||||
'exception_info': '异常信息',
|
||||
'status': '执行状态',
|
||||
'created_time': '创建时间',
|
||||
'updated_time': '更新时间',
|
||||
}
|
||||
|
||||
# 复制数据并转换状态
|
||||
data = data_list.copy()
|
||||
for item in data:
|
||||
item['status'] = '成功' if item.get('status') == '0' else '失败'
|
||||
|
||||
return ExcelUtil.export_list2excel(list_data=data, mapping_dict=mapping_dict)
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,589 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import importlib
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from asyncio import iscoroutinefunction
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.events import JobExecutionEvent, EVENT_ALL, JobEvent
|
||||
from apscheduler.executors.asyncio import AsyncIOExecutor
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.executors.pool import ProcessPoolExecutor
|
||||
from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.redis import RedisJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.config.setting import settings
|
||||
from app.core.database import engine, db_session, async_db_session
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.logger import log
|
||||
from app.utils.cron_util import CronUtil
|
||||
|
||||
from app.api.v1.module_application.job.model import JobModel
|
||||
|
||||
job_stores = {
|
||||
'default': MemoryJobStore(),
|
||||
'sqlalchemy': SQLAlchemyJobStore(url=settings.DB_URI, engine=engine),
|
||||
'redis': RedisJobStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=int(settings.REDIS_PORT),
|
||||
username=settings.REDIS_USER,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
db=int(settings.REDIS_DB_NAME),
|
||||
),
|
||||
}
|
||||
# 配置执行器
|
||||
executors = {
|
||||
'default': AsyncIOExecutor(),
|
||||
'processpool': ProcessPoolExecutor(max_workers=1) # 减少进程数量以减少资源消耗
|
||||
}
|
||||
# 配置默认参数
|
||||
job_defaults = {
|
||||
'coalesce': True, # 合并执行错过的任务
|
||||
'max_instances': 1, # 最大实例数
|
||||
}
|
||||
# 配置调度器
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.configure(
|
||||
jobstores=job_stores,
|
||||
executors=executors,
|
||||
job_defaults=job_defaults,
|
||||
timezone='Asia/Shanghai'
|
||||
)
|
||||
|
||||
class SchedulerUtil:
|
||||
"""
|
||||
定时任务相关方法
|
||||
"""
|
||||
@classmethod
|
||||
def scheduler_event_listener(cls, event: JobEvent | JobExecutionEvent) -> None:
|
||||
"""
|
||||
监听任务执行事件。
|
||||
|
||||
参数:
|
||||
- event (JobEvent | JobExecutionEvent): 任务事件对象。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
# 延迟导入避免循环导入
|
||||
from app.api.v1.module_application.job.model import JobLogModel
|
||||
|
||||
# 获取事件类型和任务ID
|
||||
event_type = event.__class__.__name__
|
||||
# 初始化任务状态
|
||||
status = True
|
||||
exception_info = ''
|
||||
if isinstance(event, JobExecutionEvent) and event.exception:
|
||||
exception_info = str(event.exception)
|
||||
status = False
|
||||
if hasattr(event, 'job_id'):
|
||||
job_id = event.job_id
|
||||
query_job = cls.get_job(job_id=job_id)
|
||||
if query_job:
|
||||
query_job_info = query_job.__getstate__()
|
||||
# 获取任务名称
|
||||
job_name = query_job_info.get('name')
|
||||
# 获取任务组名
|
||||
job_group = query_job._jobstore_alias
|
||||
# # 获取任务执行器
|
||||
job_executor = query_job_info.get('executor')
|
||||
# 获取调用目标字符串
|
||||
invoke_target = query_job_info.get('func')
|
||||
# 获取调用函数位置参数
|
||||
job_args = ','.join(map(str, query_job_info.get('args', [])))
|
||||
# 获取调用函数关键字参数
|
||||
job_kwargs = json.dumps(query_job_info.get('kwargs'))
|
||||
# 获取任务触发器
|
||||
job_trigger = str(query_job_info.get('trigger'))
|
||||
# 构造日志消息
|
||||
job_message = f"事件类型: {event_type}, 任务ID: {job_id}, 任务名称: {job_name}, 状态: {status}, 任务组: {job_group}, 错误详情: {exception_info}, 执行于{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# 创建ORM对象
|
||||
job_log = JobLogModel(
|
||||
job_name=job_name,
|
||||
job_group=job_group,
|
||||
job_executor=job_executor,
|
||||
invoke_target=invoke_target,
|
||||
job_args=job_args,
|
||||
job_kwargs=job_kwargs,
|
||||
job_trigger=job_trigger,
|
||||
job_message=job_message,
|
||||
status=status,
|
||||
exception_info=exception_info,
|
||||
created_time=datetime.now(),
|
||||
updated_time=datetime.now(),
|
||||
job_id=job_id,
|
||||
)
|
||||
|
||||
# 使用线程池执行操作以避免阻塞调度器和数据库锁定问题
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
executor.submit(cls._save_job_log_async_wrapper, job_log)
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
@classmethod
|
||||
def _save_job_log_async_wrapper(cls, job_log) -> None:
|
||||
"""
|
||||
异步保存任务日志的包装器函数,在独立线程中运行
|
||||
|
||||
参数:
|
||||
- job_log (JobLogModel): 任务日志对象
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
with db_session.begin() as session:
|
||||
try:
|
||||
session.add(job_log)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
log.error(f"保存任务日志失败: {str(e)}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
@classmethod
|
||||
async def init_system_scheduler(cls) -> None:
|
||||
"""
|
||||
应用启动时初始化定时任务。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
# 延迟导入避免循环导入
|
||||
from app.api.v1.module_application.job.crud import JobCRUD
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
log.info('🔎 开始启动定时任务...')
|
||||
|
||||
# 启动调度器
|
||||
scheduler.start()
|
||||
|
||||
# 添加事件监听器
|
||||
scheduler.add_listener(cls.scheduler_event_listener, EVENT_ALL)
|
||||
|
||||
async with async_db_session() as session:
|
||||
async with session.begin():
|
||||
auth = AuthSchema(db=session)
|
||||
job_list = await JobCRUD(auth).get_obj_list_crud()
|
||||
|
||||
# 只在一个实例上初始化任务
|
||||
# 使用Redis锁确保只有一个实例执行任务初始化
|
||||
import redis.asyncio as redis
|
||||
redis_client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=int(settings.REDIS_PORT),
|
||||
username=settings.REDIS_USER,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
db=int(settings.REDIS_DB_NAME),
|
||||
)
|
||||
|
||||
# 尝试获取锁,过期时间10秒
|
||||
lock_key = "scheduler_init_lock"
|
||||
lock_acquired = await redis_client.set(lock_key, "1", ex=10, nx=True)
|
||||
|
||||
if lock_acquired:
|
||||
try:
|
||||
for item in job_list:
|
||||
# 检查任务是否已经存在
|
||||
existing_job = cls.get_job(job_id=item.id)
|
||||
if existing_job:
|
||||
cls.remove_job(job_id=item.id) # 删除旧任务
|
||||
|
||||
# 添加新任务
|
||||
cls.add_job(item)
|
||||
|
||||
# 根据数据库中保存的状态来设置任务状态
|
||||
if hasattr(item, 'status') and item.status == "1":
|
||||
# 如果任务状态为暂停,则立即暂停刚添加的任务
|
||||
cls.pause_job(job_id=item.id)
|
||||
log.info('✅️ 系统初始定时任务加载成功')
|
||||
finally:
|
||||
# 释放锁
|
||||
await redis_client.delete(lock_key)
|
||||
else:
|
||||
# 等待其他实例完成初始化
|
||||
import asyncio
|
||||
await asyncio.sleep(2)
|
||||
log.info('✅️ 定时任务已由其他实例初始化完成')
|
||||
|
||||
@classmethod
|
||||
async def close_system_scheduler(cls) -> None:
|
||||
"""
|
||||
关闭系统定时任务。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
try:
|
||||
# 移除所有任务
|
||||
scheduler.remove_all_jobs()
|
||||
# 等待所有任务完成后再关闭
|
||||
scheduler.shutdown(wait=True)
|
||||
log.info('✅️ 关闭定时任务成功')
|
||||
except Exception as e:
|
||||
log.error(f'关闭定时任务失败: {str(e)}')
|
||||
|
||||
@classmethod
|
||||
def get_job(cls, job_id: str | int) -> Job | None:
|
||||
"""
|
||||
根据任务ID获取任务对象。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID。
|
||||
|
||||
返回:
|
||||
- Job | None: 任务对象,未找到则为 None。
|
||||
"""
|
||||
return scheduler.get_job(job_id=str(job_id))
|
||||
|
||||
@classmethod
|
||||
def get_all_jobs(cls) -> list[Job]:
|
||||
"""
|
||||
获取全部调度任务列表。
|
||||
|
||||
返回:
|
||||
- list[Job]: 任务列表。
|
||||
"""
|
||||
return scheduler.get_jobs()
|
||||
|
||||
@classmethod
|
||||
async def _task_wrapper(cls, job_id, func, *args, **kwargs):
|
||||
"""
|
||||
任务执行包装器,添加分布式锁防止同一任务被多个实例同时执行。
|
||||
|
||||
参数:
|
||||
- job_id: 任务ID
|
||||
- func: 实际要执行的任务函数
|
||||
- *args: 任务函数位置参数
|
||||
- **kwargs: 任务函数关键字参数
|
||||
|
||||
返回:
|
||||
- 任务函数的返回值
|
||||
"""
|
||||
import redis.asyncio as redis
|
||||
import asyncio
|
||||
from app.config.setting import settings
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=int(settings.REDIS_PORT),
|
||||
username=settings.REDIS_USER,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
db=int(settings.REDIS_DB_NAME),
|
||||
)
|
||||
|
||||
# 生成锁键
|
||||
lock_key = f"job_lock:{job_id}"
|
||||
|
||||
# 设置锁的过期时间(根据任务类型调整,这里设置为30秒)
|
||||
lock_expire = 30
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
# 尝试获取锁
|
||||
lock_acquired = await redis_client.set(lock_key, "1", ex=lock_expire, nx=True)
|
||||
|
||||
if lock_acquired:
|
||||
log.info(f"任务 {job_id} 获取执行锁成功")
|
||||
# 执行任务
|
||||
if iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
# 对于同步函数,使用线程池执行
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, func, *args, **kwargs)
|
||||
else:
|
||||
# 获取锁失败,记录日志
|
||||
log.info(f"任务 {job_id} 获取执行锁失败,跳过本次执行")
|
||||
return None
|
||||
finally:
|
||||
# 释放锁
|
||||
if lock_acquired:
|
||||
await redis_client.delete(lock_key)
|
||||
log.info(f"任务 {job_id} 释放执行锁")
|
||||
|
||||
@classmethod
|
||||
def add_job(cls, job_info: JobModel) -> Job:
|
||||
"""
|
||||
根据任务配置创建并添加调度任务。
|
||||
|
||||
参数:
|
||||
- job_info (JobModel): 任务对象信息(包含触发器、函数、参数等)。
|
||||
|
||||
返回:
|
||||
- Job: 新增的任务对象。
|
||||
"""
|
||||
# 动态导入模块
|
||||
# 1. 解析调用目标
|
||||
module_path, func_name = str(job_info.func).rsplit('.', 1)
|
||||
module_path = "app.api.v1.module_application.job.function_task." + module_path
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
job_func = getattr(module, func_name)
|
||||
|
||||
# 2. 确定任务存储器:优先使用redis,确保分布式环境中任务同步
|
||||
if job_info.jobstore is None:
|
||||
job_info.jobstore = 'redis' # 改为默认使用redis存储
|
||||
|
||||
# 3. 确定执行器
|
||||
job_executor = job_info.executor
|
||||
if job_executor is None:
|
||||
job_executor = 'default'
|
||||
|
||||
if job_info.trigger_args is None:
|
||||
raise ValueError("触发器缺少参数")
|
||||
|
||||
# 异步函数必须使用默认执行器
|
||||
if iscoroutinefunction(job_func):
|
||||
job_executor = 'default'
|
||||
|
||||
# 4. 创建触发器
|
||||
if job_info.trigger == 'date':
|
||||
trigger = DateTrigger(run_date=job_info.trigger_args)
|
||||
elif job_info.trigger == 'interval':
|
||||
# 将传入的 interval 表达式拆分为不同的字段
|
||||
fields = job_info.trigger_args.strip().split()
|
||||
if len(fields) != 5:
|
||||
raise ValueError("无效的 interval 表达式")
|
||||
second, minute, hour, day, week = tuple([int(field) if field != '*' else 0 for field in fields])
|
||||
# 秒、分、时、天、周(* * * * 1)
|
||||
trigger = IntervalTrigger(
|
||||
weeks=week,
|
||||
days=day,
|
||||
hours=hour,
|
||||
minutes=minute,
|
||||
seconds=second,
|
||||
start_date=job_info.start_date,
|
||||
end_date=job_info.end_date,
|
||||
timezone='Asia/Shanghai',
|
||||
jitter=None
|
||||
)
|
||||
elif job_info.trigger == 'cron':
|
||||
# 秒、分、时、天、月、星期几、年 ()
|
||||
fields = job_info.trigger_args.strip().split()
|
||||
if len(fields) not in (6, 7):
|
||||
raise ValueError("无效的 Cron 表达式")
|
||||
if not CronUtil.validate_cron_expression(job_info.trigger_args):
|
||||
raise ValueError(f'定时任务{job_info.name}, Cron表达式不正确')
|
||||
|
||||
parsed_fields = [None if field in ('*', '?') else field for field in fields]
|
||||
if len(fields) == 6:
|
||||
parsed_fields.append(None)
|
||||
|
||||
second, minute, hour, day, month, day_of_week, year = tuple(parsed_fields)
|
||||
trigger = CronTrigger(
|
||||
second=second,
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day,
|
||||
month=month,
|
||||
day_of_week=day_of_week,
|
||||
year=year,
|
||||
start_date=job_info.start_date,
|
||||
end_date=job_info.end_date,
|
||||
timezone='Asia/Shanghai'
|
||||
)
|
||||
else:
|
||||
raise ValueError("无效的 trigger 触发器")
|
||||
|
||||
# 5. 添加任务(使用包装器函数)
|
||||
job = scheduler.add_job(
|
||||
func=cls._task_wrapper,
|
||||
trigger=trigger,
|
||||
args=[str(job_info.id), job_func] + (str(job_info.args).split(',') if job_info.args else []),
|
||||
kwargs=json.loads(job_info.kwargs) if job_info.kwargs else {},
|
||||
id=str(job_info.id),
|
||||
name=job_info.name,
|
||||
coalesce=job_info.coalesce,
|
||||
max_instances=1, # 确保只有一个实例执行
|
||||
jobstore=job_info.jobstore,
|
||||
executor=job_executor,
|
||||
)
|
||||
log.info(f"任务 {job_info.id} 添加到 {job_info.jobstore} 存储器成功")
|
||||
return job
|
||||
except ModuleNotFoundError:
|
||||
raise ValueError(f"未找到该模块:{module_path}")
|
||||
except AttributeError:
|
||||
raise ValueError(f"未找到该模块下的方法:{func_name}")
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"添加任务失败: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def remove_job(cls, job_id: str | int) -> None:
|
||||
"""
|
||||
根据任务ID删除调度任务。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
query_job = cls.get_job(job_id=str(job_id))
|
||||
if query_job:
|
||||
scheduler.remove_job(job_id=str(job_id))
|
||||
|
||||
@classmethod
|
||||
def clear_jobs(cls) -> None:
|
||||
"""
|
||||
删除所有调度任务。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
scheduler.remove_all_jobs()
|
||||
|
||||
@classmethod
|
||||
def modify_job(cls, job_id: str | int) -> Job:
|
||||
"""
|
||||
更新指定任务的配置(运行中的任务下次执行生效)。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID。
|
||||
|
||||
返回:
|
||||
- Job: 更新后的任务对象。
|
||||
|
||||
异常:
|
||||
- CustomException: 当任务不存在时抛出。
|
||||
"""
|
||||
query_job = cls.get_job(job_id=str(job_id))
|
||||
if not query_job:
|
||||
raise CustomException(msg=f"未找到该任务:{job_id}")
|
||||
return scheduler.modify_job(job_id=str(job_id))
|
||||
|
||||
@classmethod
|
||||
def pause_job(cls, job_id: str | int) -> None:
|
||||
"""
|
||||
暂停指定任务(仅运行中可暂停,已终止不可)。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID。
|
||||
|
||||
返回:
|
||||
- None
|
||||
|
||||
异常:
|
||||
- ValueError: 当任务不存在时抛出。
|
||||
"""
|
||||
query_job = cls.get_job(job_id=str(job_id))
|
||||
if not query_job:
|
||||
raise ValueError(f"未找到该任务:{job_id}")
|
||||
scheduler.pause_job(job_id=str(job_id))
|
||||
|
||||
@classmethod
|
||||
def resume_job(cls, job_id: str | int) -> None:
|
||||
"""
|
||||
恢复指定任务(仅暂停中可恢复,已终止不可)。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID。
|
||||
|
||||
返回:
|
||||
- None
|
||||
|
||||
异常:
|
||||
- ValueError: 当任务不存在时抛出。
|
||||
"""
|
||||
query_job = cls.get_job(job_id=str(job_id))
|
||||
if not query_job:
|
||||
raise ValueError(f"未找到该任务:{job_id}")
|
||||
scheduler.resume_job(job_id=str(job_id))
|
||||
|
||||
@classmethod
|
||||
def reschedule_job(cls, job_id: str | int, trigger=None, **trigger_args) -> Job | None:
|
||||
"""
|
||||
重启指定任务的触发器。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID。
|
||||
- trigger: 触发器类型
|
||||
- **trigger_args: 触发器参数
|
||||
|
||||
返回:
|
||||
- Job: 更新后的任务对象
|
||||
|
||||
异常:
|
||||
- CustomException: 当任务不存在时抛出。
|
||||
"""
|
||||
query_job = cls.get_job(job_id=str(job_id))
|
||||
if not query_job:
|
||||
raise CustomException(msg=f"未找到该任务:{job_id}")
|
||||
|
||||
# 如果没有提供新的触发器,则使用现有触发器
|
||||
if trigger is None:
|
||||
# 获取当前任务的触发器配置
|
||||
current_trigger = query_job.trigger
|
||||
# 重新调度任务,使用当前的触发器
|
||||
return scheduler.reschedule_job(job_id=str(job_id), trigger=current_trigger)
|
||||
else:
|
||||
# 使用新提供的触发器
|
||||
return scheduler.reschedule_job(job_id=str(job_id), trigger=trigger, **trigger_args)
|
||||
|
||||
@classmethod
|
||||
def get_single_job_status(cls, job_id: str | int) -> str:
|
||||
"""
|
||||
获取单个任务的当前状态。
|
||||
|
||||
参数:
|
||||
- job_id (str | int): 任务ID
|
||||
|
||||
返回:
|
||||
- str: 任务状态('running' | 'paused' | 'stopped' | 'unknown')
|
||||
"""
|
||||
job = cls.get_job(job_id=str(job_id))
|
||||
if not job:
|
||||
return 'unknown'
|
||||
|
||||
# 检查任务是否在暂停列表中
|
||||
if job_id in scheduler._jobstores[job._jobstore_alias]._paused_jobs:
|
||||
return 'paused'
|
||||
|
||||
# 检查调度器状态
|
||||
if scheduler.state == 0: # STATE_STOPPED
|
||||
return 'stopped'
|
||||
|
||||
return 'running'
|
||||
|
||||
@classmethod
|
||||
def print_jobs(cls,jobstore: Any | None = None, out: Any | None = None):
|
||||
"""
|
||||
打印调度任务列表。
|
||||
|
||||
参数:
|
||||
- jobstore (Any | None): 任务存储别名。
|
||||
- out (Any | None): 输出目标。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
scheduler.print_jobs(jobstore=jobstore, out=out)
|
||||
|
||||
@classmethod
|
||||
def get_job_status(cls) -> str:
|
||||
"""
|
||||
获取调度器当前状态。
|
||||
|
||||
返回:
|
||||
- str: 状态字符串('stopped' | 'running' | 'paused' | 'unknown')。
|
||||
"""
|
||||
#: constant indicating a scheduler's stopped state
|
||||
STATE_STOPPED = 0
|
||||
#: constant indicating a scheduler's running state (started and processing jobs)
|
||||
STATE_RUNNING = 1
|
||||
#: constant indicating a scheduler's paused state (started but not processing jobs)
|
||||
STATE_PAUSED = 2
|
||||
if scheduler.state == STATE_STOPPED:
|
||||
return 'stopped'
|
||||
elif scheduler.state == STATE_RUNNING:
|
||||
return 'running'
|
||||
elif scheduler.state == STATE_PAUSED:
|
||||
return 'paused'
|
||||
else:
|
||||
return 'unknown'
|
||||
@@ -0,0 +1,9 @@
|
||||
'''
|
||||
Author: caoziyuan ziyuan.cao@zhuying.com
|
||||
Date: 2025-12-22 17:25:15
|
||||
LastEditors: caoziyuan ziyuan.cao@zhuying.com
|
||||
LastEditTime: 2025-12-22 17:25:48
|
||||
FilePath: \backend\app\api\v1\module_application\miniapp\__init__.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
@@ -0,0 +1,40 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from app.common.response import SuccessResponse
|
||||
from app.core.logger import log
|
||||
from app.core.dependencies import db_getter, redis_getter
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
from .service import MiniappService
|
||||
from .schema import MiniappLoginSchema, MiniappLoginOutSchema
|
||||
|
||||
|
||||
MiniappRouter = APIRouter(prefix="/miniapp", tags=["小程序"])
|
||||
|
||||
|
||||
@MiniappRouter.post("/login", summary="小程序登录", description="微信小程序用户登录", response_model=MiniappLoginOutSchema)
|
||||
async def miniapp_login_controller(
|
||||
data: MiniappLoginSchema,
|
||||
db: AsyncSession = Depends(db_getter),
|
||||
redis: Redis = Depends(redis_getter),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
小程序登录接口
|
||||
|
||||
前端调用 wx.login() 获取 code,传给此接口换取 token
|
||||
|
||||
参数:
|
||||
- data (MiniappLoginSchema): 包含微信登录code
|
||||
|
||||
返回:
|
||||
- MiniappLoginOutSchema: 包含access_token和用户信息
|
||||
"""
|
||||
auth = AuthSchema(db=db)
|
||||
result = await MiniappService.login_service(auth=auth, redis=redis, data=data)
|
||||
log.info(f"小程序用户登录成功")
|
||||
return SuccessResponse(data=result, msg="登录成功")
|
||||
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from app.core.base_crud import CRUDBase
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .model import MiniappUserModel
|
||||
from .schema import MiniappUserCreateSchema, MiniappUserUpdateSchema
|
||||
|
||||
|
||||
class MiniappUserCRUD(CRUDBase[MiniappUserModel, MiniappUserCreateSchema, MiniappUserUpdateSchema]):
|
||||
"""小程序用户数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
super().__init__(model=MiniappUserModel, auth=auth)
|
||||
|
||||
async def get_by_openid(self, openid: str) -> MiniappUserModel | None:
|
||||
"""根据openid获取用户"""
|
||||
return await self.get(openid=openid)
|
||||
|
||||
async def get_by_id_crud(self, id: int) -> MiniappUserModel | None:
|
||||
"""根据ID获取用户"""
|
||||
return await self.get(id=id)
|
||||
|
||||
async def update_last_login(self, id: int) -> MiniappUserModel | None:
|
||||
"""更新最后登录时间"""
|
||||
return await self.update(id=id, data={"last_login": datetime.now()})
|
||||
|
||||
async def update_session_key(self, id: int, session_key: str) -> MiniappUserModel | None:
|
||||
"""更新session_key"""
|
||||
return await self.update(id=id, data={"session_key": session_key})
|
||||
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, DateTime, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.base_model import ModelMixin
|
||||
|
||||
|
||||
class MiniappUserModel(ModelMixin):
|
||||
"""
|
||||
小程序用户模型
|
||||
|
||||
存储微信小程序用户信息
|
||||
"""
|
||||
__tablename__: str = "miniapp_user"
|
||||
__table_args__: dict[str, str] = ({'comment': '小程序用户表'})
|
||||
|
||||
openid: Mapped[str] = mapped_column(String(64), nullable=False, unique=True, index=True, comment="微信openid")
|
||||
unionid: Mapped[str | None] = mapped_column(String(64), nullable=True, unique=True, comment="微信unionid")
|
||||
session_key: Mapped[str | None] = mapped_column(String(64), nullable=True, comment="会话密钥")
|
||||
nickname: Mapped[str | None] = mapped_column(String(64), nullable=True, comment="昵称")
|
||||
avatar: Mapped[str | None] = mapped_column(String(512), nullable=True, comment="头像URL")
|
||||
phone: Mapped[str | None] = mapped_column(String(20), nullable=True, comment="手机号")
|
||||
last_login: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True, comment="最后登录时间")
|
||||
@@ -0,0 +1,36 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from app.core.base_schema import BaseSchema
|
||||
|
||||
|
||||
class MiniappLoginSchema(BaseModel):
|
||||
"""小程序登录请求"""
|
||||
code: str = Field(..., min_length=1, description="微信登录code")
|
||||
|
||||
|
||||
class MiniappUserCreateSchema(BaseModel):
|
||||
"""小程序用户创建"""
|
||||
openid: str = Field(..., max_length=64, description="微信openid")
|
||||
unionid: str | None = Field(default=None, max_length=64, description="微信unionid")
|
||||
session_key: str | None = Field(default=None, max_length=64, description="会话密钥")
|
||||
nickname: str | None = Field(default=None, max_length=64, description="昵称")
|
||||
avatar: str | None = Field(default=None, max_length=512, description="头像URL")
|
||||
|
||||
|
||||
class MiniappUserUpdateSchema(MiniappUserCreateSchema):
|
||||
"""小程序用户更新"""
|
||||
phone: str | None = Field(default=None, max_length=20, description="手机号")
|
||||
|
||||
|
||||
class MiniappUserOutSchema(MiniappUserUpdateSchema, BaseSchema):
|
||||
"""小程序用户响应"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class MiniappLoginOutSchema(BaseModel):
|
||||
"""小程序登录响应"""
|
||||
access_token: str = Field(..., description="访问令牌")
|
||||
token_type: str = Field(default="Bearer", description="令牌类型")
|
||||
expires_in: int = Field(..., description="过期时间(秒)")
|
||||
user: MiniappUserOutSchema = Field(..., description="用户信息")
|
||||
@@ -0,0 +1,153 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import uuid
|
||||
import json
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.logger import log
|
||||
from app.core.security import create_access_token
|
||||
from app.core.redis_crud import RedisCURD
|
||||
from app.common.enums import RedisInitKeyConfig
|
||||
from app.config.setting import settings
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema, JWTPayloadSchema
|
||||
|
||||
from .crud import MiniappUserCRUD
|
||||
from .schema import (
|
||||
MiniappLoginSchema,
|
||||
MiniappUserCreateSchema,
|
||||
MiniappUserOutSchema,
|
||||
MiniappLoginOutSchema,
|
||||
)
|
||||
|
||||
|
||||
class MiniappService:
|
||||
"""小程序服务层"""
|
||||
|
||||
# 微信登录接口
|
||||
WX_LOGIN_URL = "https://api.weixin.qq.com/sns/jscode2session"
|
||||
|
||||
@classmethod
|
||||
async def login_service(cls, auth: AuthSchema, redis: Redis, data: MiniappLoginSchema) -> dict:
|
||||
"""
|
||||
小程序登录
|
||||
|
||||
流程:
|
||||
1. 用微信code换取openid和session_key
|
||||
2. 查找或创建用户
|
||||
3. 生成JWT token
|
||||
"""
|
||||
# 1. 调用微信接口获取openid
|
||||
wx_result = await cls._get_wx_session(code=data.code)
|
||||
openid = wx_result.get("openid")
|
||||
session_key = wx_result.get("session_key")
|
||||
unionid = wx_result.get("unionid")
|
||||
|
||||
if not openid:
|
||||
raise CustomException(msg="微信登录失败,无法获取openid")
|
||||
|
||||
# 2. 查找或创建用户
|
||||
user = await MiniappUserCRUD(auth).get_by_openid(openid=openid)
|
||||
|
||||
if user:
|
||||
# 更新session_key和登录时间
|
||||
await MiniappUserCRUD(auth).update_session_key(id=user.id, session_key=session_key)
|
||||
await MiniappUserCRUD(auth).update_last_login(id=user.id)
|
||||
log.info(f"小程序用户登录: {openid}")
|
||||
else:
|
||||
# 创建新用户
|
||||
user_data = MiniappUserCreateSchema(
|
||||
openid=openid,
|
||||
unionid=unionid,
|
||||
session_key=session_key,
|
||||
)
|
||||
user = await MiniappUserCRUD(auth).create(data=user_data)
|
||||
log.info(f"小程序新用户注册: {openid}")
|
||||
|
||||
# 3. 生成token
|
||||
token_data = await cls._create_miniapp_token(redis=redis, user_id=user.id, openid=openid)
|
||||
|
||||
return MiniappLoginOutSchema(
|
||||
access_token=token_data["access_token"],
|
||||
token_type="Bearer",
|
||||
expires_in=token_data["expires_in"],
|
||||
user=MiniappUserOutSchema.model_validate(user)
|
||||
).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def get_user_info_service(cls, auth: AuthSchema, user_id: int) -> dict:
|
||||
"""获取用户信息"""
|
||||
user = await MiniappUserCRUD(auth).get_by_id_crud(id=user_id)
|
||||
if not user:
|
||||
raise CustomException(msg="用户不存在")
|
||||
return MiniappUserOutSchema.model_validate(user).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def _get_wx_session(cls, code: str) -> dict:
|
||||
"""
|
||||
调用微信接口获取session信息
|
||||
|
||||
注意: 需要在配置中设置 MINIAPP_APPID 和 MINIAPP_SECRET
|
||||
"""
|
||||
appid = getattr(settings, "MINIAPP_APPID", None)
|
||||
secret = getattr(settings, "MINIAPP_SECRET", None)
|
||||
|
||||
if not appid or not secret:
|
||||
# 开发环境模拟返回
|
||||
log.warning("未配置小程序appid和secret,使用模拟数据")
|
||||
return {
|
||||
"openid": f"mock_openid_{code[:8]}",
|
||||
"session_key": "mock_session_key",
|
||||
"unionid": None
|
||||
}
|
||||
|
||||
params = {
|
||||
"appid": appid,
|
||||
"secret": secret,
|
||||
"js_code": code,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(cls.WX_LOGIN_URL, params=params)
|
||||
result = response.json()
|
||||
|
||||
if "errcode" in result and result["errcode"] != 0:
|
||||
log.error(f"微信登录失败: {result}")
|
||||
raise CustomException(msg=f"微信登录失败: {result.get('errmsg', '未知错误')}")
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def _create_miniapp_token(cls, redis: Redis, user_id: int, openid: str) -> dict:
|
||||
"""创建小程序用户token"""
|
||||
session_id = str(uuid.uuid4())
|
||||
access_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
now = datetime.now()
|
||||
|
||||
session_info = json.dumps({
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"openid": openid,
|
||||
"login_type": "miniapp"
|
||||
})
|
||||
|
||||
access_token = create_access_token(payload=JWTPayloadSchema(
|
||||
sub=session_info,
|
||||
is_refresh=False,
|
||||
exp=now + access_expires,
|
||||
))
|
||||
|
||||
# 存储到Redis
|
||||
await RedisCURD(redis).set(
|
||||
key=f'{RedisInitKeyConfig.ACCESS_TOKEN.key}:miniapp:{session_id}',
|
||||
value=access_token,
|
||||
expire=int(access_expires.total_seconds())
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"expires_in": int(access_expires.total_seconds())
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.common.response import SuccessResponse
|
||||
from app.common.request import PaginationService
|
||||
from app.core.base_params import PaginationQueryParam
|
||||
from app.core.dependencies import AuthPermission
|
||||
from app.core.base_schema import BatchSetAvailable
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from app.core.router_class import OperationLogRoute
|
||||
from .service import ApplicationService
|
||||
from .schema import (
|
||||
ApplicationCreateSchema,
|
||||
ApplicationUpdateSchema,
|
||||
ApplicationQueryParam
|
||||
)
|
||||
|
||||
|
||||
MyAppRouter = APIRouter(route_class=OperationLogRoute, prefix="/myapp", tags=["应用管理"])
|
||||
|
||||
@MyAppRouter.get("/detail/{id}", summary="获取应用详情", description="获取应用详情")
|
||||
async def get_obj_detail_controller(
|
||||
id: int = Path(..., description="应用ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:myapp:query"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取应用详情
|
||||
|
||||
参数:
|
||||
- id (int): 应用ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含应用详情的JSON响应
|
||||
"""
|
||||
result_dict = await ApplicationService.detail_service(id=id, auth=auth)
|
||||
log.info(f"获取应用详情成功 {id}")
|
||||
return SuccessResponse(data=result_dict, msg="获取应用详情成功")
|
||||
|
||||
@MyAppRouter.get("/list", summary="查询应用列表", description="查询应用列表")
|
||||
async def get_obj_list_controller(
|
||||
page: PaginationQueryParam = Depends(),
|
||||
search: ApplicationQueryParam = Depends(),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:myapp:query"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
查询应用列表
|
||||
|
||||
参数:
|
||||
- page (PaginationQueryParam): 分页参数模型
|
||||
- search (ApplicationQueryParam): 查询参数模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含应用列表的JSON响应
|
||||
"""
|
||||
result_dict_list = await ApplicationService.list_service(auth=auth, search=search, order_by=page.order_by)
|
||||
result_dict = await PaginationService.paginate(data_list=result_dict_list, page_no=page.page_no, page_size=page.page_size)
|
||||
log.info(f"查询应用列表成功")
|
||||
return SuccessResponse(data=result_dict, msg="查询应用列表成功")
|
||||
|
||||
@MyAppRouter.post("/create", summary="创建应用", description="创建应用")
|
||||
async def create_obj_controller(
|
||||
data: ApplicationCreateSchema,
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:myapp:create"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
创建应用
|
||||
|
||||
参数:
|
||||
- data (ApplicationCreateSchema): 应用创建模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含创建应用详情的JSON响应
|
||||
"""
|
||||
result_dict = await ApplicationService.create_service(auth=auth, data=data)
|
||||
log.info(f"创建应用成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="创建应用成功")
|
||||
|
||||
@MyAppRouter.put("/update/{id}", summary="修改应用", description="修改应用")
|
||||
async def update_obj_controller(
|
||||
data: ApplicationUpdateSchema,
|
||||
id: int = Path(..., description="应用ID"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:myapp:update"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
修改应用
|
||||
|
||||
参数:
|
||||
- data (ApplicationUpdateSchema): 应用更新模型
|
||||
- id (int): 应用ID
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含修改应用详情的JSON响应
|
||||
"""
|
||||
result_dict = await ApplicationService.update_service(auth=auth, id=id, data=data)
|
||||
log.info(f"修改应用成功: {result_dict}")
|
||||
return SuccessResponse(data=result_dict, msg="修改应用成功")
|
||||
|
||||
@MyAppRouter.delete("/delete", summary="删除应用", description="删除应用")
|
||||
async def delete_obj_controller(
|
||||
ids: list[int] = Body(..., description="ID列表"),
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:myapp:delete"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
删除应用
|
||||
|
||||
参数:
|
||||
- ids (list[int]): 应用ID列表
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含删除应用详情的JSON响应
|
||||
"""
|
||||
await ApplicationService.delete_service(auth=auth, ids=ids)
|
||||
log.info(f"删除应用成功: {ids}")
|
||||
return SuccessResponse(msg="删除应用成功")
|
||||
|
||||
@MyAppRouter.patch("/available/setting", summary="批量修改应用状态", description="批量修改应用状态")
|
||||
async def batch_set_available_obj_controller(
|
||||
data: BatchSetAvailable,
|
||||
auth: AuthSchema = Depends(AuthPermission(["module_application:myapp:patch"]))
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
批量修改应用状态
|
||||
|
||||
参数:
|
||||
- data (BatchSetAvailable): 批量修改应用状态模型
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
|
||||
返回:
|
||||
- JSONResponse: 批量修改应用状态成功
|
||||
"""
|
||||
await ApplicationService.set_available_service(auth=auth, data=data)
|
||||
log.info(f"批量修改应用状态成功: {data.ids}")
|
||||
return SuccessResponse(msg="批量修改应用状态成功")
|
||||
@@ -0,0 +1,101 @@
|
||||
'''
|
||||
Author: caoziyuan ziyuan.cao@zhuying.com
|
||||
Date: 2025-12-15 17:37:50
|
||||
LastEditors: caoziyuan ziyuan.cao@zhuying.com
|
||||
LastEditTime: 2025-12-22 17:26:54
|
||||
FilePath: \backend\app\api\v1\module_application\myapp\crud.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Sequence, Any
|
||||
|
||||
from app.core.base_crud import CRUDBase
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .model import ApplicationModel
|
||||
from .schema import ApplicationCreateSchema, ApplicationUpdateSchema
|
||||
|
||||
|
||||
class ApplicationCRUD(CRUDBase[ApplicationModel, ApplicationCreateSchema, ApplicationUpdateSchema]):
|
||||
"""应用系统数据层"""
|
||||
|
||||
def __init__(self, auth: AuthSchema) -> None:
|
||||
"""
|
||||
初始化应用CRUD
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
"""
|
||||
self.auth = auth
|
||||
super().__init__(model=ApplicationModel, auth=auth)
|
||||
|
||||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> ApplicationModel | None:
|
||||
"""
|
||||
根据id获取应用详情
|
||||
|
||||
参数:
|
||||
- id (int): 应用ID
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- ApplicationModel | None: 应用详情,如果不存在则为None
|
||||
"""
|
||||
return await self.get(id=id, preload=preload)
|
||||
|
||||
async def list_crud(self, search: dict[str, Any] | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[ApplicationModel]:
|
||||
"""
|
||||
列表查询应用
|
||||
|
||||
参数:
|
||||
- search (dict[str, Any] | None): 查询参数,默认None
|
||||
- order_by (list[dict[str, str]] | None): 排序参数,默认None
|
||||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||||
|
||||
返回:
|
||||
- Sequence[ApplicationModel]: 应用列表
|
||||
"""
|
||||
return await self.list(search=search, order_by=order_by, preload=preload)
|
||||
|
||||
async def create_crud(self, data: ApplicationCreateSchema) -> ApplicationModel | None:
|
||||
"""
|
||||
创建应用
|
||||
|
||||
参数:
|
||||
- data (ApplicationCreateSchema): 应用创建模型
|
||||
|
||||
返回:
|
||||
- ApplicationModel | None: 创建的应用详情,如果创建失败则为None
|
||||
"""
|
||||
return await self.create(data=data)
|
||||
|
||||
async def update_crud(self, id: int, data: ApplicationUpdateSchema) -> ApplicationModel | None:
|
||||
"""
|
||||
更新应用
|
||||
|
||||
参数:
|
||||
- id (int): 应用ID
|
||||
- data (ApplicationUpdateSchema): 应用更新模型
|
||||
|
||||
返回:
|
||||
- ApplicationModel | None: 更新后的应用详情,如果更新失败则为None
|
||||
"""
|
||||
return await self.update(id=id, data=data)
|
||||
|
||||
async def delete_crud(self, ids: list[int]) -> None:
|
||||
"""
|
||||
批量删除应用
|
||||
|
||||
参数:
|
||||
- ids (list[int]): 应用ID列表
|
||||
"""
|
||||
return await self.delete(ids=ids)
|
||||
|
||||
async def set_available_crud(self, ids: list[int], status: str) -> None:
|
||||
"""
|
||||
批量设置可用状态
|
||||
|
||||
参数:
|
||||
- ids (list[int]): 应用ID列表
|
||||
- status (str): 可用状态,True为可用,False为不可用
|
||||
"""
|
||||
return await self.set(ids=ids, status=status)
|
||||
@@ -0,0 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.base_model import ModelMixin, UserMixin
|
||||
|
||||
|
||||
class ApplicationModel(ModelMixin, UserMixin):
|
||||
"""
|
||||
应用系统表
|
||||
"""
|
||||
__tablename__: str = 'app_myapp'
|
||||
__table_args__: dict[str, str] = ({'comment': '应用系统表'})
|
||||
__loader_options__: list[str] = ["created_by", "updated_by"]
|
||||
|
||||
name: Mapped[str] = mapped_column(String(64), nullable=False, comment='应用名称')
|
||||
access_url: Mapped[str] = mapped_column(String(500), nullable=False, comment='访问地址')
|
||||
icon_url: Mapped[str | None] = mapped_column(String(300), nullable=True, comment='应用图标URL')
|
||||
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from urllib.parse import urlparse
|
||||
from fastapi import Query
|
||||
|
||||
from app.core.validator import DateTimeStr
|
||||
from app.core.base_schema import BaseSchema, UserBySchema
|
||||
|
||||
|
||||
class ApplicationCreateSchema(BaseModel):
|
||||
"""应用创建模型"""
|
||||
name: str = Field(..., max_length=64, description='应用名称')
|
||||
access_url: str = Field(..., max_length=255, description="访问地址")
|
||||
icon_url: str | None = Field(None, max_length=300, description="应用图标URL")
|
||||
status: str = Field("0", description="是否启用(0:启用 1:禁用)")
|
||||
description: str | None = Field(default=None, max_length=255, description="描述")
|
||||
|
||||
@field_validator('access_url')
|
||||
@classmethod
|
||||
def _validate_access_url(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError('访问地址不能为空')
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
raise ValueError('访问地址必须为 http/https URL')
|
||||
return v
|
||||
|
||||
@field_validator('icon_url')
|
||||
@classmethod
|
||||
def _validate_icon_url(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
if v == "":
|
||||
return None
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
raise ValueError('应用图标URL必须为 http/https URL')
|
||||
return v
|
||||
|
||||
|
||||
class ApplicationUpdateSchema(ApplicationCreateSchema):
|
||||
"""应用更新模型"""
|
||||
...
|
||||
|
||||
|
||||
class ApplicationOutSchema(ApplicationCreateSchema, BaseSchema, UserBySchema):
|
||||
"""应用响应模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ApplicationQueryParam:
|
||||
"""应用系统查询参数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = Query(None, description="应用名称"),
|
||||
status: str | None = Query(None, description="是否启用"),
|
||||
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
||||
created_id: int | None = Query(None, description="创建人"),
|
||||
updated_id: int | None = Query(None, description="更新人"),
|
||||
) -> None:
|
||||
|
||||
# 模糊查询字段
|
||||
self.name = ("like", name) if name else None
|
||||
|
||||
# 精确查询字段
|
||||
self.status = status
|
||||
self.created_id = created_id
|
||||
self.updated_id = updated_id
|
||||
|
||||
# 时间范围查询
|
||||
if created_time and len(created_time) == 2:
|
||||
self.created_time = ("between", (created_time[0], created_time[1]))
|
||||
if updated_time and len(updated_time) == 2:
|
||||
self.updated_time = ("between", (updated_time[0], updated_time[1]))
|
||||
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from app.core.base_schema import BatchSetAvailable
|
||||
from app.core.exceptions import CustomException
|
||||
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from .schema import (
|
||||
ApplicationCreateSchema,
|
||||
ApplicationUpdateSchema,
|
||||
ApplicationOutSchema,
|
||||
ApplicationQueryParam
|
||||
)
|
||||
from .crud import ApplicationCRUD
|
||||
|
||||
|
||||
class ApplicationService:
|
||||
"""
|
||||
应用系统管理服务层
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def detail_service(cls, auth: AuthSchema, id: int) -> dict:
|
||||
"""
|
||||
获取应用详情
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- id (int): 应用ID
|
||||
|
||||
返回:
|
||||
- dict: 应用详情字典
|
||||
"""
|
||||
obj = await ApplicationCRUD(auth).get_by_id_crud(id=id)
|
||||
if not obj:
|
||||
raise CustomException(msg='应用不存在')
|
||||
return ApplicationOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def list_service(cls, auth: AuthSchema, search: ApplicationQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict]:
|
||||
"""
|
||||
获取应用列表
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- search (ApplicationQueryParam | None): 查询参数模型
|
||||
- order_by (list[dict[str, str]] | None): 排序参数,支持字符串或字典列表
|
||||
|
||||
返回:
|
||||
- list[dict]: 应用详情字典列表
|
||||
"""
|
||||
# 过滤空值
|
||||
search_dict = search.__dict__ if search else None
|
||||
obj_list = await ApplicationCRUD(auth).list_crud(search=search_dict, order_by=order_by)
|
||||
return [ApplicationOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||||
|
||||
@classmethod
|
||||
async def create_service(cls, auth: AuthSchema, data: ApplicationCreateSchema) -> dict:
|
||||
"""
|
||||
创建应用
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- data (ApplicationCreateSchema): 应用创建模型
|
||||
|
||||
返回:
|
||||
- Dict: 应用详情字典
|
||||
"""
|
||||
# 检查名称是否重复
|
||||
obj = await ApplicationCRUD(auth).get(name=data.name)
|
||||
if obj:
|
||||
raise CustomException(msg='创建失败,应用名称已存在')
|
||||
|
||||
obj = await ApplicationCRUD(auth).create_crud(data=data)
|
||||
return ApplicationOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def update_service(cls, auth: AuthSchema, id: int, data: ApplicationUpdateSchema) -> dict:
|
||||
"""
|
||||
更新应用
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- id (int): 应用ID
|
||||
- data (ApplicationUpdateSchema): 应用更新模型
|
||||
|
||||
返回:
|
||||
- Dict: 应用详情字典
|
||||
"""
|
||||
obj = await ApplicationCRUD(auth).get_by_id_crud(id=id)
|
||||
if not obj:
|
||||
raise CustomException(msg='更新失败,该应用不存在')
|
||||
|
||||
# 检查名称重复
|
||||
exist_obj = await ApplicationCRUD(auth).get(name=data.name)
|
||||
if exist_obj and exist_obj.id != id:
|
||||
raise CustomException(msg='更新失败,应用名称重复')
|
||||
|
||||
obj = await ApplicationCRUD(auth).update_crud(id=id, data=data)
|
||||
return ApplicationOutSchema.model_validate(obj).model_dump()
|
||||
|
||||
@classmethod
|
||||
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||||
"""
|
||||
删除应用
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- ids (list[int]): 应用ID列表
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
if len(ids) < 1:
|
||||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||||
for id in ids:
|
||||
obj = await ApplicationCRUD(auth).get_by_id_crud(id=id)
|
||||
if not obj:
|
||||
raise CustomException(msg=f'删除失败,应用 {id} 不存在')
|
||||
await ApplicationCRUD(auth).delete_crud(ids=ids)
|
||||
|
||||
@classmethod
|
||||
async def set_available_service(cls, auth: AuthSchema, data: BatchSetAvailable) -> None:
|
||||
"""
|
||||
批量设置应用状态
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息模型
|
||||
- data (BatchSetAvailable): 批量设置应用状态模型
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
await ApplicationCRUD(auth).set_available_crud(ids=data.ids, status=data.status)
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# 工作流接口-开发中...
|
||||
Reference in New Issue
Block a user