# -*- coding: utf-8 -*- from typing import Sequence, Any from app.core.base_crud import CRUDBase from app.api.v1.module_system.auth.schema import AuthSchema from .model import McpModel, AIProviderModel, EmbeddingConfigModel, KnowledgeBaseModel, AIModelConfigModel, AIModelTrainingMessageModel from .schema import ( McpCreateSchema, McpUpdateSchema, AIProviderCreateSchema, AIProviderUpdateSchema, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema, AIModelConfigCreateSchema, AIModelConfigUpdateSchema, AIModelTrainingMessageCreateSchema ) class McpCRUD(CRUDBase[McpModel, McpCreateSchema, McpUpdateSchema]): """MCP 服务器数据层""" def __init__(self, auth: AuthSchema) -> None: """ 初始化CRUD 参数: - auth (AuthSchema): 认证信息模型 """ self.auth = auth super().__init__(model=McpModel, auth=auth) async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> McpModel | None: """ 获取MCP服务器详情 参数: - id (int): MCP服务器ID - preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项 返回: - McpModel | None: MCP服务器模型实例(如果存在) """ return await self.get(id=id, preload=preload) async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> McpModel | None: """ 通过名称获取MCP服务器 参数: - name (str): MCP服务器名称 - preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项 返回: - Optional[McpModel]: MCP服务器模型实例(如果存在) """ return await self.get(name=name, preload=preload) async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[McpModel]: """ 列表查询MCP服务器 参数: - search (dict | None): 查询参数字典 - order_by (list[dict[str, str]] | None): 排序参数列表 - preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项 返回: - Sequence[McpModel]: MCP服务器模型实例序列 """ return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload) async def create_crud(self, data: McpCreateSchema) -> McpModel | None: """ 创建MCP服务器 参数: - data (McpCreateSchema): 创建MCP服务器模型 返回: - Optional[McpModel]: 创建的MCP服务器模型实例(如果成功) """ return await self.create(data=data) async def update_crud(self, id: int, data: McpUpdateSchema) -> McpModel | None: """ 更新MCP服务器 参数: - id (int): MCP服务器ID - data (McpUpdateSchema): 更新MCP服务器模型 返回: - McpModel | None: 更新的MCP服务器模型实例(如果成功) """ return await self.update(id=id, data=data) async def delete_crud(self, ids: list[int]) -> None: """ 批量删除MCP服务器 参数: - ids (list[int]): MCP服务器ID列表 返回: - None """ return await self.delete(ids=ids) class AIProviderCRUD(CRUDBase[AIProviderModel, AIProviderCreateSchema, AIProviderUpdateSchema]): """AI供应商数据层""" def __init__(self, auth: AuthSchema) -> None: self.auth = auth super().__init__(model=AIProviderModel, auth=auth) async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIProviderModel | None: return await self.get(id=id, preload=preload) async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> AIProviderModel | None: return await self.get(name=name, preload=preload) async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIProviderModel]: return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload) async def create_crud(self, data: AIProviderCreateSchema) -> AIProviderModel | None: return await self.create(data=data) async def update_crud(self, id: int, data: AIProviderUpdateSchema) -> AIProviderModel | None: return await self.update(id=id, data=data) async def delete_crud(self, ids: list[int]) -> None: return await self.delete(ids=ids) async def clear_default_crud(self) -> None: """清除所有默认设置""" obj_list = await self.list(search={'is_default': 1}) for obj in obj_list: await self.update(id=obj.id, data=AIProviderUpdateSchema( name=obj.name, provider_type=obj.provider_type, base_url=obj.base_url, api_key=obj.api_key, is_default=0, description=obj.description )) class EmbeddingConfigCRUD(CRUDBase[EmbeddingConfigModel, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema]): """向量化配置数据层""" def __init__(self, auth: AuthSchema) -> None: self.auth = auth super().__init__(model=EmbeddingConfigModel, auth=auth) async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None: return await self.get(id=id, preload=preload) async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None: return await self.get(name=name, preload=preload) async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[EmbeddingConfigModel]: return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload) async def create_crud(self, data: EmbeddingConfigCreateSchema) -> EmbeddingConfigModel | None: return await self.create(data=data) async def update_crud(self, id: int, data: EmbeddingConfigUpdateSchema) -> EmbeddingConfigModel | None: return await self.update(id=id, data=data) async def delete_crud(self, ids: list[int]) -> None: return await self.delete(ids=ids) async def clear_default_crud(self) -> None: """清除所有默认设置""" obj_list = await self.list(search={'is_default': 1}) for obj in obj_list: await self.update(id=obj.id, data=EmbeddingConfigUpdateSchema( name=obj.name, embedding_type=obj.embedding_type, model_name=obj.model_name, base_url=obj.base_url, api_key=obj.api_key, is_default=0, description=obj.description )) class KnowledgeBaseCRUD(CRUDBase[KnowledgeBaseModel, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema]): """知识库数据层""" def __init__(self, auth: AuthSchema) -> None: self.auth = auth super().__init__(model=KnowledgeBaseModel, auth=auth) async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None: return await self.get(id=id, preload=preload) async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None: return await self.get(name=name, preload=preload) async def get_by_collection_crud(self, collection_name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None: return await self.get(collection_name=collection_name, preload=preload) async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[KnowledgeBaseModel]: return await self.list(search=search or {}, order_by=order_by or [{'id': 'desc'}], preload=preload) async def create_crud(self, data: dict) -> KnowledgeBaseModel | None: """create方法接受字典,因为需要额外字段""" return await self.create(data=data) async def update_crud(self, id: int, data: KnowledgeBaseUpdateSchema | dict) -> KnowledgeBaseModel | None: return await self.update(id=id, data=data) async def delete_crud(self, ids: list[int]) -> None: return await self.delete(ids=ids) class AIModelConfigCRUD(CRUDBase[AIModelConfigModel, AIModelConfigCreateSchema, AIModelConfigUpdateSchema]): """AI模型配置数据层""" def __init__(self, auth: AuthSchema) -> None: self.auth = auth super().__init__(model=AIModelConfigModel, auth=auth) async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelConfigModel | None: return await self.get(id=id, preload=preload) async def get_by_model_type_crud(self, model_type: str, preload: list[str | Any] | None = None) -> AIModelConfigModel | None: return await self.get(model_type=model_type, preload=preload) async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIModelConfigModel]: return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload) async def create_crud(self, data: AIModelConfigCreateSchema | dict) -> AIModelConfigModel | None: return await self.create(data=data) async def update_crud(self, id: int, data: AIModelConfigUpdateSchema | dict) -> AIModelConfigModel | None: return await self.update(id=id, data=data) async def delete_crud(self, ids: list[int]) -> None: return await self.delete(ids=ids) class AIModelTrainingMessageCRUD(CRUDBase[AIModelTrainingMessageModel, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageCreateSchema]): """AI模型训练对话数据层""" def __init__(self, auth: AuthSchema) -> None: self.auth = auth super().__init__(model=AIModelTrainingMessageModel, auth=auth) async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelTrainingMessageModel | None: return await self.get(id=id, preload=preload) async def get_list_by_config_crud(self, model_config_id: int, preload: list[str | Any] | None = None) -> Sequence[AIModelTrainingMessageModel]: """获取某个模型配置的所有训练对话""" return await self.list(search={'model_config_id': model_config_id}, order_by=[{'id': 'asc'}], preload=preload) async def create_crud(self, data: AIModelTrainingMessageCreateSchema | dict) -> AIModelTrainingMessageModel | None: return await self.create(data=data) async def delete_crud(self, ids: list[int]) -> None: return await self.delete(ids=ids) async def delete_by_config_crud(self, model_config_id: int) -> None: """删除某个模型配置的所有训练对话""" messages = await self.list(search={'model_config_id': model_config_id}) if messages: ids = [m.id for m in messages] await self.delete(ids=ids)