# -*- coding: utf-8 -*- import re import asyncio import aiofiles from pathlib import Path from typing import Any, AsyncGenerator, List from concurrent.futures import ThreadPoolExecutor from fastapi import UploadFile from app.core.exceptions import CustomException from app.core.logger import log from app.core.database import async_db_session from app.config.setting import settings from app.api.v1.module_system.auth.schema import AuthSchema from .tools.ai_util import AIClient # 延迟导入向量化工具,避免影响路由加载 # from .tools.document_processor import DocumentProcessor # from .tools.embedding_util import EmbeddingUtil from .schema import ( McpCreateSchema, McpUpdateSchema, McpOutSchema, ChatQuerySchema, McpQueryParam, AIProviderCreateSchema, AIProviderUpdateSchema, AIProviderOutSchema, AIProviderQueryParam, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, EmbeddingConfigOutSchema, EmbeddingConfigQueryParam, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, KnowledgeBaseOutSchema, KnowledgeBaseQueryParam, AIModelConfigCreateSchema, AIModelConfigUpdateSchema, AIModelConfigOutSchema, AIModelConfigQueryParam, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageOutSchema, AIModelTrainingChatSchema, AI_MODEL_TYPES ) from .crud import McpCRUD, AIProviderCRUD, EmbeddingConfigCRUD, KnowledgeBaseCRUD, AIModelConfigCRUD, AIModelTrainingMessageCRUD from .model import EmbeddingConfigModel, AIModelConfigModel class McpService: """MCP服务层""" @classmethod async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]: """ 获取MCP服务器详情 参数: - auth (AuthSchema): 认证信息模型 - id (int): MCP服务器ID 返回: - dict[str, Any]: MCP服务器详情字典 """ obj = await McpCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='MCP 服务器不存在') return McpOutSchema.model_validate(obj).model_dump() @classmethod async def list_service(cls, auth: AuthSchema, search: McpQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]: """ 列表查询MCP服务器 参数: - auth (AuthSchema): 认证信息模型 - search (McpQueryParam | None): 查询参数模型 - order_by (list[dict[str, str]] | None): 排序参数列表 返回: - list[dict[str, Any]]: MCP服务器详情字典列表 """ search_dict = search.__dict__ if search else None obj_list = await McpCRUD(auth).get_list_crud(search=search_dict, order_by=order_by) return [McpOutSchema.model_validate(obj).model_dump() for obj in obj_list] @classmethod async def create_service(cls, auth: AuthSchema, data: McpCreateSchema) -> dict[str, Any]: """ 创建MCP服务器 参数: - auth (AuthSchema): 认证信息模型 - data (McpCreateSchema): 创建MCP服务器模型 返回: - dict[str, Any]: 创建的MCP服务器详情字典 """ obj = await McpCRUD(auth).get_by_name_crud(name=data.name) if obj: raise CustomException(msg='创建失败,MCP 服务器已存在') obj = await McpCRUD(auth).create_crud(data=data) return McpOutSchema.model_validate(obj).model_dump() @classmethod async def update_service(cls, auth: AuthSchema, id: int, data: McpUpdateSchema) -> dict[str, Any]: """ 更新MCP服务器 参数: - auth (AuthSchema): 认证信息模型 - id (int): MCP服务器ID - data (McpUpdateSchema): 更新MCP服务器模型 返回: - dict[str, Any]: 更新的MCP服务器详情字典 """ obj = await McpCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='更新失败,该数据不存在') exist_obj = await McpCRUD(auth).get_by_name_crud(name=data.name) if exist_obj and exist_obj.id != id: raise CustomException(msg='更新失败,MCP 服务器名称重复') obj = await McpCRUD(auth).update_crud(id=id, data=data) return McpOutSchema.model_validate(obj).model_dump() @classmethod async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None: """ 批量删除MCP服务器 参数: - auth (AuthSchema): 认证信息模型 - ids (list[int]): MCP服务器ID列表 返回: - None """ if len(ids) < 1: raise CustomException(msg='删除失败,删除对象不能为空') for id in ids: obj = await McpCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='删除失败,该数据不存在') await McpCRUD(auth).delete_crud(ids=ids) @classmethod async def chat_query(cls, query: ChatQuerySchema) -> AsyncGenerator[str, Any]: """ 处理聊天查询 参数: - query (ChatQuerySchema): 聊天查询模型 返回: - AsyncGenerator[str, None]: 异步生成器,每次返回一个聊天响应 """ # 创建MCP客户端实例 mcp_client = AIClient() try: # 处理消息 async for response in mcp_client.process(query.message): yield response finally: # 确保关闭客户端连接 await mcp_client.close() class AIProviderService: """AI供应商服务层""" @classmethod async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]: obj = await AIProviderCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='AI供应商不存在') return AIProviderOutSchema.model_validate(obj).model_dump() @classmethod async def list_service(cls, auth: AuthSchema, search: AIProviderQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]: search_dict = search.__dict__ if search else None obj_list = await AIProviderCRUD(auth).get_list_crud(search=search_dict, order_by=order_by) return [AIProviderOutSchema.model_validate(obj).model_dump() for obj in obj_list] @classmethod async def create_service(cls, auth: AuthSchema, data: AIProviderCreateSchema) -> dict[str, Any]: obj = await AIProviderCRUD(auth).get_by_name_crud(name=data.name) if obj: raise CustomException(msg='创建失败,AI供应商名称已存在') # 如果设置为默认,先清除其他默认 if data.is_default == 1: await AIProviderCRUD(auth).clear_default_crud() obj = await AIProviderCRUD(auth).create_crud(data=data) return AIProviderOutSchema.model_validate(obj).model_dump() @classmethod async def update_service(cls, auth: AuthSchema, id: int, data: AIProviderUpdateSchema) -> dict[str, Any]: obj = await AIProviderCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='更新失败,该数据不存在') exist_obj = await AIProviderCRUD(auth).get_by_name_crud(name=data.name) if exist_obj and exist_obj.id != id: raise CustomException(msg='更新失败,AI供应商名称重复') # 如果设置为默认,先清除其他默认 if data.is_default == 1: await AIProviderCRUD(auth).clear_default_crud() obj = await AIProviderCRUD(auth).update_crud(id=id, data=data) return AIProviderOutSchema.model_validate(obj).model_dump() @classmethod async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None: if len(ids) < 1: raise CustomException(msg='删除失败,删除对象不能为空') for id in ids: obj = await AIProviderCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='删除失败,该数据不存在') await AIProviderCRUD(auth).delete_crud(ids=ids) class EmbeddingConfigService: """向量化配置服务层""" @classmethod async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]: obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='向量化配置不存在') return EmbeddingConfigOutSchema.model_validate(obj).model_dump() @classmethod async def list_service(cls, auth: AuthSchema, search: EmbeddingConfigQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]: search_dict = search.__dict__ if search else None obj_list = await EmbeddingConfigCRUD(auth).get_list_crud(search=search_dict, order_by=order_by) return [EmbeddingConfigOutSchema.model_validate(obj).model_dump() for obj in obj_list] @classmethod async def create_service(cls, auth: AuthSchema, data: EmbeddingConfigCreateSchema) -> dict[str, Any]: obj = await EmbeddingConfigCRUD(auth).get_by_name_crud(name=data.name) if obj: raise CustomException(msg='创建失败,向量化配置名称已存在') # 远程模式验证 if data.embedding_type == 1: if not data.base_url or not data.api_key: raise CustomException(msg='远程模式必须提供接口地址和API Key') # 如果设置为默认,先清除其他默认 if data.is_default == 1: await EmbeddingConfigCRUD(auth).clear_default_crud() obj = await EmbeddingConfigCRUD(auth).create_crud(data=data) return EmbeddingConfigOutSchema.model_validate(obj).model_dump() @classmethod async def update_service(cls, auth: AuthSchema, id: int, data: EmbeddingConfigUpdateSchema) -> dict[str, Any]: obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='更新失败,该数据不存在') exist_obj = await EmbeddingConfigCRUD(auth).get_by_name_crud(name=data.name) if exist_obj and exist_obj.id != id: raise CustomException(msg='更新失败,向量化配置名称重复') # 远程模式验证 if data.embedding_type == 1: if not data.base_url or not data.api_key: raise CustomException(msg='远程模式必须提供接口地址和API Key') # 如果设置为默认,先清除其他默认 if data.is_default == 1: await EmbeddingConfigCRUD(auth).clear_default_crud() obj = await EmbeddingConfigCRUD(auth).update_crud(id=id, data=data) return EmbeddingConfigOutSchema.model_validate(obj).model_dump() @classmethod async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None: if len(ids) < 1: raise CustomException(msg='删除失败,删除对象不能为空') for id in ids: obj = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='删除失败,该数据不存在') await EmbeddingConfigCRUD(auth).delete_crud(ids=ids) class KnowledgeBaseService: """知识库服务层""" @staticmethod def _generate_collection_name(name: str) -> str: """根据知识库名称生成ChromaDB集合名称""" import time # 移除特殊字符,只保留字母数字下划线 clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) timestamp = int(time.time()) return f"kb_{clean_name}_{timestamp}" @classmethod async def detail_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]: obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='知识库不存在') return KnowledgeBaseOutSchema.model_validate(obj).model_dump() @classmethod async def list_service(cls, auth: AuthSchema, search: KnowledgeBaseQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]: search_dict = search.__dict__ if search else None obj_list = await KnowledgeBaseCRUD(auth).get_list_crud(search=search_dict, order_by=order_by) return [KnowledgeBaseOutSchema.model_validate(obj).model_dump() for obj in obj_list] @classmethod async def create_service(cls, auth: AuthSchema, data: KnowledgeBaseCreateSchema, files: List[UploadFile] | None = None) -> dict[str, Any]: obj = await KnowledgeBaseCRUD(auth).get_by_name_crud(name=data.name) if obj: raise CustomException(msg='创建失败,知识库名称已存在') # 验证向量化配置是否存在 embedding_config = None if data.embedding_config_id: embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=data.embedding_config_id) if not embedding_config: raise CustomException(msg='向量化配置不存在') # 生成集合名称 collection_name = cls._generate_collection_name(data.name) # 计算文档数量 document_count = len(files) if files else 0 # 创建知识库记录,初始状态为处理中(如果有文件)或已完成(如果没有文件) initial_status = 1 if files and len(files) > 0 else 2 # 1=处理中, 2=已完成 # 保存文件到磁盘 file_paths = [] file_contents = [] if files and len(files) > 0: # 创建知识库专用目录 kb_dir = settings.UPLOAD_FILE_PATH / "knowledge_base" / collection_name kb_dir.mkdir(parents=True, exist_ok=True) for file in files: if not file.filename: continue # 读取文件内容 content = await file.read() # 保存文件到磁盘 file_path = kb_dir / file.filename async with aiofiles.open(file_path, 'wb') as f: await f.write(content) file_paths.append(str(file_path)) file_contents.append({ 'filename': file.filename, 'content': content, 'content_type': file.content_type }) # 创建知识库记录 create_data = { 'name': data.name, 'embedding_config_id': data.embedding_config_id, 'collection_name': collection_name, 'description': data.description, 'document_count': document_count, 'vector_count': 0, 'kb_status': initial_status, 'file_paths': file_paths if file_paths else None } obj = await KnowledgeBaseCRUD(auth).create_crud(data=create_data) knowledge_base_id = obj.id # 如果有文件上传,在后台进行向量化处理 if file_contents: # 提取 embedding 配置信息 embedding_info = None if embedding_config: embedding_info = { 'embedding_type': embedding_config.embedding_type, 'model_name': embedding_config.model_name, 'base_url': embedding_config.base_url, 'api_key': embedding_config.api_key, } # 启动后台向量化任务 asyncio.create_task( cls._process_vectorization( knowledge_base_id=knowledge_base_id, collection_name=collection_name, file_contents=file_contents, embedding_info=embedding_info ) ) return KnowledgeBaseOutSchema.model_validate(obj).model_dump() @classmethod async def _process_vectorization( cls, knowledge_base_id: int, collection_name: str, file_contents: List[dict], embedding_info: dict | None ) -> None: """ 后台向量化处理任务 参数: - knowledge_base_id: 知识库ID - collection_name: ChromaDB集合名称 - file_contents: 文件内容列表 [{'filename': str, 'content': bytes, 'content_type': str}] - embedding_info: Embedding配置信息 """ from langchain.schema import Document try: # 1. 处理文档 all_documents = [] for file_info in file_contents: filename = file_info['filename'] content = file_info['content'] if not filename: continue # 根据文件类型处理 ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else '' try: if ext in ['txt', 'md']: # 文本文件 text = None for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: try: text = content.decode(encoding) break except UnicodeDecodeError: continue if text: all_documents.append(Document( page_content=text, metadata={"source": filename, "type": ext} )) elif ext == 'pdf': # PDF 文件 try: import pypdf from io import BytesIO pdf_file = BytesIO(content) reader = pypdf.PdfReader(pdf_file) for page_num, page in enumerate(reader.pages): text = page.extract_text() if text and text.strip(): all_documents.append(Document( page_content=text, metadata={ "source": filename, "type": "pdf", "page": page_num + 1 } )) except Exception as e: log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}") elif ext == 'docx': # DOCX 文件 try: from docx import Document as DocxDocument from io import BytesIO docx_file = BytesIO(content) doc = DocxDocument(docx_file) paragraphs = [] for para in doc.paragraphs: if para.text.strip(): paragraphs.append(para.text) text = '\n'.join(paragraphs) if text.strip(): all_documents.append(Document( page_content=text, metadata={"source": filename, "type": "docx"} )) except Exception as e: log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}") except Exception as e: log.error(f"处理文件失败: {filename}, 错误: {e}") if not all_documents: # 没有有效文档,更新状态为失败向量化处理失败 await cls._update_kb_status( knowledge_base_id=knowledge_base_id, kb_status=3, # 失败 error_message="没有有效的文档内容" ) return # 2. 创建 Embedding 工具(延迟导入) from .tools.embedding_util import EmbeddingUtil if embedding_info: embedding_util = EmbeddingUtil( embedding_type=embedding_info['embedding_type'], model_name=embedding_info['model_name'], base_url=embedding_info['base_url'], api_key=embedding_info['api_key'] ) else: # 使用默认配置 embedding_util = EmbeddingUtil( embedding_type=1, # 远程模式 model_name='text-embedding-3-small', base_url='https://api.openai.com/v1', api_key='your-api-key' # 需要配置 ) # 3. 在线程池中执行向量化(因为 ChromaDB 操作是同步的) loop = asyncio.get_event_loop() with ThreadPoolExecutor(max_workers=1) as executor: vector_count = await loop.run_in_executor( executor, embedding_util.add_documents, collection_name, all_documents ) # 4. 更新知识库状态为已完成 await cls._update_kb_status( knowledge_base_id=knowledge_base_id, kb_status=2, # 已完成 vector_count=vector_count ) except Exception as e: log.error(f"向量化处理失败,知识库ID: {knowledge_base_id}, 错误: {e}") # 更新状态为失败,截断错误信息到50字 error_msg = str(e)[:50] if len(str(e)) > 50 else str(e) await cls._update_kb_status( knowledge_base_id=knowledge_base_id, kb_status=3, # 失败 error_message=error_msg ) @classmethod async def _update_kb_status( cls, knowledge_base_id: int, kb_status: int, vector_count: int = 0, error_message: str | None = None ) -> None: """ 更新知识库状态 参数: - knowledge_base_id: 知识库ID - kb_status: 状态 (0=待处理, 1=处理中, 2=已完成, 3=失败) - vector_count: 向量数量 - error_message: 错误信息 """ from .model import KnowledgeBaseModel from sqlalchemy import update try: async with async_db_session() as session: async with session.begin(): update_data = { 'kb_status': kb_status, 'vector_count': vector_count, } if error_message: update_data['error_message'] = error_message stmt = ( update(KnowledgeBaseModel) .where(KnowledgeBaseModel.id == knowledge_base_id) .values(**update_data) ) await session.execute(stmt) await session.commit() except Exception as e: log.error(f"更新知识库状态失败,ID: {knowledge_base_id}, 错误: {e}") @classmethod async def update_service(cls, auth: AuthSchema, id: int, data: KnowledgeBaseUpdateSchema) -> dict[str, Any]: obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='更新失败,该数据不存在') if data.name: exist_obj = await KnowledgeBaseCRUD(auth).get_by_name_crud(name=data.name) if exist_obj and exist_obj.id != id: raise CustomException(msg='更新失败,知识库名称重复') # 验证向量化配置是否存在 if data.embedding_config_id: embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=data.embedding_config_id) if not embedding_config: raise CustomException(msg='向量化配置不存在') obj = await KnowledgeBaseCRUD(auth).update_crud(id=id, data=data) return KnowledgeBaseOutSchema.model_validate(obj).model_dump() @classmethod async def delete_service(cls, auth: AuthSchema, ids: list[int]) -> None: if len(ids) < 1: raise CustomException(msg='删除失败,删除对象不能为空') for id in ids: obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='删除失败,该数据不存在') # 删除物理文件 if obj.file_paths: for file_path in obj.file_paths: try: path = Path(file_path) if path.exists(): path.unlink() except Exception as e: log.warning(f"删除文件失败: {file_path}, 错误: {e}") # 删除知识库目录(如果为空) try: kb_dir = settings.UPLOAD_FILE_PATH / "knowledge_base" / obj.collection_name if kb_dir.exists() and not any(kb_dir.iterdir()): kb_dir.rmdir() except Exception as e: log.warning(f"删除知识库目录失败: {obj.collection_name}, 错误: {e}") # 删除 ChromaDB 中的集合(延迟导入) if obj.collection_name: try: from .tools.embedding_util import EmbeddingUtil # 获取 embedding 配置 embedding_info = None if obj.embedding_config_id: embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=obj.embedding_config_id) if embedding_config: embedding_info = { 'embedding_type': embedding_config.embedding_type, 'model_name': embedding_config.model_name, 'base_url': embedding_config.base_url, 'api_key': embedding_config.api_key, } # 创建 EmbeddingUtil 并删除集合 if embedding_info: embedding_util = EmbeddingUtil( embedding_type=embedding_info['embedding_type'], model_name=embedding_info['model_name'], base_url=embedding_info['base_url'], api_key=embedding_info['api_key'] ) else: embedding_util = EmbeddingUtil() embedding_util.delete_collection(obj.collection_name) except Exception as e: log.warning(f"删除 ChromaDB 集合失败: {obj.collection_name}, 错误: {e}") await KnowledgeBaseCRUD(auth).delete_crud(ids=ids) @classmethod async def retry_service(cls, auth: AuthSchema, id: int) -> dict[str, Any]: """ 重试向量化 参数: - auth (AuthSchema): 认证信息模型 - id (int): 知识库ID 返回: - dict[str, Any]: 知识库详情字典 """ obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id) if not obj: raise CustomException(msg='知识库不存在') if obj.kb_status != 3: raise CustomException(msg='只有处理失败的知识库才能重试') if not obj.file_paths or len(obj.file_paths) == 0: raise CustomException(msg='没有找到已保存的文件,请删除后重新创建知识库') # 从磁盘读取文件内容 file_contents = [] for file_path in obj.file_paths: path = Path(file_path) if not path.exists(): log.warning(f"文件不存在: {file_path}") continue async with aiofiles.open(path, 'rb') as f: content = await f.read() file_contents.append({ 'filename': path.name, 'content': content, 'content_type': None }) if not file_contents: raise CustomException(msg='所有文件均已丢失,请删除后重新创建知识库') # 获取 embedding 配置 embedding_info = None if obj.embedding_config_id: embedding_config = await EmbeddingConfigCRUD(auth).get_by_id_crud(id=obj.embedding_config_id) if embedding_config: embedding_info = { 'embedding_type': embedding_config.embedding_type, 'model_name': embedding_config.model_name, 'base_url': embedding_config.base_url, 'api_key': embedding_config.api_key, } # 更新状态为处理中,清空错误信息 await cls._update_kb_status( knowledge_base_id=id, kb_status=1, # 处理中 vector_count=0, error_message='' ) # 启动后台向量化任务 asyncio.create_task( cls._process_vectorization( knowledge_base_id=id, collection_name=obj.collection_name, file_contents=file_contents, embedding_info=embedding_info ) ) # 返回更新后的知识库信息 obj = await KnowledgeBaseCRUD(auth).get_by_id_crud(id=id) return KnowledgeBaseOutSchema.model_validate(obj).model_dump() class AIModelConfigService: """AI模型配置服务层""" @classmethod async def detail_service(cls, auth: AuthSchema, model_type: str) -> dict[str, Any]: """获取模型配置详情,如果不存在则创建默认配置""" if model_type not in AI_MODEL_TYPES: raise CustomException(msg=f'无效的模型类型: {model_type}') obj = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type) if not obj: # 创建默认配置 create_data = { 'model_type': model_type, 'model_name': None, 'provider_id': None, 'system_prompt': None, 'temperature': 1.0, 'knowledge_base_ids': None } obj = await AIModelConfigCRUD(auth).create_crud(data=create_data) return AIModelConfigOutSchema.model_validate(obj).model_dump() @classmethod async def list_service(cls, auth: AuthSchema, search: AIModelConfigQueryParam | None = None, order_by: list[dict[str, str]] | None = None) -> list[dict[str, Any]]: """查询模型配置列表""" search_dict = search.__dict__ if search else None obj_list = await AIModelConfigCRUD(auth).get_list_crud(search=search_dict, order_by=order_by) return [AIModelConfigOutSchema.model_validate(obj).model_dump() for obj in obj_list] @classmethod async def update_service(cls, auth: AuthSchema, model_type: str, data: AIModelConfigUpdateSchema) -> dict[str, Any]: """更新模型配置""" if model_type not in AI_MODEL_TYPES: raise CustomException(msg=f'无效的模型类型: {model_type}') obj = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type) if not obj: # 如果不存在,先创建 create_data = { 'model_type': model_type, 'model_name': data.model_name, 'provider_id': data.provider_id, 'system_prompt': data.system_prompt, 'temperature': data.temperature if data.temperature is not None else 1.0, 'knowledge_base_ids': data.knowledge_base_ids } obj = await AIModelConfigCRUD(auth).create_crud(data=create_data) else: # 更新现有配置 update_data = {} if data.model_name is not None: update_data['model_name'] = data.model_name if data.provider_id is not None: update_data['provider_id'] = data.provider_id if data.system_prompt is not None: update_data['system_prompt'] = data.system_prompt if data.temperature is not None: update_data['temperature'] = data.temperature if data.knowledge_base_ids is not None: update_data['knowledge_base_ids'] = data.knowledge_base_ids if update_data: obj = await AIModelConfigCRUD(auth).update_crud(id=obj.id, data=update_data) return AIModelConfigOutSchema.model_validate(obj).model_dump() @classmethod async def get_available_models(cls, auth: AuthSchema, provider_id: int) -> list[dict[str, Any]]: """根据供应商配置的baseurl和key远程获取可用模型列表""" import httpx provider = await AIProviderCRUD(auth).get_by_id_crud(id=provider_id) if not provider: raise CustomException(msg='AI供应商不存在') # 构建请求URL,兼容OpenAI标准API base_url = provider.base_url.rstrip('/') # 确保base_url包含/v1路径 if not base_url.endswith('/v1'): base_url = f"{base_url}/v1" models_url = f"{base_url}/models" headers = { 'Authorization': f'Bearer {provider.api_key}', 'Content-Type': 'application/json' } try: async with httpx.AsyncClient(timeout=30.0, verify=False) as client: response = await client.get(models_url, headers=headers) response.raise_for_status() data = response.json() # 解析OpenAI标准格式的models响应 models_data = data.get('data', []) models = [] for model in models_data: model_id = model.get('id', '') # 使用id作为name,如果没有单独的name字段 model_name = model.get('name', model_id) if model_id: models.append({'id': model_id, 'name': model_name}) # 按模型id排序 models.sort(key=lambda x: x['id']) return models except httpx.HTTPStatusError as e: log.error(f"获取模型列表失败, HTTP错误: {e.response.status_code}, {e.response.text}") if e.response.status_code == 401: raise CustomException(msg='API Key无效或已过期') elif e.response.status_code == 403: raise CustomException(msg='无权限访问模型列表') else: raise CustomException(msg=f'获取模型列表失败: HTTP {e.response.status_code}') except httpx.RequestError as e: log.error(f"获取模型列表失败, 网络错误: {str(e)}") raise CustomException(msg=f'网络请求失败: {str(e)}') except Exception as e: log.error(f"获取模型列表失败: {str(e)}") raise CustomException(msg=f'获取模型列表失败: {str(e)}') class AIModelTrainingService: """AI模型训练对话服务层""" @classmethod async def get_messages_service(cls, auth: AuthSchema, model_type: str) -> list[dict[str, Any]]: """获取某个模型的训练对话记录""" if model_type not in AI_MODEL_TYPES: raise CustomException(msg=f'无效的模型类型: {model_type}') # 获取模型配置 config = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type) if not config: return [] messages = await AIModelTrainingMessageCRUD(auth).get_list_by_config_crud(model_config_id=config.id) return [AIModelTrainingMessageOutSchema.model_validate(msg).model_dump() for msg in messages] @classmethod async def delete_message_service(cls, auth: AuthSchema, message_id: int) -> None: """删除单条训练对话记录""" msg = await AIModelTrainingMessageCRUD(auth).get_by_id_crud(id=message_id) if not msg: raise CustomException(msg='对话记录不存在') await AIModelTrainingMessageCRUD(auth).delete_crud(ids=[message_id]) @classmethod async def clear_messages_service(cls, auth: AuthSchema, model_type: str) -> None: """清空某个模型的所有训练对话记录""" if model_type not in AI_MODEL_TYPES: raise CustomException(msg=f'无效的模型类型: {model_type}') config = await AIModelConfigCRUD(auth).get_by_model_type_crud(model_type=model_type) if config: await AIModelTrainingMessageCRUD(auth).delete_by_config_crud(model_config_id=config.id) @classmethod async def chat_stream(cls, auth: AuthSchema, chat_data: AIModelTrainingChatSchema) -> AsyncGenerator[str, Any]: """ 训练对话流式输出 注意:由于 StreamingResponse 的特性,依赖注入的数据库会话在流式响应期间会被关闭, 因此需要使用独立的数据库会话来保存消息。 重要:yield 不能在 async with session.begin() 块内使用,否则会导致生成器无法正确执行 """ from app.core.database import async_db_session model_type = chat_data.model_type if model_type not in AI_MODEL_TYPES: yield f'无效的模型类型: {model_type}' return # 用于存储错误消息 error_message = None # 用于存储AI调用所需的配置信息 config_id = None config_system_prompt = None config_model_name = None config_temperature = None provider_api_key = None provider_base_url = None messages = [] # 1. 使用独立的数据库会话处理所有数据库操作(不在此块内yield) async with async_db_session() as session: async with session.begin(): stream_auth = AuthSchema(db=session, user=auth.user, check_data_scope=False) # 如果配置有变动,先保存配置 if chat_data.config_changed and chat_data.config_data: await AIModelConfigService.update_service(auth=stream_auth, model_type=model_type, data=chat_data.config_data) # 获取模型配置 config = await AIModelConfigCRUD(stream_auth).get_by_model_type_crud(model_type=model_type) if not config: error_message = '模型配置不存在,请先配置模型' else: # 保存用户消息 user_msg = await AIModelTrainingMessageCRUD(stream_auth).create_crud(data={ 'model_config_id': config.id, 'role': 'user', 'content': chat_data.message }) # 获取供应商配置 provider = None if config.provider_id: provider = await AIProviderCRUD(stream_auth).get_by_id_crud(id=config.provider_id) if not provider: providers = await AIProviderCRUD(stream_auth).get_list_crud(search={'is_default': 1}) if providers: provider = providers[0] if not provider: error_message = '未配置AI供应商,请先在AI配置中添加供应商' else: # 获取历史对话记录 history_messages = await AIModelTrainingMessageCRUD(stream_auth).get_list_by_config_crud(model_config_id=config.id) # 提取需要的信息 config_id = config.id config_system_prompt = config.system_prompt config_model_name = config.model_name config_temperature = config.temperature provider_api_key = provider.api_key provider_base_url = provider.base_url # 构建消息列表 if config_system_prompt: messages.append({"role": "system", "content": config_system_prompt}) for msg in history_messages: if msg.id != user_msg.id: messages.append({"role": msg.role, "content": msg.content}) messages.append({"role": "user", "content": chat_data.message}) # 提交事务 await session.commit() # 2. 如果有错误,yield错误消息并返回 if error_message: log.warning(f"[训练对话] 配置错误: {error_message}") yield error_message return # 检查配置是否完整 if not provider_api_key or not provider_base_url: log.error("[训练对话] 供应商配置不完整") yield "供应商配置不完整,请检查API Key和Base URL" return # 3. 创建AI客户端并流式输出 import httpx from openai import AsyncOpenAI # 确保base_url包含/v1路径 if provider_base_url and not provider_base_url.rstrip('/').endswith('/v1'): provider_base_url = provider_base_url.rstrip('/') + '/v1' log.info(f"[训练对话] 开始调用AI,模型: {config_model_name}, 消息数: {len(messages)}") log.info(f"[训练对话] API配置: base_url={provider_base_url}") http_client = httpx.AsyncClient( timeout=60.0, follow_redirects=True, verify=False # 禁用 SSL 验证 ) client = AsyncOpenAI( api_key=provider_api_key, base_url=provider_base_url, http_client=http_client ) full_response = "" is_cancelled = False response = None try: response = await client.chat.completions.create( model=config_model_name or 'gpt-3.5-turbo', messages=messages, temperature=config_temperature, stream=True ) log.info(f"[训练对话] AI响应开始流式输出") async for chunk in response: # 处理不同的响应格式 content = None if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta if hasattr(delta, 'content') and delta.content: content = delta.content if content: full_response += content yield content log.info(f"[训练对话] AI响应完成,总长度: {len(full_response)}") except asyncio.CancelledError: # 客户端断开连接,停止AI输出 is_cancelled = True log.info(f"[训练对话] 客户端断开连接,停止AI输出,已生成长度: {len(full_response)}") # 关闭AI流式响应 if response: try: await response.close() except Exception: pass raise # 重新抛出以便 FastAPI 知道请求已取消 except GeneratorExit: # 生成器被关闭,通常是客户端断开 is_cancelled = True log.info(f"[训练对话] 生成器关闭,停止AI输出,已生成长度: {len(full_response)}") if response: try: await response.close() except Exception: pass raise except Exception as e: log.error(f"AI训练对话失败: {str(e)}") import traceback log.error(f"AI训练对话异常详情: {traceback.format_exc()}") error_msg = f"对话失败: {str(e)}" full_response = error_msg yield error_msg finally: await http_client.aclose() # 4. 保存AI响应 if full_response and config_id: async with async_db_session() as session: async with session.begin(): save_auth = AuthSchema(db=session, user=auth.user, check_data_scope=False) await AIModelTrainingMessageCRUD(save_auth).create_crud(data={ 'model_config_id': config_id, 'role': 'assistant', 'content': full_response }) await session.commit() class AIModelTestService: """起名测试服务层(用于小程序/外部调用,仅读不写)""" @classmethod async def test_naming(cls, model_type: str, text: str) -> str: """ 起名测试接口 - 读取模型配置和训练信息 - 内部流式调用AI - 组合响应后一次性返回 - 仅读数据,不保存对话记录 参数: model_type: 模型类型(enterprise_naming/personal_naming等) text: 用户输入的文本 返回: AI的完整响应文本 """ from app.core.database import async_db_session from app.api.v1.module_application.ai.schema import AI_MODEL_TYPES if model_type not in AI_MODEL_TYPES: raise CustomException(msg=f'无效的模型类型: {model_type}') # 用于存储AI调用所需的配置信息 config_system_prompt = None config_model_name = None config_temperature = None provider_api_key = None provider_base_url = None messages = [] # 1. 读取模型配置和训练记录(只读,不保存任何数据) async with async_db_session() as session: async with session.begin(): from app.api.v1.module_system.auth.schema import AuthSchema read_auth = AuthSchema(db=session, user=None, check_data_scope=False) # 获取模型配置 config = await AIModelConfigCRUD(read_auth).get_by_model_type_crud(model_type=model_type) if not config: raise CustomException(msg='模型配置不存在,请先配置模型') # 获取供应商配置 provider = None if config.provider_id: provider = await AIProviderCRUD(read_auth).get_by_id_crud(id=config.provider_id) if not provider: providers = await AIProviderCRUD(read_auth).get_list_crud(search={'is_default': 1}) if providers: provider = providers[0] if not provider: raise CustomException(msg='未配置AI供应商,请先在AI配置中添加供应商') # 获取历史训练对话记录(作为上下文) history_messages = await AIModelTrainingMessageCRUD(read_auth).get_list_by_config_crud(model_config_id=config.id) # 提取配置信息 config_system_prompt = config.system_prompt config_model_name = config.model_name config_temperature = config.temperature provider_api_key = provider.api_key provider_base_url = provider.base_url # 构建消息列表 if config_system_prompt: messages.append({"role": "system", "content": config_system_prompt}) # 添加历史训练对话作为上下文 for msg in history_messages: messages.append({"role": msg.role, "content": msg.content}) # 添加当前用户输入 messages.append({"role": "user", "content": text}) # 检查配置是否完整 if not provider_api_key or not provider_base_url: raise CustomException(msg="供应商配置不完整,请检查API Key和Base URL") # 2. 调用AI(内部流式,收集后一次性返回) import httpx import asyncio from openai import AsyncOpenAI # 确保base_url包含/v1路径 if provider_base_url and not provider_base_url.rstrip('/').endswith('/v1'): provider_base_url = provider_base_url.rstrip('/') + '/v1' log.info(f"[起名测试] 开始调用AI,模型类型: {model_type}, 模型: {config_model_name}, 消息数: {len(messages)}") def _extract_status_code(err: Exception) -> int | None: if hasattr(err, "status_code") and isinstance(getattr(err, "status_code"), int): return getattr(err, "status_code") response = getattr(err, "response", None) if response is not None and hasattr(response, "status_code"): return response.status_code return None def _extract_error_message(err: Exception) -> str: body = getattr(err, "body", None) if isinstance(body, dict) and "error" in body: inner = body.get("error") or {} msg = inner.get("message") if isinstance(msg, str) and msg.strip(): return msg.strip() return str(err) retry_times = 3 base_sleep = 0.8 last_err: Exception | None = None for attempt in range(1, retry_times + 1): http_client = httpx.AsyncClient( timeout=120.0, follow_redirects=True, verify=False ) client = AsyncOpenAI( api_key=provider_api_key, base_url=provider_base_url, http_client=http_client ) full_response = "" try: response = await client.chat.completions.create( model=config_model_name or 'gpt-3.5-turbo', messages=messages, temperature=config_temperature, stream=True ) async for chunk in response: if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta if hasattr(delta, 'content') and delta.content: full_response += delta.content log.info(f"[起名测试] AI响应完成,总长度: {len(full_response)}") return full_response except Exception as e: last_err = e status_code = _extract_status_code(e) err_msg = _extract_error_message(e) transient = status_code in {408, 429, 500, 502, 503, 504} or isinstance(e, (httpx.TimeoutException, httpx.NetworkError)) if not transient or attempt >= retry_times: log.error(f"[起名测试] AI调用失败(终止): {str(e)}") import traceback log.error(f"[起名测试] 异常详情: {traceback.format_exc()}") raise CustomException(msg=f"服务暂时不可用,请稍后重试。({err_msg})", status_code=503) sleep_s = base_sleep * (2 ** (attempt - 1)) log.warning(f"[起名测试] AI调用失败(第{attempt}次/{retry_times}次),{sleep_s:.2f}s后重试: {err_msg}") await asyncio.sleep(sleep_s) finally: await http_client.aclose() raise CustomException(msg=f"服务暂时不可用,请稍后重试。({_extract_error_message(last_err) if last_err else ''})", status_code=503)