Files

1232 lines
53 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
import re
import asyncio
import aiofiles
from pathlib import Path
from typing import Any, AsyncGenerator, List
from concurrent.futures import ThreadPoolExecutor
from fastapi import UploadFile
from app.core.exceptions import CustomException
from app.core.logger import log
from app.core.database import async_db_session
from app.config.setting import settings
from app.api.v1.module_system.auth.schema import AuthSchema
from .tools.ai_util import AIClient
# 延迟导入向量化工具,避免影响路由加载
# from .tools.document_processor import DocumentProcessor
# from .tools.embedding_util import EmbeddingUtil
from .schema import (
McpCreateSchema, McpUpdateSchema, McpOutSchema, ChatQuerySchema, McpQueryParam,
AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderOutSchema, AIProviderQueryParam,
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigOutSchema, EmbeddingConfigQueryParam,
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseOutSchema, KnowledgeBaseQueryParam,
AIModelConfigCreateSchema, AIModelConfigUpdateSchema, AIModelConfigOutSchema, AIModelConfigQueryParam,
AIModelTrainingMessageCreateSchema, AIModelTrainingMessageOutSchema, AIModelTrainingChatSchema,
AI_MODEL_TYPES
)
from .crud import McpCRUD, AIProviderCRUD, EmbeddingConfigCRUD, KnowledgeBaseCRUD, AIModelConfigCRUD, AIModelTrainingMessageCRUD
from .model import EmbeddingConfigModel, AIModelConfigModel
class McpService:
"""MCP服务层"""
@classmethod
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
"""
获取MCP服务器详情
参数:
- auth (AuthSchema): 认证信息模型
- id (int): MCP服务器ID
返回:
- dict[str, Any]: MCP服务器详情字典
"""
obj = await McpCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='MCP 服务器不存在')
return McpOutSchema.model_validate(obj).model_dump()
@classmethod
async def list_service(cls, auth: AuthSchema, search: McpQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
"""
列表查询MCP服务器
参数:
- auth (AuthSchema): 认证信息模型
- search (McpQueryParam | None): 查询参数模型
- order_by (list[dict[str, str]] | None): 排序参数列表
返回:
- list[dict[str, Any]]: MCP服务器详情字典列表
"""
search_dict = search.__dict__ if search else None
obj_list = await McpCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
return [McpOutSchema.model_validate(obj).model_dump() for obj in obj_list]
@classmethod
async def create_service(cls, auth: AuthSchema, data: McpCreateSchema) -> dict[str, Any]:
"""
创建MCP服务器
参数:
- auth (AuthSchema): 认证信息模型
- data (McpCreateSchema): 创建MCP服务器模型
返回:
- dict[str, Any]: 创建的MCP服务器详情字典
"""
obj = await McpCRUD(auth).get_by_name_crud(name=data.name)
if obj:
raise CustomException(msg='创建失败MCP 服务器已存在')
obj = await McpCRUD(auth).create_crud(data=data)
return McpOutSchema.model_validate(obj).model_dump()
@classmethod
async def update_service(cls, auth: AuthSchema, id: int, data: McpUpdateSchema) -> dict[str, Any]:
"""
更新MCP服务器
参数:
- auth (AuthSchema): 认证信息模型
- id (int): MCP服务器ID
- data (McpUpdateSchema): 更新MCP服务器模型
返回:
- dict[str, Any]: 更新的MCP服务器详情字典
"""
obj = await McpCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='更新失败,该数据不存在')
exist_obj = await McpCRUD(auth).get_by_name_crud(name=data.name)
if exist_obj and exist_obj.id != id:
raise CustomException(msg='更新失败MCP 服务器名称重复')
obj = await McpCRUD(auth).update_crud(id=id, data=data)
return McpOutSchema.model_validate(obj).model_dump()
@classmethod
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
"""
批量删除MCP服务器
参数:
- auth (AuthSchema): 认证信息模型
- ids (list[int]): MCP服务器ID列表
返回:
- None
"""
if len(ids) < 1:
raise CustomException(msg='删除失败,删除对象不能为空')
for id in ids:
obj = await McpCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='删除失败,该数据不存在')
await McpCRUD(auth).delete_crud(ids=ids)
@classmethod
async def chat_query(cls, query: ChatQuerySchema) -> AsyncGenerator[str, Any]:
"""
处理聊天查询
参数:
- query (ChatQuerySchema): 聊天查询模型
返回:
- AsyncGenerator[str, None]: 异步生成器,每次返回一个聊天响应
"""
# 创建MCP客户端实例
mcp_client = AIClient()
try:
# 处理消息
async for response in mcp_client.process(query.message):
yield response
finally:
# 确保关闭客户端连接
await mcp_client.close()
class AIProviderService:
"""AI供应商服务层"""
@classmethod
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
obj = await AIProviderCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='AI供应商不存在')
return AIProviderOutSchema.model_validate(obj).model_dump()
@classmethod
async def list_service(cls, auth: AuthSchema, search: AIProviderQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
search_dict = search.__dict__ if search else None
obj_list = await AIProviderCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
return [AIProviderOutSchema.model_validate(obj).model_dump() for obj in obj_list]
@classmethod
async def create_service(cls, auth: AuthSchema, data: AIProviderCreateSchema) -> dict[str, Any]:
obj = await AIProviderCRUD(auth).get_by_name_crud(name=data.name)
if obj:
raise CustomException(msg='创建失败AI供应商名称已存在')
# 如果设置为默认,先清除其他默认
if data.is_default == 1:
await AIProviderCRUD(auth).clear_default_crud()
obj = await AIProviderCRUD(auth).create_crud(data=data)
return AIProviderOutSchema.model_validate(obj).model_dump()
@classmethod
async def update_service(cls, auth: AuthSchema, id: int, data: AIProviderUpdateSchema) -> dict[str, Any]:
obj = await AIProviderCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='更新失败,该数据不存在')
exist_obj = await AIProviderCRUD(auth).get_by_name_crud(name=data.name)
if exist_obj and exist_obj.id != id:
raise CustomException(msg='更新失败AI供应商名称重复')
# 如果设置为默认,先清除其他默认
if data.is_default == 1:
await AIProviderCRUD(auth).clear_default_crud()
obj = await AIProviderCRUD(auth).update_crud(id=id, data=data)
return AIProviderOutSchema.model_validate(obj).model_dump()
@classmethod
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
if len(ids) < 1:
raise CustomException(msg='删除失败,删除对象不能为空')
for id in ids:
obj = await AIProviderCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='删除失败,该数据不存在')
await AIProviderCRUD(auth).delete_crud(ids=ids)
class EmbeddingConfigService:
"""向量化配置服务层"""
@classmethod
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='向量化配置不存在')
return EmbeddingConfigOutSchema.model_validate(obj).model_dump()
@classmethod
async def list_service(cls, auth: AuthSchema, search: EmbeddingConfigQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
search_dict = search.__dict__ if search else None
obj_list = await EmbeddingConfigCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
return [EmbeddingConfigOutSchema.model_validate(obj).model_dump() for obj in obj_list]
@classmethod
async def create_service(cls, auth: AuthSchema, data: EmbeddingConfigCreateSchema) -> dict[str, Any]:
obj = await EmbeddingConfigCRUD(auth).get_by_name_crud(name=data.name)
if obj:
raise CustomException(msg='创建失败,向量化配置名称已存在')
# 远程模式验证
if data.embedding_type == 1:
if not data.base_url or not data.api_key:
raise CustomException(msg='远程模式必须提供接口地址和API Key')
# 如果设置为默认,先清除其他默认
if data.is_default == 1:
await EmbeddingConfigCRUD(auth).clear_default_crud()
obj = await EmbeddingConfigCRUD(auth).create_crud(data=data)
return EmbeddingConfigOutSchema.model_validate(obj).model_dump()
@classmethod
async def update_service(cls, auth: AuthSchema, id: int, data: EmbeddingConfigUpdateSchema) -> dict[str, Any]:
obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='更新失败,该数据不存在')
exist_obj = await EmbeddingConfigCRUD(auth).get_by_name_crud(name=data.name)
if exist_obj and exist_obj.id != id:
raise CustomException(msg='更新失败,向量化配置名称重复')
# 远程模式验证
if data.embedding_type == 1:
if not data.base_url or not data.api_key:
raise CustomException(msg='远程模式必须提供接口地址和API Key')
# 如果设置为默认,先清除其他默认
if data.is_default == 1:
await EmbeddingConfigCRUD(auth).clear_default_crud()
obj = await EmbeddingConfigCRUD(auth).update_crud(id=id, data=data)
return EmbeddingConfigOutSchema.model_validate(obj).model_dump()
@classmethod
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
if len(ids) < 1:
raise CustomException(msg='删除失败,删除对象不能为空')
for id in ids:
obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='删除失败,该数据不存在')
await EmbeddingConfigCRUD(auth).delete_crud(ids=ids)
class KnowledgeBaseService:
"""知识库服务层"""
@staticmethod
def _generate_collection_name(name: str) -> str:
"""根据知识库名称生成ChromaDB集合名称"""
import time
# 移除特殊字符,只保留字母数字下划线
clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
timestamp = int(time.time())
return f"kb_{clean_name}_{timestamp}"
@classmethod
async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='知识库不存在')
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
@classmethod
async def list_service(cls, auth: AuthSchema, search: KnowledgeBaseQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
search_dict = search.__dict__ if search else None
obj_list = await KnowledgeBaseCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
return [KnowledgeBaseOutSchema.model_validate(obj).model_dump() for obj in obj_list]
@classmethod
async def create_service(cls, auth: AuthSchema, data: KnowledgeBaseCreateSchema, files: List[UploadFile] | None = None) -> dict[str, Any]:
obj = await KnowledgeBaseCRUD(auth).get_by_name_crud(name=data.name)
if obj:
raise CustomException(msg='创建失败,知识库名称已存在')
# 验证向量化配置是否存在
embedding_config = None
if data.embedding_config_id:
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=data.embedding_config_id)
if not embedding_config:
raise CustomException(msg='向量化配置不存在')
# 生成集合名称
collection_name = cls._generate_collection_name(data.name)
# 计算文档数量
document_count = len(files) if files else 0
# 创建知识库记录,初始状态为处理中(如果有文件)或已完成(如果没有文件)
initial_status = 1 if files and len(files) > 0 else 2 # 1=处理中, 2=已完成
# 保存文件到磁盘
file_paths = []
file_contents = []
if files and len(files) > 0:
# 创建知识库专用目录
kb_dir = settings.UPLOAD_FILE_PATH / "knowledge_base" / collection_name
kb_dir.mkdir(parents=True, exist_ok=True)
for file in files:
if not file.filename:
continue
# 读取文件内容
content = await file.read()
# 保存文件到磁盘
file_path = kb_dir / file.filename
async with aiofiles.open(file_path, 'wb') as f:
await f.write(content)
file_paths.append(str(file_path))
file_contents.append({
'filename': file.filename,
'content': content,
'content_type': file.content_type
})
# 创建知识库记录
create_data = {
'name': data.name,
'embedding_config_id': data.embedding_config_id,
'collection_name': collection_name,
'description': data.description,
'document_count': document_count,
'vector_count': 0,
'kb_status': initial_status,
'file_paths': file_paths if file_paths else None
}
obj = await KnowledgeBaseCRUD(auth).create_crud(data=create_data)
knowledge_base_id = obj.id
# 如果有文件上传,在后台进行向量化处理
if file_contents:
# 提取 embedding 配置信息
embedding_info = None
if embedding_config:
embedding_info = {
'embedding_type': embedding_config.embedding_type,
'model_name': embedding_config.model_name,
'base_url': embedding_config.base_url,
'api_key': embedding_config.api_key,
}
# 启动后台向量化任务
asyncio.create_task(
cls._process_vectorization(
knowledge_base_id=knowledge_base_id,
collection_name=collection_name,
file_contents=file_contents,
embedding_info=embedding_info
)
)
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
@classmethod
async def _process_vectorization(
cls,
knowledge_base_id: int,
collection_name: str,
file_contents: List[dict],
embedding_info: dict | None
) -> None:
"""
后台向量化处理任务
参数:
- knowledge_base_id: 知识库ID
- collection_name: ChromaDB集合名称
- file_contents: 文件内容列表 [{'filename': str, 'content': bytes, 'content_type': str}]
- embedding_info: Embedding配置信息
"""
from langchain.schema import Document
try:
# 1. 处理文档
all_documents = []
for file_info in file_contents:
filename = file_info['filename']
content = file_info['content']
if not filename:
continue
# 根据文件类型处理
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
try:
if ext in ['txt', 'md']:
# 文本文件
text = None
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
try:
text = content.decode(encoding)
break
except UnicodeDecodeError:
continue
if text:
all_documents.append(Document(
page_content=text,
metadata={"source": filename, "type": ext}
))
elif ext == 'pdf':
# PDF 文件
try:
import pypdf
from io import BytesIO
pdf_file = BytesIO(content)
reader = pypdf.PdfReader(pdf_file)
for page_num, page in enumerate(reader.pages):
text = page.extract_text()
if text and text.strip():
all_documents.append(Document(
page_content=text,
metadata={
"source": filename,
"type": "pdf",
"page": page_num + 1
}
))
except Exception as e:
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
elif ext == 'docx':
# DOCX 文件
try:
from docx import Document as DocxDocument
from io import BytesIO
docx_file = BytesIO(content)
doc = DocxDocument(docx_file)
paragraphs = []
for para in doc.paragraphs:
if para.text.strip():
paragraphs.append(para.text)
text = '\n'.join(paragraphs)
if text.strip():
all_documents.append(Document(
page_content=text,
metadata={"source": filename, "type": "docx"}
))
except Exception as e:
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
except Exception as e:
log.error(f"处理文件失败: {filename}, 错误: {e}")
if not all_documents:
# 没有有效文档,更新状态为失败向量化处理失败
await cls._update_kb_status(
knowledge_base_id=knowledge_base_id,
kb_status=3, # 失败
error_message="没有有效的文档内容"
)
return
# 2. 创建 Embedding 工具(延迟导入)
from .tools.embedding_util import EmbeddingUtil
if embedding_info:
embedding_util = EmbeddingUtil(
embedding_type=embedding_info['embedding_type'],
model_name=embedding_info['model_name'],
base_url=embedding_info['base_url'],
api_key=embedding_info['api_key']
)
else:
# 使用默认配置
embedding_util = EmbeddingUtil(
embedding_type=1, # 远程模式
model_name='text-embedding-3-small',
base_url='https://api.openai.com/v1',
api_key='your-api-key' # 需要配置
)
# 3. 在线程池中执行向量化(因为 ChromaDB 操作是同步的)
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=1) as executor:
vector_count = await loop.run_in_executor(
executor,
embedding_util.add_documents,
collection_name,
all_documents
)
# 4. 更新知识库状态为已完成
await cls._update_kb_status(
knowledge_base_id=knowledge_base_id,
kb_status=2, # 已完成
vector_count=vector_count
)
except Exception as e:
log.error(f"向量化处理失败知识库ID: {knowledge_base_id}, 错误: {e}")
# 更新状态为失败截断错误信息到50字
error_msg = str(e)[:50] if len(str(e)) > 50 else str(e)
await cls._update_kb_status(
knowledge_base_id=knowledge_base_id,
kb_status=3, # 失败
error_message=error_msg
)
@classmethod
async def _update_kb_status(
cls,
knowledge_base_id: int,
kb_status: int,
vector_count: int = 0,
error_message: str | None = None
) -> None:
"""
更新知识库状态
参数:
- knowledge_base_id: 知识库ID
- kb_status: 状态 (0=待处理, 1=处理中, 2=已完成, 3=失败)
- vector_count: 向量数量
- error_message: 错误信息
"""
from .model import KnowledgeBaseModel
from sqlalchemy import update
try:
async with async_db_session() as session:
async with session.begin():
update_data = {
'kb_status': kb_status,
'vector_count': vector_count,
}
if error_message:
update_data['error_message'] = error_message
stmt = (
update(KnowledgeBaseModel)
.where(KnowledgeBaseModel.id == knowledge_base_id)
.values(**update_data)
)
await session.execute(stmt)
await session.commit()
except Exception as e:
log.error(f"更新知识库状态失败ID: {knowledge_base_id}, 错误: {e}")
@classmethod
async def update_service(cls, auth: AuthSchema, id: int, data: KnowledgeBaseUpdateSchema) -> dict[str, Any]:
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='更新失败,该数据不存在')
if data.name:
exist_obj = await KnowledgeBaseCRUD(auth).get_by_name_crud(name=data.name)
if exist_obj and exist_obj.id != id:
raise CustomException(msg='更新失败,知识库名称重复')
# 验证向量化配置是否存在
if data.embedding_config_id:
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=data.embedding_config_id)
if not embedding_config:
raise CustomException(msg='向量化配置不存在')
obj = await KnowledgeBaseCRUD(auth).update_crud(id=id, data=data)
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
@classmethod
async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None:
if len(ids) < 1:
raise CustomException(msg='删除失败,删除对象不能为空')
for id in ids:
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='删除失败,该数据不存在')
# 删除物理文件
if obj.file_paths:
for file_path in obj.file_paths:
try:
path = Path(file_path)
if path.exists():
path.unlink()
except Exception as e:
log.warning(f"删除文件失败: {file_path}, 错误: {e}")
# 删除知识库目录(如果为空)
try:
kb_dir = settings.UPLOAD_FILE_PATH / "knowledge_base" / obj.collection_name
if kb_dir.exists() and not any(kb_dir.iterdir()):
kb_dir.rmdir()
except Exception as e:
log.warning(f"删除知识库目录失败: {obj.collection_name}, 错误: {e}")
# 删除 ChromaDB 中的集合(延迟导入)
if obj.collection_name:
try:
from .tools.embedding_util import EmbeddingUtil
# 获取 embedding 配置
embedding_info = None
if obj.embedding_config_id:
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=obj.embedding_config_id)
if embedding_config:
embedding_info = {
'embedding_type': embedding_config.embedding_type,
'model_name': embedding_config.model_name,
'base_url': embedding_config.base_url,
'api_key': embedding_config.api_key,
}
# 创建 EmbeddingUtil 并删除集合
if embedding_info:
embedding_util = EmbeddingUtil(
embedding_type=embedding_info['embedding_type'],
model_name=embedding_info['model_name'],
base_url=embedding_info['base_url'],
api_key=embedding_info['api_key']
)
else:
embedding_util = EmbeddingUtil()
embedding_util.delete_collection(obj.collection_name)
except Exception as e:
log.warning(f"删除 ChromaDB 集合失败: {obj.collection_name}, 错误: {e}")
await KnowledgeBaseCRUD(auth).delete_crud(ids=ids)
@classmethod
async def retry_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]:
"""
重试向量化
参数:
- auth (AuthSchema): 认证信息模型
- id (int): 知识库ID
返回:
- dict[str, Any]: 知识库详情字典
"""
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
if not obj:
raise CustomException(msg='知识库不存在')
if obj.kb_status != 3:
raise CustomException(msg='只有处理失败的知识库才能重试')
if not obj.file_paths or len(obj.file_paths) == 0:
raise CustomException(msg='没有找到已保存的文件,请删除后重新创建知识库')
# 从磁盘读取文件内容
file_contents = []
for file_path in obj.file_paths:
path = Path(file_path)
if not path.exists():
log.warning(f"文件不存在: {file_path}")
continue
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
file_contents.append({
'filename': path.name,
'content': content,
'content_type': None
})
if not file_contents:
raise CustomException(msg='所有文件均已丢失,请删除后重新创建知识库')
# 获取 embedding 配置
embedding_info = None
if obj.embedding_config_id:
embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=obj.embedding_config_id)
if embedding_config:
embedding_info = {
'embedding_type': embedding_config.embedding_type,
'model_name': embedding_config.model_name,
'base_url': embedding_config.base_url,
'api_key': embedding_config.api_key,
}
# 更新状态为处理中,清空错误信息
await cls._update_kb_status(
knowledge_base_id=id,
kb_status=1, # 处理中
vector_count=0,
error_message=''
)
# 启动后台向量化任务
asyncio.create_task(
cls._process_vectorization(
knowledge_base_id=id,
collection_name=obj.collection_name,
file_contents=file_contents,
embedding_info=embedding_info
)
)
# 返回更新后的知识库信息
obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id)
return KnowledgeBaseOutSchema.model_validate(obj).model_dump()
class AIModelConfigService:
"""AI模型配置服务层"""
@classmethod
async def detail_service(cls, auth: AuthSchema, model_type: str) -> dict[str, Any]:
"""获取模型配置详情,如果不存在则创建默认配置"""
if model_type not in AI_MODEL_TYPES:
raise CustomException(msg=f'无效的模型类型: {model_type}')
obj = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
if not obj:
# 创建默认配置
create_data = {
'model_type': model_type,
'model_name': None,
'provider_id': None,
'system_prompt': None,
'temperature': 1.0,
'knowledge_base_ids': None
}
obj = await AIModelConfigCRUD(auth).create_crud(data=create_data)
return AIModelConfigOutSchema.model_validate(obj).model_dump()
@classmethod
async def list_service(cls, auth: AuthSchema, search: AIModelConfigQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]:
"""查询模型配置列表"""
search_dict = search.__dict__ if search else None
obj_list = await AIModelConfigCRUD(auth).get_list_crud(search=search_dict, order_by=order_by)
return [AIModelConfigOutSchema.model_validate(obj).model_dump() for obj in obj_list]
@classmethod
async def update_service(cls, auth: AuthSchema, model_type: str, data: AIModelConfigUpdateSchema) -> dict[str, Any]:
"""更新模型配置"""
if model_type not in AI_MODEL_TYPES:
raise CustomException(msg=f'无效的模型类型: {model_type}')
obj = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
if not obj:
# 如果不存在,先创建
create_data = {
'model_type': model_type,
'model_name': data.model_name,
'provider_id': data.provider_id,
'system_prompt': data.system_prompt,
'temperature': data.temperature if data.temperature is not None else 1.0,
'knowledge_base_ids': data.knowledge_base_ids
}
obj = await AIModelConfigCRUD(auth).create_crud(data=create_data)
else:
# 更新现有配置
update_data = {}
if data.model_name is not None:
update_data['model_name'] = data.model_name
if data.provider_id is not None:
update_data['provider_id'] = data.provider_id
if data.system_prompt is not None:
update_data['system_prompt'] = data.system_prompt
if data.temperature is not None:
update_data['temperature'] = data.temperature
if data.knowledge_base_ids is not None:
update_data['knowledge_base_ids'] = data.knowledge_base_ids
if update_data:
obj = await AIModelConfigCRUD(auth).update_crud(id=obj.id, data=update_data)
return AIModelConfigOutSchema.model_validate(obj).model_dump()
@classmethod
async def get_available_models(cls, auth: AuthSchema, provider_id: int) -> list[dict[str, Any]]:
"""根据供应商配置的baseurl和key远程获取可用模型列表"""
import httpx
provider = await AIProviderCRUD(auth).get_by_id_crud(id=provider_id)
if not provider:
raise CustomException(msg='AI供应商不存在')
# 构建请求URL兼容OpenAI标准API
base_url = provider.base_url.rstrip('/')
# 确保base_url包含/v1路径
if not base_url.endswith('/v1'):
base_url = f"{base_url}/v1"
models_url = f"{base_url}/models"
headers = {
'Authorization': f'Bearer {provider.api_key}',
'Content-Type': 'application/json'
}
try:
async with httpx.AsyncClient(timeout=30.0, verify=False) as client:
response = await client.get(models_url, headers=headers)
response.raise_for_status()
data = response.json()
# 解析OpenAI标准格式的models响应
models_data = data.get('data', [])
models = []
for model in models_data:
model_id = model.get('id', '')
# 使用id作为name如果没有单独的name字段
model_name = model.get('name', model_id)
if model_id:
models.append({'id': model_id, 'name': model_name})
# 按模型id排序
models.sort(key=lambda x: x['id'])
return models
except httpx.HTTPStatusError as e:
log.error(f"获取模型列表失败, HTTP错误: {e.response.status_code}, {e.response.text}")
if e.response.status_code == 401:
raise CustomException(msg='API Key无效或已过期')
elif e.response.status_code == 403:
raise CustomException(msg='无权限访问模型列表')
else:
raise CustomException(msg=f'获取模型列表失败: HTTP {e.response.status_code}')
except httpx.RequestError as e:
log.error(f"获取模型列表失败, 网络错误: {str(e)}")
raise CustomException(msg=f'网络请求失败: {str(e)}')
except Exception as e:
log.error(f"获取模型列表失败: {str(e)}")
raise CustomException(msg=f'获取模型列表失败: {str(e)}')
class AIModelTrainingService:
"""AI模型训练对话服务层"""
@classmethod
async def get_messages_service(cls, auth: AuthSchema, model_type: str) -> list[dict[str, Any]]:
"""获取某个模型的训练对话记录"""
if model_type not in AI_MODEL_TYPES:
raise CustomException(msg=f'无效的模型类型: {model_type}')
# 获取模型配置
config = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
if not config:
return []
messages = await AIModelTrainingMessageCRUD(auth).get_list_by_config_crud(model_config_id=config.id)
return [AIModelTrainingMessageOutSchema.model_validate(msg).model_dump() for msg in messages]
@classmethod
async def delete_message_service(cls, auth: AuthSchema, message_id: int) -> None:
"""删除单条训练对话记录"""
msg = await AIModelTrainingMessageCRUD(auth).get_by_id_crud(id=message_id)
if not msg:
raise CustomException(msg='对话记录不存在')
await AIModelTrainingMessageCRUD(auth).delete_crud(ids=[message_id])
@classmethod
async def clear_messages_service(cls, auth: AuthSchema, model_type: str) -> None:
"""清空某个模型的所有训练对话记录"""
if model_type not in AI_MODEL_TYPES:
raise CustomException(msg=f'无效的模型类型: {model_type}')
config = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type)
if config:
await AIModelTrainingMessageCRUD(auth).delete_by_config_crud(model_config_id=config.id)
@classmethod
async def chat_stream(cls, auth: AuthSchema, chat_data: AIModelTrainingChatSchema) -> AsyncGenerator[str, Any]:
"""
训练对话流式输出
注意:由于 StreamingResponse 的特性,依赖注入的数据库会话在流式响应期间会被关闭,
因此需要使用独立的数据库会话来保存消息。
重要yield 不能在 async with session.begin() 块内使用,否则会导致生成器无法正确执行
"""
from app.core.database import async_db_session
model_type = chat_data.model_type
if model_type not in AI_MODEL_TYPES:
yield f'无效的模型类型: {model_type}'
return
# 用于存储错误消息
error_message = None
# 用于存储AI调用所需的配置信息
config_id = None
config_system_prompt = None
config_model_name = None
config_temperature = None
provider_api_key = None
provider_base_url = None
messages = []
# 1. 使用独立的数据库会话处理所有数据库操作不在此块内yield
async with async_db_session() as session:
async with session.begin():
stream_auth = AuthSchema(db=session, user=auth.user, check_data_scope=False)
# 如果配置有变动,先保存配置
if chat_data.config_changed and chat_data.config_data:
await AIModelConfigService.update_service(auth=stream_auth, model_type=model_type, data=chat_data.config_data)
# 获取模型配置
config = await AIModelConfigCRUD(stream_auth).get_by_model_type_crud(model_type=model_type)
if not config:
error_message = '模型配置不存在,请先配置模型'
else:
# 保存用户消息
user_msg = await AIModelTrainingMessageCRUD(stream_auth).create_crud(data={
'model_config_id': config.id,
'role': 'user',
'content': chat_data.message
})
# 获取供应商配置
provider = None
if config.provider_id:
provider = await AIProviderCRUD(stream_auth).get_by_id_crud(id=config.provider_id)
if not provider:
providers = await AIProviderCRUD(stream_auth).get_list_crud(search={'is_default': 1})
if providers:
provider = providers[0]
if not provider:
error_message = '未配置AI供应商请先在AI配置中添加供应商'
else:
# 获取历史对话记录
history_messages = await AIModelTrainingMessageCRUD(stream_auth).get_list_by_config_crud(model_config_id=config.id)
# 提取需要的信息
config_id = config.id
config_system_prompt = config.system_prompt
config_model_name = config.model_name
config_temperature = config.temperature
provider_api_key = provider.api_key
provider_base_url = provider.base_url
# 构建消息列表
if config_system_prompt:
messages.append({"role": "system", "content": config_system_prompt})
for msg in history_messages:
if msg.id != user_msg.id:
messages.append({"role": msg.role, "content": msg.content})
messages.append({"role": "user", "content": chat_data.message})
# 提交事务
await session.commit()
# 2. 如果有错误yield错误消息并返回
if error_message:
log.warning(f"[训练对话] 配置错误: {error_message}")
yield error_message
return
# 检查配置是否完整
if not provider_api_key or not provider_base_url:
log.error("[训练对话] 供应商配置不完整")
yield "供应商配置不完整请检查API Key和Base URL"
return
# 3. 创建AI客户端并流式输出
import httpx
from openai import AsyncOpenAI
# 确保base_url包含/v1路径
if provider_base_url and not provider_base_url.rstrip('/').endswith('/v1'):
provider_base_url = provider_base_url.rstrip('/') + '/v1'
log.info(f"[训练对话] 开始调用AI模型: {config_model_name}, 消息数: {len(messages)}")
log.info(f"[训练对话] API配置: base_url={provider_base_url}")
http_client = httpx.AsyncClient(
timeout=60.0,
follow_redirects=True,
verify=False # 禁用 SSL 验证
)
client = AsyncOpenAI(
api_key=provider_api_key,
base_url=provider_base_url,
http_client=http_client
)
full_response = ""
is_cancelled = False
response = None
try:
response = await client.chat.completions.create(
model=config_model_name or 'gpt-3.5-turbo',
messages=messages,
temperature=config_temperature,
stream=True
)
log.info(f"[训练对话] AI响应开始流式输出")
async for chunk in response:
# 处理不同的响应格式
content = None
if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
content = delta.content
if content:
full_response += content
yield content
log.info(f"[训练对话] AI响应完成总长度: {len(full_response)}")
except asyncio.CancelledError:
# 客户端断开连接停止AI输出
is_cancelled = True
log.info(f"[训练对话] 客户端断开连接停止AI输出已生成长度: {len(full_response)}")
# 关闭AI流式响应
if response:
try:
await response.close()
except Exception:
pass
raise # 重新抛出以便 FastAPI 知道请求已取消
except GeneratorExit:
# 生成器被关闭,通常是客户端断开
is_cancelled = True
log.info(f"[训练对话] 生成器关闭停止AI输出已生成长度: {len(full_response)}")
if response:
try:
await response.close()
except Exception:
pass
raise
except Exception as e:
log.error(f"AI训练对话失败: {str(e)}")
import traceback
log.error(f"AI训练对话异常详情: {traceback.format_exc()}")
error_msg = f"对话失败: {str(e)}"
full_response = error_msg
yield error_msg
finally:
await http_client.aclose()
# 4. 保存AI响应
if full_response and config_id:
async with async_db_session() as session:
async with session.begin():
save_auth = AuthSchema(db=session, user=auth.user, check_data_scope=False)
await AIModelTrainingMessageCRUD(save_auth).create_crud(data={
'model_config_id': config_id,
'role': 'assistant',
'content': full_response
})
await session.commit()
class AIModelTestService:
"""起名测试服务层(用于小程序/外部调用,仅读不写)"""
@classmethod
async def test_naming(cls, model_type: str, text: str) -> str:
"""
起名测试接口
- 读取模型配置和训练信息
- 内部流式调用AI
- 组合响应后一次性返回
- 仅读数据,不保存对话记录
参数:
model_type: 模型类型(enterprise_naming/personal_naming等)
text: 用户输入的文本
返回:
AI的完整响应文本
"""
from app.core.database import async_db_session
from app.api.v1.module_application.ai.schema import AI_MODEL_TYPES
if model_type not in AI_MODEL_TYPES:
raise CustomException(msg=f'无效的模型类型: {model_type}')
# 用于存储AI调用所需的配置信息
config_system_prompt = None
config_model_name = None
config_temperature = None
provider_api_key = None
provider_base_url = None
messages = []
# 1. 读取模型配置和训练记录(只读,不保存任何数据)
async with async_db_session() as session:
async with session.begin():
from app.api.v1.module_system.auth.schema import AuthSchema
read_auth = AuthSchema(db=session, user=None, check_data_scope=False)
# 获取模型配置
config = await AIModelConfigCRUD(read_auth).get_by_model_type_crud(model_type=model_type)
if not config:
raise CustomException(msg='模型配置不存在,请先配置模型')
# 获取供应商配置
provider = None
if config.provider_id:
provider = await AIProviderCRUD(read_auth).get_by_id_crud(id=config.provider_id)
if not provider:
providers = await AIProviderCRUD(read_auth).get_list_crud(search={'is_default': 1})
if providers:
provider = providers[0]
if not provider:
raise CustomException(msg='未配置AI供应商请先在AI配置中添加供应商')
# 获取历史训练对话记录(作为上下文)
history_messages = await AIModelTrainingMessageCRUD(read_auth).get_list_by_config_crud(model_config_id=config.id)
# 提取配置信息
config_system_prompt = config.system_prompt
config_model_name = config.model_name
config_temperature = config.temperature
provider_api_key = provider.api_key
provider_base_url = provider.base_url
# 构建消息列表
if config_system_prompt:
messages.append({"role": "system", "content": config_system_prompt})
# 添加历史训练对话作为上下文
for msg in history_messages:
messages.append({"role": msg.role, "content": msg.content})
# 添加当前用户输入
messages.append({"role": "user", "content": text})
# 检查配置是否完整
if not provider_api_key or not provider_base_url:
raise CustomException(msg="供应商配置不完整请检查API Key和Base URL")
# 2. 调用AI内部流式收集后一次性返回
import httpx
import asyncio
from openai import AsyncOpenAI
# 确保base_url包含/v1路径
if provider_base_url and not provider_base_url.rstrip('/').endswith('/v1'):
provider_base_url = provider_base_url.rstrip('/') + '/v1'
log.info(f"[起名测试] 开始调用AI模型类型: {model_type}, 模型: {config_model_name}, 消息数: {len(messages)}")
def _extract_status_code(err: Exception) -> int | None:
if hasattr(err, "status_code") and isinstance(getattr(err, "status_code"), int):
return getattr(err, "status_code")
response = getattr(err, "response", None)
if response is not None and hasattr(response, "status_code"):
return response.status_code
return None
def _extract_error_message(err: Exception) -> str:
body = getattr(err, "body", None)
if isinstance(body, dict) and "error" in body:
inner = body.get("error") or {}
msg = inner.get("message")
if isinstance(msg, str) and msg.strip():
return msg.strip()
return str(err)
retry_times = 3
base_sleep = 0.8
last_err: Exception | None = None
for attempt in range(1, retry_times + 1):
http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
verify=False
)
client = AsyncOpenAI(
api_key=provider_api_key,
base_url=provider_base_url,
http_client=http_client
)
full_response = ""
try:
response = await client.chat.completions.create(
model=config_model_name or 'gpt-3.5-turbo',
messages=messages,
temperature=config_temperature,
stream=True
)
async for chunk in response:
if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
full_response += delta.content
log.info(f"[起名测试] AI响应完成总长度: {len(full_response)}")
return full_response
except Exception as e:
last_err = e
status_code = _extract_status_code(e)
err_msg = _extract_error_message(e)
transient = status_code in {408, 429, 500, 502, 503, 504} or isinstance(e, (httpx.TimeoutException, httpx.NetworkError))
if not transient or attempt >= retry_times:
log.error(f"[起名测试] AI调用失败(终止): {str(e)}")
import traceback
log.error(f"[起名测试] 异常详情: {traceback.format_exc()}")
raise CustomException(msg=f"服务暂时不可用,请稍后重试。({err_msg})", status_code=503)
sleep_s = base_sleep * (2 ** (attempt - 1))
log.warning(f"[起名测试] AI调用失败(第{attempt}次/{retry_times}次){sleep_s:.2f}s后重试: {err_msg}")
await asyncio.sleep(sleep_s)
finally:
await http_client.aclose()
raise CustomException(msg=f"服务暂时不可用,请稍后重试。({_extract_error_message(last_err) if last_err else ''})", status_code=503)