270 lines
12 KiB
Python
270 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Sequence, Any
|
||
|
||
from app.core.base_crud import CRUDBase
|
||
|
||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||
from .model import McpModel, AIProviderModel, EmbeddingConfigModel, KnowledgeBaseModel, AIModelConfigModel, AIModelTrainingMessageModel
|
||
from .schema import (
|
||
McpCreateSchema, McpUpdateSchema,
|
||
AIProviderCreateSchema, AIProviderUpdateSchema,
|
||
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema,
|
||
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema,
|
||
AIModelConfigCreateSchema, AIModelConfigUpdateSchema,
|
||
AIModelTrainingMessageCreateSchema
|
||
)
|
||
|
||
|
||
class McpCRUD(CRUDBase[McpModel, McpCreateSchema, McpUpdateSchema]):
|
||
"""MCP 服务器数据层"""
|
||
|
||
def __init__(self, auth: AuthSchema) -> None:
|
||
"""
|
||
初始化CRUD
|
||
|
||
参数:
|
||
- auth (AuthSchema): 认证信息模型
|
||
"""
|
||
self.auth = auth
|
||
super().__init__(model=McpModel, auth=auth)
|
||
|
||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> McpModel | None:
|
||
"""
|
||
获取MCP服务器详情
|
||
|
||
参数:
|
||
- id (int): MCP服务器ID
|
||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||
|
||
返回:
|
||
- McpModel | None: MCP服务器模型实例(如果存在)
|
||
"""
|
||
return await self.get(id=id, preload=preload)
|
||
|
||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> McpModel | None:
|
||
"""
|
||
通过名称获取MCP服务器
|
||
|
||
参数:
|
||
- name (str): MCP服务器名称
|
||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||
|
||
返回:
|
||
- Optional[McpModel]: MCP服务器模型实例(如果存在)
|
||
"""
|
||
return await self.get(name=name, preload=preload)
|
||
|
||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[McpModel]:
|
||
"""
|
||
列表查询MCP服务器
|
||
|
||
参数:
|
||
- search (dict | None): 查询参数字典
|
||
- order_by (list[dict[str, str]] | None): 排序参数列表
|
||
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
|
||
|
||
返回:
|
||
- Sequence[McpModel]: MCP服务器模型实例序列
|
||
"""
|
||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||
|
||
async def create_crud(self, data: McpCreateSchema) -> McpModel | None:
|
||
"""
|
||
创建MCP服务器
|
||
|
||
参数:
|
||
- data (McpCreateSchema): 创建MCP服务器模型
|
||
|
||
返回:
|
||
- Optional[McpModel]: 创建的MCP服务器模型实例(如果成功)
|
||
"""
|
||
return await self.create(data=data)
|
||
|
||
async def update_crud(self, id: int, data: McpUpdateSchema) -> McpModel | None:
|
||
"""
|
||
更新MCP服务器
|
||
|
||
参数:
|
||
- id (int): MCP服务器ID
|
||
- data (McpUpdateSchema): 更新MCP服务器模型
|
||
|
||
返回:
|
||
- McpModel | None: 更新的MCP服务器模型实例(如果成功)
|
||
"""
|
||
return await self.update(id=id, data=data)
|
||
|
||
async def delete_crud(self, ids: list[int]) -> None:
|
||
"""
|
||
批量删除MCP服务器
|
||
|
||
参数:
|
||
- ids (list[int]): MCP服务器ID列表
|
||
|
||
返回:
|
||
- None
|
||
"""
|
||
return await self.delete(ids=ids)
|
||
|
||
|
||
class AIProviderCRUD(CRUDBase[AIProviderModel, AIProviderCreateSchema, AIProviderUpdateSchema]):
|
||
"""AI供应商数据层"""
|
||
|
||
def __init__(self, auth: AuthSchema) -> None:
|
||
self.auth = auth
|
||
super().__init__(model=AIProviderModel, auth=auth)
|
||
|
||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIProviderModel | None:
|
||
return await self.get(id=id, preload=preload)
|
||
|
||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> AIProviderModel | None:
|
||
return await self.get(name=name, preload=preload)
|
||
|
||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIProviderModel]:
|
||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||
|
||
async def create_crud(self, data: AIProviderCreateSchema) -> AIProviderModel | None:
|
||
return await self.create(data=data)
|
||
|
||
async def update_crud(self, id: int, data: AIProviderUpdateSchema) -> AIProviderModel | None:
|
||
return await self.update(id=id, data=data)
|
||
|
||
async def delete_crud(self, ids: list[int]) -> None:
|
||
return await self.delete(ids=ids)
|
||
|
||
async def clear_default_crud(self) -> None:
|
||
"""清除所有默认设置"""
|
||
obj_list = await self.list(search={'is_default': 1})
|
||
for obj in obj_list:
|
||
await self.update(id=obj.id, data=AIProviderUpdateSchema(
|
||
name=obj.name,
|
||
provider_type=obj.provider_type,
|
||
base_url=obj.base_url,
|
||
api_key=obj.api_key,
|
||
is_default=0,
|
||
description=obj.description
|
||
))
|
||
|
||
|
||
class EmbeddingConfigCRUD(CRUDBase[EmbeddingConfigModel, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema]):
|
||
"""向量化配置数据层"""
|
||
|
||
def __init__(self, auth: AuthSchema) -> None:
|
||
self.auth = auth
|
||
super().__init__(model=EmbeddingConfigModel, auth=auth)
|
||
|
||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
|
||
return await self.get(id=id, preload=preload)
|
||
|
||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
|
||
return await self.get(name=name, preload=preload)
|
||
|
||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[EmbeddingConfigModel]:
|
||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||
|
||
async def create_crud(self, data: EmbeddingConfigCreateSchema) -> EmbeddingConfigModel | None:
|
||
return await self.create(data=data)
|
||
|
||
async def update_crud(self, id: int, data: EmbeddingConfigUpdateSchema) -> EmbeddingConfigModel | None:
|
||
return await self.update(id=id, data=data)
|
||
|
||
async def delete_crud(self, ids: list[int]) -> None:
|
||
return await self.delete(ids=ids)
|
||
|
||
async def clear_default_crud(self) -> None:
|
||
"""清除所有默认设置"""
|
||
obj_list = await self.list(search={'is_default': 1})
|
||
for obj in obj_list:
|
||
await self.update(id=obj.id, data=EmbeddingConfigUpdateSchema(
|
||
name=obj.name,
|
||
embedding_type=obj.embedding_type,
|
||
model_name=obj.model_name,
|
||
base_url=obj.base_url,
|
||
api_key=obj.api_key,
|
||
is_default=0,
|
||
description=obj.description
|
||
))
|
||
|
||
|
||
class KnowledgeBaseCRUD(CRUDBase[KnowledgeBaseModel, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema]):
|
||
"""知识库数据层"""
|
||
|
||
def __init__(self, auth: AuthSchema) -> None:
|
||
self.auth = auth
|
||
super().__init__(model=KnowledgeBaseModel, auth=auth)
|
||
|
||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||
return await self.get(id=id, preload=preload)
|
||
|
||
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||
return await self.get(name=name, preload=preload)
|
||
|
||
async def get_by_collection_crud(self, collection_name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
|
||
return await self.get(collection_name=collection_name, preload=preload)
|
||
|
||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[KnowledgeBaseModel]:
|
||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'desc'}], preload=preload)
|
||
|
||
async def create_crud(self, data: dict) -> KnowledgeBaseModel | None:
|
||
"""create方法接受字典,因为需要额外字段"""
|
||
return await self.create(data=data)
|
||
|
||
async def update_crud(self, id: int, data: KnowledgeBaseUpdateSchema | dict) -> KnowledgeBaseModel | None:
|
||
return await self.update(id=id, data=data)
|
||
|
||
async def delete_crud(self, ids: list[int]) -> None:
|
||
return await self.delete(ids=ids)
|
||
|
||
|
||
class AIModelConfigCRUD(CRUDBase[AIModelConfigModel, AIModelConfigCreateSchema, AIModelConfigUpdateSchema]):
|
||
"""AI模型配置数据层"""
|
||
|
||
def __init__(self, auth: AuthSchema) -> None:
|
||
self.auth = auth
|
||
super().__init__(model=AIModelConfigModel, auth=auth)
|
||
|
||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
|
||
return await self.get(id=id, preload=preload)
|
||
|
||
async def get_by_model_type_crud(self, model_type: str, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
|
||
return await self.get(model_type=model_type, preload=preload)
|
||
|
||
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIModelConfigModel]:
|
||
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
|
||
|
||
async def create_crud(self, data: AIModelConfigCreateSchema | dict) -> AIModelConfigModel | None:
|
||
return await self.create(data=data)
|
||
|
||
async def update_crud(self, id: int, data: AIModelConfigUpdateSchema | dict) -> AIModelConfigModel | None:
|
||
return await self.update(id=id, data=data)
|
||
|
||
async def delete_crud(self, ids: list[int]) -> None:
|
||
return await self.delete(ids=ids)
|
||
|
||
|
||
class AIModelTrainingMessageCRUD(CRUDBase[AIModelTrainingMessageModel, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageCreateSchema]):
|
||
"""AI模型训练对话数据层"""
|
||
|
||
def __init__(self, auth: AuthSchema) -> None:
|
||
self.auth = auth
|
||
super().__init__(model=AIModelTrainingMessageModel, auth=auth)
|
||
|
||
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelTrainingMessageModel | None:
|
||
return await self.get(id=id, preload=preload)
|
||
|
||
async def get_list_by_config_crud(self, model_config_id: int, preload: list[str | Any] | None = None) -> Sequence[AIModelTrainingMessageModel]:
|
||
"""获取某个模型配置的所有训练对话"""
|
||
return await self.list(search={'model_config_id': model_config_id}, order_by=[{'id': 'asc'}], preload=preload)
|
||
|
||
async def create_crud(self, data: AIModelTrainingMessageCreateSchema | dict) -> AIModelTrainingMessageModel | None:
|
||
return await self.create(data=data)
|
||
|
||
async def delete_crud(self, ids: list[int]) -> None:
|
||
return await self.delete(ids=ids)
|
||
|
||
async def delete_by_config_crud(self, model_config_id: int) -> None:
|
||
"""删除某个模型配置的所有训练对话"""
|
||
messages = await self.list(search={'model_config_id': model_config_id})
|
||
if messages:
|
||
ids = [m.id for m in messages]
|
||
await self.delete(ids=ids) |