1232 lines
53 KiB
Python
1232 lines
53 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
import re
|
||
import asyncio
|
||
import aiofiles
|
||
from pathlib import Path
|
||
from typing import Any, AsyncGenerator, List
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
from fastapi import UploadFile
|
||
|
||
from app.core.exceptions import CustomException
|
||
from app.core.logger import log
|
||
from app.core.database import async_db_session
|
||
from app.config.setting import settings
|
||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||
from .tools.ai_util import AIClient
|
||
# 延迟导入向量化工具,避免影响路由加载
|
||
# from .tools.document_processor import DocumentProcessor
|
||
# from .tools.embedding_util import EmbeddingUtil
|
||
from .schema import (
|
||
McpCreateSchema, McpUpdateSchema, McpOutSchema, ChatQuerySchema, McpQueryParam,
|
||
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderOutSchema, AIProviderQueryParam,
|
||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigOutSchema, EmbeddingConfigQueryParam,
|
||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseOutSchema, KnowledgeBaseQueryParam,
|
||
AIModelConfigCreateSchema, AIModelConfigUpdateSchema, AIModelConfigOutSchema, AIModelConfigQueryParam,
|
||
AIModelTrainingMessageCreateSchema, AIModelTrainingMessageOutSchema, AIModelTrainingChatSchema,
|
||
AI_MODEL_TYPES
|
||
)
|
||
from .crud import McpCRUD, AIProviderCRUD, EmbeddingConfigCRUD, KnowledgeBaseCRUD, AIModelConfigCRUD, AIModelTrainingMessageCRUD
|
||
from .model import EmbeddingConfigModel, AIModelConfigModel
|
||
|
||
|
||
class McpService:
|
||
"""MCP服务层"""
|
||
|
||
@classmethod
|
||
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
|
||
"""
|
||
获取MCP服务器详情
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
- id (int): MCP服务器ID
|
||
|
||
返回:
|
||
- dict[str, Any]: MCP服务器详情字典
|
||
"""
|
||
obj = await McpCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='MCP 服务器不存在')
|
||
return McpOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def list_service(cls, auth: AuthSchema, search: McpQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
|
||
"""
|
||
列表查询MCP服务器
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
- search (McpQueryParam | None): 查询参数模型
|
||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||
|
||
返回:
|
||
- list[dict[str, Any]]: MCP服务器详情字典列表
|
||
"""
|
||
search_dict = search.__dict__ if search else None
|
||
obj_list = await McpCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
|
||
return [McpOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||
|
||
@classmethod
|
||
async def create_service(cls, auth: AuthSchema, data: McpCreateSchema) -> dict[str, Any]:
|
||
"""
|
||
创建MCP服务器
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
- data (McpCreateSchema): 创建MCP服务器模型
|
||
|
||
返回:
|
||
- dict[str, Any]: 创建的MCP服务器详情字典
|
||
"""
|
||
obj = await McpCRUD(auth).get_by_name_crud(name=data.name)
|
||
if obj:
|
||
raise CustomException(msg='创建失败,MCP 服务器已存在')
|
||
obj = await McpCRUD(auth).create_crud(data=data)
|
||
return McpOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def update_service(cls, auth: AuthSchema, id: int, data: McpUpdateSchema) -> dict[str, Any]:
|
||
"""
|
||
更新MCP服务器
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
- id (int): MCP服务器ID
|
||
- data (McpUpdateSchema): 更新MCP服务器模型
|
||
|
||
返回:
|
||
- dict[str, Any]: 更新的MCP服务器详情字典
|
||
"""
|
||
obj = await McpCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='更新失败,该数据不存在')
|
||
exist_obj = await McpCRUD(auth).get_by_name_crud(name=data.name)
|
||
if exist_obj and exist_obj.id != id:
|
||
raise CustomException(msg='更新失败,MCP 服务器名称重复')
|
||
obj = await McpCRUD(auth).update_crud(id=id, data=data)
|
||
return McpOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||
"""
|
||
批量删除MCP服务器
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
- ids (list[int]): MCP服务器ID列表
|
||
|
||
返回:
|
||
- None
|
||
"""
|
||
if len(ids) < 1:
|
||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||
for id in ids:
|
||
obj = await McpCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='删除失败,该数据不存在')
|
||
await McpCRUD(auth).delete_crud(ids=ids)
|
||
|
||
@classmethod
|
||
async def chat_query(cls, query: ChatQuerySchema) -> AsyncGenerator[str, Any]:
|
||
"""
|
||
处理聊天查询
|
||
|
||
参数:
|
||
- query (ChatQuerySchema): 聊天查询模型
|
||
|
||
返回:
|
||
- AsyncGenerator[str, None]: 异步生成器,每次返回一个聊天响应
|
||
"""
|
||
# 创建MCP客户端实例
|
||
mcp_client = AIClient()
|
||
try:
|
||
# 处理消息
|
||
async for response in mcp_client.process(query.message):
|
||
yield response
|
||
finally:
|
||
# 确保关闭客户端连接
|
||
await mcp_client.close()
|
||
|
||
|
||
class AIProviderService:
|
||
"""AI供应商服务层"""
|
||
|
||
@classmethod
|
||
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
|
||
obj = await AIProviderCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='AI供应商不存在')
|
||
return AIProviderOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def list_service(cls, auth: AuthSchema, search: AIProviderQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
|
||
search_dict = search.__dict__ if search else None
|
||
obj_list = await AIProviderCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
|
||
return [AIProviderOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||
|
||
@classmethod
|
||
async def create_service(cls, auth: AuthSchema, data: AIProviderCreateSchema) -> dict[str, Any]:
|
||
obj = await AIProviderCRUD(auth).get_by_name_crud(name=data.name)
|
||
if obj:
|
||
raise CustomException(msg='创建失败,AI供应商名称已存在')
|
||
# 如果设置为默认,先清除其他默认
|
||
if data.is_default == 1:
|
||
await AIProviderCRUD(auth).clear_default_crud()
|
||
obj = await AIProviderCRUD(auth).create_crud(data=data)
|
||
return AIProviderOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def update_service(cls, auth: AuthSchema, id: int, data: AIProviderUpdateSchema) -> dict[str, Any]:
|
||
obj = await AIProviderCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='更新失败,该数据不存在')
|
||
exist_obj = await AIProviderCRUD(auth).get_by_name_crud(name=data.name)
|
||
if exist_obj and exist_obj.id != id:
|
||
raise CustomException(msg='更新失败,AI供应商名称重复')
|
||
# 如果设置为默认,先清除其他默认
|
||
if data.is_default == 1:
|
||
await AIProviderCRUD(auth).clear_default_crud()
|
||
obj = await AIProviderCRUD(auth).update_crud(id=id, data=data)
|
||
return AIProviderOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||
if len(ids) < 1:
|
||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||
for id in ids:
|
||
obj = await AIProviderCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='删除失败,该数据不存在')
|
||
await AIProviderCRUD(auth).delete_crud(ids=ids)
|
||
|
||
|
||
class EmbeddingConfigService:
|
||
"""向量化配置服务层"""
|
||
|
||
@classmethod
|
||
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
|
||
obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='向量化配置不存在')
|
||
return EmbeddingConfigOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def list_service(cls, auth: AuthSchema, search: EmbeddingConfigQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
|
||
search_dict = search.__dict__ if search else None
|
||
obj_list = await EmbeddingConfigCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
|
||
return [EmbeddingConfigOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||
|
||
@classmethod
|
||
async def create_service(cls, auth: AuthSchema, data: EmbeddingConfigCreateSchema) -> dict[str, Any]:
|
||
obj = await EmbeddingConfigCRUD(auth).get_by_name_crud(name=data.name)
|
||
if obj:
|
||
raise CustomException(msg='创建失败,向量化配置名称已存在')
|
||
# 远程模式验证
|
||
if data.embedding_type == 1:
|
||
if not data.base_url or not data.api_key:
|
||
raise CustomException(msg='远程模式必须提供接口地址和API Key')
|
||
# 如果设置为默认,先清除其他默认
|
||
if data.is_default == 1:
|
||
await EmbeddingConfigCRUD(auth).clear_default_crud()
|
||
obj = await EmbeddingConfigCRUD(auth).create_crud(data=data)
|
||
return EmbeddingConfigOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def update_service(cls, auth: AuthSchema, id: int, data: EmbeddingConfigUpdateSchema) -> dict[str, Any]:
|
||
obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='更新失败,该数据不存在')
|
||
exist_obj = await EmbeddingConfigCRUD(auth).get_by_name_crud(name=data.name)
|
||
if exist_obj and exist_obj.id != id:
|
||
raise CustomException(msg='更新失败,向量化配置名称重复')
|
||
# 远程模式验证
|
||
if data.embedding_type == 1:
|
||
if not data.base_url or not data.api_key:
|
||
raise CustomException(msg='远程模式必须提供接口地址和API Key')
|
||
# 如果设置为默认,先清除其他默认
|
||
if data.is_default == 1:
|
||
await EmbeddingConfigCRUD(auth).clear_default_crud()
|
||
obj = await EmbeddingConfigCRUD(auth).update_crud(id=id, data=data)
|
||
return EmbeddingConfigOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||
if len(ids) < 1:
|
||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||
for id in ids:
|
||
obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='删除失败,该数据不存在')
|
||
await EmbeddingConfigCRUD(auth).delete_crud(ids=ids)
|
||
|
||
|
||
class KnowledgeBaseService:
|
||
"""知识库服务层"""
|
||
|
||
@staticmethod
|
||
def _generate_collection_name(name: str) -> str:
|
||
"""根据知识库名称生成ChromaDB集合名称"""
|
||
import time
|
||
# 移除特殊字符,只保留字母数字下划线
|
||
clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||
timestamp = int(time.time())
|
||
return f"kb_{clean_name}_{timestamp}"
|
||
|
||
@classmethod
|
||
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
|
||
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='知识库不存在')
|
||
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def list_service(cls, auth: AuthSchema, search: KnowledgeBaseQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
|
||
search_dict = search.__dict__ if search else None
|
||
obj_list = await KnowledgeBaseCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
|
||
return [KnowledgeBaseOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||
|
||
@classmethod
|
||
async def create_service(cls, auth: AuthSchema, data: KnowledgeBaseCreateSchema, files: List[UploadFile] | None = None) -> dict[str, Any]:
|
||
obj = await KnowledgeBaseCRUD(auth).get_by_name_crud(name=data.name)
|
||
if obj:
|
||
raise CustomException(msg='创建失败,知识库名称已存在')
|
||
|
||
# 验证向量化配置是否存在
|
||
embedding_config = None
|
||
if data.embedding_config_id:
|
||
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=data.embedding_config_id)
|
||
if not embedding_config:
|
||
raise CustomException(msg='向量化配置不存在')
|
||
|
||
# 生成集合名称
|
||
collection_name = cls._generate_collection_name(data.name)
|
||
|
||
# 计算文档数量
|
||
document_count = len(files) if files else 0
|
||
|
||
# 创建知识库记录,初始状态为处理中(如果有文件)或已完成(如果没有文件)
|
||
initial_status = 1 if files and len(files) > 0 else 2 # 1=处理中, 2=已完成
|
||
|
||
# 保存文件到磁盘
|
||
file_paths = []
|
||
file_contents = []
|
||
if files and len(files) > 0:
|
||
# 创建知识库专用目录
|
||
kb_dir = settings.UPLOAD_FILE_PATH / "knowledge_base" / collection_name
|
||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
for file in files:
|
||
if not file.filename:
|
||
continue
|
||
# 读取文件内容
|
||
content = await file.read()
|
||
# 保存文件到磁盘
|
||
file_path = kb_dir / file.filename
|
||
async with aiofiles.open(file_path, 'wb') as f:
|
||
await f.write(content)
|
||
file_paths.append(str(file_path))
|
||
file_contents.append({
|
||
'filename': file.filename,
|
||
'content': content,
|
||
'content_type': file.content_type
|
||
})
|
||
|
||
# 创建知识库记录
|
||
create_data = {
|
||
'name': data.name,
|
||
'embedding_config_id': data.embedding_config_id,
|
||
'collection_name': collection_name,
|
||
'description': data.description,
|
||
'document_count': document_count,
|
||
'vector_count': 0,
|
||
'kb_status': initial_status,
|
||
'file_paths': file_paths if file_paths else None
|
||
}
|
||
obj = await KnowledgeBaseCRUD(auth).create_crud(data=create_data)
|
||
knowledge_base_id = obj.id
|
||
|
||
# 如果有文件上传,在后台进行向量化处理
|
||
if file_contents:
|
||
# 提取 embedding 配置信息
|
||
embedding_info = None
|
||
if embedding_config:
|
||
embedding_info = {
|
||
'embedding_type': embedding_config.embedding_type,
|
||
'model_name': embedding_config.model_name,
|
||
'base_url': embedding_config.base_url,
|
||
'api_key': embedding_config.api_key,
|
||
}
|
||
|
||
# 启动后台向量化任务
|
||
asyncio.create_task(
|
||
cls._process_vectorization(
|
||
knowledge_base_id=knowledge_base_id,
|
||
collection_name=collection_name,
|
||
file_contents=file_contents,
|
||
embedding_info=embedding_info
|
||
)
|
||
)
|
||
|
||
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def _process_vectorization(
|
||
cls,
|
||
knowledge_base_id: int,
|
||
collection_name: str,
|
||
file_contents: List[dict],
|
||
embedding_info: dict | None
|
||
) -> None:
|
||
"""
|
||
后台向量化处理任务
|
||
|
||
参数:
|
||
- knowledge_base_id: 知识库ID
|
||
- collection_name: ChromaDB集合名称
|
||
- file_contents: 文件内容列表 [{'filename': str, 'content': bytes, 'content_type': str}]
|
||
- embedding_info: Embedding配置信息
|
||
"""
|
||
from langchain.schema import Document
|
||
|
||
try:
|
||
# 1. 处理文档
|
||
all_documents = []
|
||
for file_info in file_contents:
|
||
filename = file_info['filename']
|
||
content = file_info['content']
|
||
|
||
if not filename:
|
||
continue
|
||
|
||
# 根据文件类型处理
|
||
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||
|
||
try:
|
||
if ext in ['txt', 'md']:
|
||
# 文本文件
|
||
text = None
|
||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||
try:
|
||
text = content.decode(encoding)
|
||
break
|
||
except UnicodeDecodeError:
|
||
continue
|
||
|
||
if text:
|
||
all_documents.append(Document(
|
||
page_content=text,
|
||
metadata={"source": filename, "type": ext}
|
||
))
|
||
|
||
elif ext == 'pdf':
|
||
# PDF 文件
|
||
try:
|
||
import pypdf
|
||
from io import BytesIO
|
||
|
||
pdf_file = BytesIO(content)
|
||
reader = pypdf.PdfReader(pdf_file)
|
||
|
||
for page_num, page in enumerate(reader.pages):
|
||
text = page.extract_text()
|
||
if text and text.strip():
|
||
all_documents.append(Document(
|
||
page_content=text,
|
||
metadata={
|
||
"source": filename,
|
||
"type": "pdf",
|
||
"page": page_num + 1
|
||
}
|
||
))
|
||
except Exception as e:
|
||
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
|
||
|
||
elif ext == 'docx':
|
||
# DOCX 文件
|
||
try:
|
||
from docx import Document as DocxDocument
|
||
from io import BytesIO
|
||
|
||
docx_file = BytesIO(content)
|
||
doc = DocxDocument(docx_file)
|
||
|
||
paragraphs = []
|
||
for para in doc.paragraphs:
|
||
if para.text.strip():
|
||
paragraphs.append(para.text)
|
||
|
||
text = '\n'.join(paragraphs)
|
||
if text.strip():
|
||
all_documents.append(Document(
|
||
page_content=text,
|
||
metadata={"source": filename, "type": "docx"}
|
||
))
|
||
except Exception as e:
|
||
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
|
||
|
||
except Exception as e:
|
||
log.error(f"处理文件失败: {filename}, 错误: {e}")
|
||
|
||
if not all_documents:
|
||
# 没有有效文档,更新状态为失败向量化处理失败
|
||
await cls._update_kb_status(
|
||
knowledge_base_id=knowledge_base_id,
|
||
kb_status=3, # 失败
|
||
error_message="没有有效的文档内容"
|
||
)
|
||
return
|
||
|
||
# 2. 创建 Embedding 工具(延迟导入)
|
||
from .tools.embedding_util import EmbeddingUtil
|
||
|
||
if embedding_info:
|
||
embedding_util = EmbeddingUtil(
|
||
embedding_type=embedding_info['embedding_type'],
|
||
model_name=embedding_info['model_name'],
|
||
base_url=embedding_info['base_url'],
|
||
api_key=embedding_info['api_key']
|
||
)
|
||
else:
|
||
# 使用默认配置
|
||
embedding_util = EmbeddingUtil(
|
||
embedding_type=1, # 远程模式
|
||
model_name='text-embedding-3-small',
|
||
base_url='https://api.openai.com/v1',
|
||
api_key='your-api-key' # 需要配置
|
||
)
|
||
|
||
# 3. 在线程池中执行向量化(因为 ChromaDB 操作是同步的)
|
||
loop = asyncio.get_event_loop()
|
||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||
vector_count = await loop.run_in_executor(
|
||
executor,
|
||
embedding_util.add_documents,
|
||
collection_name,
|
||
all_documents
|
||
)
|
||
|
||
# 4. 更新知识库状态为已完成
|
||
await cls._update_kb_status(
|
||
knowledge_base_id=knowledge_base_id,
|
||
kb_status=2, # 已完成
|
||
vector_count=vector_count
|
||
)
|
||
|
||
except Exception as e:
|
||
log.error(f"向量化处理失败,知识库ID: {knowledge_base_id}, 错误: {e}")
|
||
# 更新状态为失败,截断错误信息到50字
|
||
error_msg = str(e)[:50] if len(str(e)) > 50 else str(e)
|
||
await cls._update_kb_status(
|
||
knowledge_base_id=knowledge_base_id,
|
||
kb_status=3, # 失败
|
||
error_message=error_msg
|
||
)
|
||
|
||
@classmethod
|
||
async def _update_kb_status(
|
||
cls,
|
||
knowledge_base_id: int,
|
||
kb_status: int,
|
||
vector_count: int = 0,
|
||
error_message: str | None = None
|
||
) -> None:
|
||
"""
|
||
更新知识库状态
|
||
|
||
参数:
|
||
- knowledge_base_id: 知识库ID
|
||
- kb_status: 状态 (0=待处理, 1=处理中, 2=已完成, 3=失败)
|
||
- vector_count: 向量数量
|
||
- error_message: 错误信息
|
||
"""
|
||
from .model import KnowledgeBaseModel
|
||
from sqlalchemy import update
|
||
|
||
try:
|
||
async with async_db_session() as session:
|
||
async with session.begin():
|
||
update_data = {
|
||
'kb_status': kb_status,
|
||
'vector_count': vector_count,
|
||
}
|
||
if error_message:
|
||
update_data['error_message'] = error_message
|
||
|
||
stmt = (
|
||
update(KnowledgeBaseModel)
|
||
.where(KnowledgeBaseModel.id == knowledge_base_id)
|
||
.values(**update_data)
|
||
)
|
||
await session.execute(stmt)
|
||
await session.commit()
|
||
except Exception as e:
|
||
log.error(f"更新知识库状态失败,ID: {knowledge_base_id}, 错误: {e}")
|
||
|
||
@classmethod
|
||
async def update_service(cls, auth: AuthSchema, id: int, data: KnowledgeBaseUpdateSchema) -> dict[str, Any]:
|
||
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='更新失败,该数据不存在')
|
||
if data.name:
|
||
exist_obj = await KnowledgeBaseCRUD(auth).get_by_name_crud(name=data.name)
|
||
if exist_obj and exist_obj.id != id:
|
||
raise CustomException(msg='更新失败,知识库名称重复')
|
||
# 验证向量化配置是否存在
|
||
if data.embedding_config_id:
|
||
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=data.embedding_config_id)
|
||
if not embedding_config:
|
||
raise CustomException(msg='向量化配置不存在')
|
||
obj = await KnowledgeBaseCRUD(auth).update_crud(id=id, data=data)
|
||
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
|
||
if len(ids) < 1:
|
||
raise CustomException(msg='删除失败,删除对象不能为空')
|
||
|
||
for id in ids:
|
||
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='删除失败,该数据不存在')
|
||
|
||
# 删除物理文件
|
||
if obj.file_paths:
|
||
for file_path in obj.file_paths:
|
||
try:
|
||
path = Path(file_path)
|
||
if path.exists():
|
||
path.unlink()
|
||
except Exception as e:
|
||
log.warning(f"删除文件失败: {file_path}, 错误: {e}")
|
||
|
||
# 删除知识库目录(如果为空)
|
||
try:
|
||
kb_dir = settings.UPLOAD_FILE_PATH / "knowledge_base" / obj.collection_name
|
||
if kb_dir.exists() and not any(kb_dir.iterdir()):
|
||
kb_dir.rmdir()
|
||
except Exception as e:
|
||
log.warning(f"删除知识库目录失败: {obj.collection_name}, 错误: {e}")
|
||
|
||
# 删除 ChromaDB 中的集合(延迟导入)
|
||
if obj.collection_name:
|
||
try:
|
||
from .tools.embedding_util import EmbeddingUtil
|
||
# 获取 embedding 配置
|
||
embedding_info = None
|
||
if obj.embedding_config_id:
|
||
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=obj.embedding_config_id)
|
||
if embedding_config:
|
||
embedding_info = {
|
||
'embedding_type': embedding_config.embedding_type,
|
||
'model_name': embedding_config.model_name,
|
||
'base_url': embedding_config.base_url,
|
||
'api_key': embedding_config.api_key,
|
||
}
|
||
|
||
# 创建 EmbeddingUtil 并删除集合
|
||
if embedding_info:
|
||
embedding_util = EmbeddingUtil(
|
||
embedding_type=embedding_info['embedding_type'],
|
||
model_name=embedding_info['model_name'],
|
||
base_url=embedding_info['base_url'],
|
||
api_key=embedding_info['api_key']
|
||
)
|
||
else:
|
||
embedding_util = EmbeddingUtil()
|
||
|
||
embedding_util.delete_collection(obj.collection_name)
|
||
except Exception as e:
|
||
log.warning(f"删除 ChromaDB 集合失败: {obj.collection_name}, 错误: {e}")
|
||
|
||
await KnowledgeBaseCRUD(auth).delete_crud(ids=ids)
|
||
|
||
@classmethod
|
||
async def retry_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
|
||
"""
|
||
重试向量化
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
- id (int): 知识库ID
|
||
|
||
返回:
|
||
- dict[str, Any]: 知识库详情字典
|
||
"""
|
||
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
|
||
if not obj:
|
||
raise CustomException(msg='知识库不存在')
|
||
|
||
if obj.kb_status != 3:
|
||
raise CustomException(msg='只有处理失败的知识库才能重试')
|
||
|
||
if not obj.file_paths or len(obj.file_paths) == 0:
|
||
raise CustomException(msg='没有找到已保存的文件,请删除后重新创建知识库')
|
||
|
||
# 从磁盘读取文件内容
|
||
file_contents = []
|
||
for file_path in obj.file_paths:
|
||
path = Path(file_path)
|
||
if not path.exists():
|
||
log.warning(f"文件不存在: {file_path}")
|
||
continue
|
||
async with aiofiles.open(path, 'rb') as f:
|
||
content = await f.read()
|
||
file_contents.append({
|
||
'filename': path.name,
|
||
'content': content,
|
||
'content_type': None
|
||
})
|
||
|
||
if not file_contents:
|
||
raise CustomException(msg='所有文件均已丢失,请删除后重新创建知识库')
|
||
|
||
# 获取 embedding 配置
|
||
embedding_info = None
|
||
if obj.embedding_config_id:
|
||
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=obj.embedding_config_id)
|
||
if embedding_config:
|
||
embedding_info = {
|
||
'embedding_type': embedding_config.embedding_type,
|
||
'model_name': embedding_config.model_name,
|
||
'base_url': embedding_config.base_url,
|
||
'api_key': embedding_config.api_key,
|
||
}
|
||
|
||
# 更新状态为处理中,清空错误信息
|
||
await cls._update_kb_status(
|
||
knowledge_base_id=id,
|
||
kb_status=1, # 处理中
|
||
vector_count=0,
|
||
error_message=''
|
||
)
|
||
|
||
# 启动后台向量化任务
|
||
asyncio.create_task(
|
||
cls._process_vectorization(
|
||
knowledge_base_id=id,
|
||
collection_name=obj.collection_name,
|
||
file_contents=file_contents,
|
||
embedding_info=embedding_info
|
||
)
|
||
)
|
||
|
||
# 返回更新后的知识库信息
|
||
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
|
||
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
|
||
|
||
|
||
class AIModelConfigService:
|
||
"""AI模型配置服务层"""
|
||
|
||
@classmethod
|
||
async def detail_service(cls, auth: AuthSchema, model_type: str) -> dict[str, Any]:
|
||
"""获取模型配置详情,如果不存在则创建默认配置"""
|
||
if model_type not in AI_MODEL_TYPES:
|
||
raise CustomException(msg=f'无效的模型类型: {model_type}')
|
||
|
||
obj = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
|
||
if not obj:
|
||
# 创建默认配置
|
||
create_data = {
|
||
'model_type': model_type,
|
||
'model_name': None,
|
||
'provider_id': None,
|
||
'system_prompt': None,
|
||
'temperature': 1.0,
|
||
'knowledge_base_ids': None
|
||
}
|
||
obj = await AIModelConfigCRUD(auth).create_crud(data=create_data)
|
||
|
||
return AIModelConfigOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def list_service(cls, auth: AuthSchema, search: AIModelConfigQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
|
||
"""查询模型配置列表"""
|
||
search_dict = search.__dict__ if search else None
|
||
obj_list = await AIModelConfigCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
|
||
return [AIModelConfigOutSchema.model_validate(obj).model_dump() for obj in obj_list]
|
||
|
||
@classmethod
|
||
async def update_service(cls, auth: AuthSchema, model_type: str, data: AIModelConfigUpdateSchema) -> dict[str, Any]:
|
||
"""更新模型配置"""
|
||
if model_type not in AI_MODEL_TYPES:
|
||
raise CustomException(msg=f'无效的模型类型: {model_type}')
|
||
|
||
obj = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
|
||
if not obj:
|
||
# 如果不存在,先创建
|
||
create_data = {
|
||
'model_type': model_type,
|
||
'model_name': data.model_name,
|
||
'provider_id': data.provider_id,
|
||
'system_prompt': data.system_prompt,
|
||
'temperature': data.temperature if data.temperature is not None else 1.0,
|
||
'knowledge_base_ids': data.knowledge_base_ids
|
||
}
|
||
obj = await AIModelConfigCRUD(auth).create_crud(data=create_data)
|
||
else:
|
||
# 更新现有配置
|
||
update_data = {}
|
||
if data.model_name is not None:
|
||
update_data['model_name'] = data.model_name
|
||
if data.provider_id is not None:
|
||
update_data['provider_id'] = data.provider_id
|
||
if data.system_prompt is not None:
|
||
update_data['system_prompt'] = data.system_prompt
|
||
if data.temperature is not None:
|
||
update_data['temperature'] = data.temperature
|
||
if data.knowledge_base_ids is not None:
|
||
update_data['knowledge_base_ids'] = data.knowledge_base_ids
|
||
|
||
if update_data:
|
||
obj = await AIModelConfigCRUD(auth).update_crud(id=obj.id, data=update_data)
|
||
|
||
return AIModelConfigOutSchema.model_validate(obj).model_dump()
|
||
|
||
@classmethod
|
||
async def get_available_models(cls, auth: AuthSchema, provider_id: int) -> list[dict[str, Any]]:
|
||
"""根据供应商配置的baseurl和key远程获取可用模型列表"""
|
||
import httpx
|
||
|
||
provider = await AIProviderCRUD(auth).get_by_id_crud(id=provider_id)
|
||
if not provider:
|
||
raise CustomException(msg='AI供应商不存在')
|
||
|
||
# 构建请求URL,兼容OpenAI标准API
|
||
base_url = provider.base_url.rstrip('/')
|
||
# 确保base_url包含/v1路径
|
||
if not base_url.endswith('/v1'):
|
||
base_url = f"{base_url}/v1"
|
||
models_url = f"{base_url}/models"
|
||
|
||
headers = {
|
||
'Authorization': f'Bearer {provider.api_key}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
|
||
response = await client.get(models_url, headers=headers)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
# 解析OpenAI标准格式的models响应
|
||
models_data = data.get('data', [])
|
||
models = []
|
||
for model in models_data:
|
||
model_id = model.get('id', '')
|
||
# 使用id作为name,如果没有单独的name字段
|
||
model_name = model.get('name', model_id)
|
||
if model_id:
|
||
models.append({'id': model_id, 'name': model_name})
|
||
|
||
# 按模型id排序
|
||
models.sort(key=lambda x: x['id'])
|
||
return models
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
log.error(f"获取模型列表失败, HTTP错误: {e.response.status_code}, {e.response.text}")
|
||
if e.response.status_code == 401:
|
||
raise CustomException(msg='API Key无效或已过期')
|
||
elif e.response.status_code == 403:
|
||
raise CustomException(msg='无权限访问模型列表')
|
||
else:
|
||
raise CustomException(msg=f'获取模型列表失败: HTTP {e.response.status_code}')
|
||
except httpx.RequestError as e:
|
||
log.error(f"获取模型列表失败, 网络错误: {str(e)}")
|
||
raise CustomException(msg=f'网络请求失败: {str(e)}')
|
||
except Exception as e:
|
||
log.error(f"获取模型列表失败: {str(e)}")
|
||
raise CustomException(msg=f'获取模型列表失败: {str(e)}')
|
||
|
||
|
||
class AIModelTrainingService:
|
||
"""AI模型训练对话服务层"""
|
||
|
||
@classmethod
|
||
async def get_messages_service(cls, auth: AuthSchema, model_type: str) -> list[dict[str, Any]]:
|
||
"""获取某个模型的训练对话记录"""
|
||
if model_type not in AI_MODEL_TYPES:
|
||
raise CustomException(msg=f'无效的模型类型: {model_type}')
|
||
|
||
# 获取模型配置
|
||
config = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
|
||
if not config:
|
||
return []
|
||
|
||
messages = await AIModelTrainingMessageCRUD(auth).get_list_by_config_crud(model_config_id=config.id)
|
||
return [AIModelTrainingMessageOutSchema.model_validate(msg).model_dump() for msg in messages]
|
||
|
||
@classmethod
|
||
async def delete_message_service(cls, auth: AuthSchema, message_id: int) -> None:
|
||
"""删除单条训练对话记录"""
|
||
msg = await AIModelTrainingMessageCRUD(auth).get_by_id_crud(id=message_id)
|
||
if not msg:
|
||
raise CustomException(msg='对话记录不存在')
|
||
await AIModelTrainingMessageCRUD(auth).delete_crud(ids=[message_id])
|
||
|
||
@classmethod
|
||
async def clear_messages_service(cls, auth: AuthSchema, model_type: str) -> None:
|
||
"""清空某个模型的所有训练对话记录"""
|
||
if model_type not in AI_MODEL_TYPES:
|
||
raise CustomException(msg=f'无效的模型类型: {model_type}')
|
||
|
||
config = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
|
||
if config:
|
||
await AIModelTrainingMessageCRUD(auth).delete_by_config_crud(model_config_id=config.id)
|
||
|
||
@classmethod
|
||
async def chat_stream(cls, auth: AuthSchema, chat_data: AIModelTrainingChatSchema) -> AsyncGenerator[str, Any]:
|
||
"""
|
||
训练对话流式输出
|
||
|
||
注意:由于 StreamingResponse 的特性,依赖注入的数据库会话在流式响应期间会被关闭,
|
||
因此需要使用独立的数据库会话来保存消息。
|
||
|
||
重要:yield 不能在 async with session.begin() 块内使用,否则会导致生成器无法正确执行
|
||
"""
|
||
from app.core.database import async_db_session
|
||
|
||
model_type = chat_data.model_type
|
||
if model_type not in AI_MODEL_TYPES:
|
||
yield f'无效的模型类型: {model_type}'
|
||
return
|
||
|
||
# 用于存储错误消息
|
||
error_message = None
|
||
# 用于存储AI调用所需的配置信息
|
||
config_id = None
|
||
config_system_prompt = None
|
||
config_model_name = None
|
||
config_temperature = None
|
||
provider_api_key = None
|
||
provider_base_url = None
|
||
messages = []
|
||
|
||
# 1. 使用独立的数据库会话处理所有数据库操作(不在此块内yield)
|
||
async with async_db_session() as session:
|
||
async with session.begin():
|
||
stream_auth = AuthSchema(db=session, user=auth.user, check_data_scope=False)
|
||
|
||
# 如果配置有变动,先保存配置
|
||
if chat_data.config_changed and chat_data.config_data:
|
||
await AIModelConfigService.update_service(auth=stream_auth, model_type=model_type, data=chat_data.config_data)
|
||
|
||
# 获取模型配置
|
||
config = await AIModelConfigCRUD(stream_auth).get_by_model_type_crud(model_type=model_type)
|
||
if not config:
|
||
error_message = '模型配置不存在,请先配置模型'
|
||
else:
|
||
# 保存用户消息
|
||
user_msg = await AIModelTrainingMessageCRUD(stream_auth).create_crud(data={
|
||
'model_config_id': config.id,
|
||
'role': 'user',
|
||
'content': chat_data.message
|
||
})
|
||
|
||
# 获取供应商配置
|
||
provider = None
|
||
if config.provider_id:
|
||
provider = await AIProviderCRUD(stream_auth).get_by_id_crud(id=config.provider_id)
|
||
|
||
if not provider:
|
||
providers = await AIProviderCRUD(stream_auth).get_list_crud(search={'is_default': 1})
|
||
if providers:
|
||
provider = providers[0]
|
||
|
||
if not provider:
|
||
error_message = '未配置AI供应商,请先在AI配置中添加供应商'
|
||
else:
|
||
# 获取历史对话记录
|
||
history_messages = await AIModelTrainingMessageCRUD(stream_auth).get_list_by_config_crud(model_config_id=config.id)
|
||
|
||
# 提取需要的信息
|
||
config_id = config.id
|
||
config_system_prompt = config.system_prompt
|
||
config_model_name = config.model_name
|
||
config_temperature = config.temperature
|
||
provider_api_key = provider.api_key
|
||
provider_base_url = provider.base_url
|
||
|
||
# 构建消息列表
|
||
if config_system_prompt:
|
||
messages.append({"role": "system", "content": config_system_prompt})
|
||
|
||
for msg in history_messages:
|
||
if msg.id != user_msg.id:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
messages.append({"role": "user", "content": chat_data.message})
|
||
|
||
# 提交事务
|
||
await session.commit()
|
||
|
||
# 2. 如果有错误,yield错误消息并返回
|
||
if error_message:
|
||
log.warning(f"[训练对话] 配置错误: {error_message}")
|
||
yield error_message
|
||
return
|
||
|
||
# 检查配置是否完整
|
||
if not provider_api_key or not provider_base_url:
|
||
log.error("[训练对话] 供应商配置不完整")
|
||
yield "供应商配置不完整,请检查API Key和Base URL"
|
||
return
|
||
|
||
# 3. 创建AI客户端并流式输出
|
||
import httpx
|
||
from openai import AsyncOpenAI
|
||
|
||
# 确保base_url包含/v1路径
|
||
if provider_base_url and not provider_base_url.rstrip('/').endswith('/v1'):
|
||
provider_base_url = provider_base_url.rstrip('/') + '/v1'
|
||
|
||
log.info(f"[训练对话] 开始调用AI,模型: {config_model_name}, 消息数: {len(messages)}")
|
||
log.info(f"[训练对话] API配置: base_url={provider_base_url}")
|
||
|
||
http_client = httpx.AsyncClient(
|
||
timeout=60.0,
|
||
follow_redirects=True,
|
||
verify=False # 禁用 SSL 验证
|
||
)
|
||
client = AsyncOpenAI(
|
||
api_key=provider_api_key,
|
||
base_url=provider_base_url,
|
||
http_client=http_client
|
||
)
|
||
|
||
full_response = ""
|
||
is_cancelled = False
|
||
response = None
|
||
try:
|
||
response = await client.chat.completions.create(
|
||
model=config_model_name or 'gpt-3.5-turbo',
|
||
messages=messages,
|
||
temperature=config_temperature,
|
||
stream=True
|
||
)
|
||
|
||
log.info(f"[训练对话] AI响应开始流式输出")
|
||
|
||
async for chunk in response:
|
||
# 处理不同的响应格式
|
||
content = None
|
||
if hasattr(chunk, 'choices') and chunk.choices:
|
||
delta = chunk.choices[0].delta
|
||
if hasattr(delta, 'content') and delta.content:
|
||
content = delta.content
|
||
|
||
if content:
|
||
full_response += content
|
||
yield content
|
||
|
||
log.info(f"[训练对话] AI响应完成,总长度: {len(full_response)}")
|
||
except asyncio.CancelledError:
|
||
# 客户端断开连接,停止AI输出
|
||
is_cancelled = True
|
||
log.info(f"[训练对话] 客户端断开连接,停止AI输出,已生成长度: {len(full_response)}")
|
||
# 关闭AI流式响应
|
||
if response:
|
||
try:
|
||
await response.close()
|
||
except Exception:
|
||
pass
|
||
raise # 重新抛出以便 FastAPI 知道请求已取消
|
||
except GeneratorExit:
|
||
# 生成器被关闭,通常是客户端断开
|
||
is_cancelled = True
|
||
log.info(f"[训练对话] 生成器关闭,停止AI输出,已生成长度: {len(full_response)}")
|
||
if response:
|
||
try:
|
||
await response.close()
|
||
except Exception:
|
||
pass
|
||
raise
|
||
except Exception as e:
|
||
log.error(f"AI训练对话失败: {str(e)}")
|
||
import traceback
|
||
log.error(f"AI训练对话异常详情: {traceback.format_exc()}")
|
||
error_msg = f"对话失败: {str(e)}"
|
||
full_response = error_msg
|
||
yield error_msg
|
||
finally:
|
||
await http_client.aclose()
|
||
|
||
# 4. 保存AI响应
|
||
if full_response and config_id:
|
||
async with async_db_session() as session:
|
||
async with session.begin():
|
||
save_auth = AuthSchema(db=session, user=auth.user, check_data_scope=False)
|
||
await AIModelTrainingMessageCRUD(save_auth).create_crud(data={
|
||
'model_config_id': config_id,
|
||
'role': 'assistant',
|
||
'content': full_response
|
||
})
|
||
await session.commit()
|
||
|
||
|
||
class AIModelTestService:
|
||
"""起名测试服务层(用于小程序/外部调用,仅读不写)"""
|
||
|
||
@classmethod
|
||
async def test_naming(cls, model_type: str, text: str) -> str:
|
||
"""
|
||
起名测试接口
|
||
|
||
- 读取模型配置和训练信息
|
||
- 内部流式调用AI
|
||
- 组合响应后一次性返回
|
||
- 仅读数据,不保存对话记录
|
||
|
||
参数:
|
||
model_type: 模型类型(enterprise_naming/personal_naming等)
|
||
text: 用户输入的文本
|
||
|
||
返回:
|
||
AI的完整响应文本
|
||
"""
|
||
from app.core.database import async_db_session
|
||
from app.api.v1.module_application.ai.schema import AI_MODEL_TYPES
|
||
|
||
if model_type not in AI_MODEL_TYPES:
|
||
raise CustomException(msg=f'无效的模型类型: {model_type}')
|
||
|
||
# 用于存储AI调用所需的配置信息
|
||
config_system_prompt = None
|
||
config_model_name = None
|
||
config_temperature = None
|
||
provider_api_key = None
|
||
provider_base_url = None
|
||
messages = []
|
||
|
||
# 1. 读取模型配置和训练记录(只读,不保存任何数据)
|
||
async with async_db_session() as session:
|
||
async with session.begin():
|
||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||
read_auth = AuthSchema(db=session, user=None, check_data_scope=False)
|
||
|
||
# 获取模型配置
|
||
config = await AIModelConfigCRUD(read_auth).get_by_model_type_crud(model_type=model_type)
|
||
if not config:
|
||
raise CustomException(msg='模型配置不存在,请先配置模型')
|
||
|
||
# 获取供应商配置
|
||
provider = None
|
||
if config.provider_id:
|
||
provider = await AIProviderCRUD(read_auth).get_by_id_crud(id=config.provider_id)
|
||
|
||
if not provider:
|
||
providers = await AIProviderCRUD(read_auth).get_list_crud(search={'is_default': 1})
|
||
if providers:
|
||
provider = providers[0]
|
||
|
||
if not provider:
|
||
raise CustomException(msg='未配置AI供应商,请先在AI配置中添加供应商')
|
||
|
||
# 获取历史训练对话记录(作为上下文)
|
||
history_messages = await AIModelTrainingMessageCRUD(read_auth).get_list_by_config_crud(model_config_id=config.id)
|
||
|
||
# 提取配置信息
|
||
config_system_prompt = config.system_prompt
|
||
config_model_name = config.model_name
|
||
config_temperature = config.temperature
|
||
provider_api_key = provider.api_key
|
||
provider_base_url = provider.base_url
|
||
|
||
# 构建消息列表
|
||
if config_system_prompt:
|
||
messages.append({"role": "system", "content": config_system_prompt})
|
||
|
||
# 添加历史训练对话作为上下文
|
||
for msg in history_messages:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 添加当前用户输入
|
||
messages.append({"role": "user", "content": text})
|
||
|
||
# 检查配置是否完整
|
||
if not provider_api_key or not provider_base_url:
|
||
raise CustomException(msg="供应商配置不完整,请检查API Key和Base URL")
|
||
|
||
# 2. 调用AI(内部流式,收集后一次性返回)
|
||
import httpx
|
||
import asyncio
|
||
from openai import AsyncOpenAI
|
||
|
||
# 确保base_url包含/v1路径
|
||
if provider_base_url and not provider_base_url.rstrip('/').endswith('/v1'):
|
||
provider_base_url = provider_base_url.rstrip('/') + '/v1'
|
||
|
||
log.info(f"[起名测试] 开始调用AI,模型类型: {model_type}, 模型: {config_model_name}, 消息数: {len(messages)}")
|
||
|
||
def _extract_status_code(err: Exception) -> int | None:
|
||
if hasattr(err, "status_code") and isinstance(getattr(err, "status_code"), int):
|
||
return getattr(err, "status_code")
|
||
response = getattr(err, "response", None)
|
||
if response is not None and hasattr(response, "status_code"):
|
||
return response.status_code
|
||
return None
|
||
|
||
def _extract_error_message(err: Exception) -> str:
|
||
body = getattr(err, "body", None)
|
||
if isinstance(body, dict) and "error" in body:
|
||
inner = body.get("error") or {}
|
||
msg = inner.get("message")
|
||
if isinstance(msg, str) and msg.strip():
|
||
return msg.strip()
|
||
return str(err)
|
||
|
||
retry_times = 3
|
||
base_sleep = 0.8
|
||
last_err: Exception | None = None
|
||
|
||
for attempt in range(1, retry_times + 1):
|
||
http_client = httpx.AsyncClient(
|
||
timeout=120.0,
|
||
follow_redirects=True,
|
||
verify=False
|
||
)
|
||
client = AsyncOpenAI(
|
||
api_key=provider_api_key,
|
||
base_url=provider_base_url,
|
||
http_client=http_client
|
||
)
|
||
|
||
full_response = ""
|
||
try:
|
||
response = await client.chat.completions.create(
|
||
model=config_model_name or 'gpt-3.5-turbo',
|
||
messages=messages,
|
||
temperature=config_temperature,
|
||
stream=True
|
||
)
|
||
|
||
async for chunk in response:
|
||
if hasattr(chunk, 'choices') and chunk.choices:
|
||
delta = chunk.choices[0].delta
|
||
if hasattr(delta, 'content') and delta.content:
|
||
full_response += delta.content
|
||
|
||
log.info(f"[起名测试] AI响应完成,总长度: {len(full_response)}")
|
||
return full_response
|
||
except Exception as e:
|
||
last_err = e
|
||
status_code = _extract_status_code(e)
|
||
err_msg = _extract_error_message(e)
|
||
transient = status_code in {408, 429, 500, 502, 503, 504} or isinstance(e, (httpx.TimeoutException, httpx.NetworkError))
|
||
if not transient or attempt >= retry_times:
|
||
log.error(f"[起名测试] AI调用失败(终止): {str(e)}")
|
||
import traceback
|
||
log.error(f"[起名测试] 异常详情: {traceback.format_exc()}")
|
||
raise CustomException(msg=f"服务暂时不可用,请稍后重试。({err_msg})", status_code=503)
|
||
|
||
sleep_s = base_sleep * (2 ** (attempt - 1))
|
||
log.warning(f"[起名测试] AI调用失败(第{attempt}次/{retry_times}次),{sleep_s:.2f}s后重试: {err_msg}")
|
||
await asyncio.sleep(sleep_s)
|
||
finally:
|
||
await http_client.aclose()
|
||
|
||
raise CustomException(msg=f"服务暂时不可用,请稍后重试。({_extract_error_message(last_err) if last_err else ''})", status_code=503)
|