Files

270 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
from typing import Sequence, Any
from app.core.base_crud import CRUDBase
from app.api.v1.module_system.auth.schema import AuthSchema
from .model import McpModel, AIProviderModel, EmbeddingConfigModel, KnowledgeBaseModel, AIModelConfigModel, AIModelTrainingMessageModel
from .schema import (
McpCreateSchema, McpUpdateSchema,
AIProviderCreateSchema, AIProviderUpdateSchema,
EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema,
KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema,
AIModelConfigCreateSchema, AIModelConfigUpdateSchema,
AIModelTrainingMessageCreateSchema
)
class McpCRUD(CRUDBase[McpModel, McpCreateSchema, McpUpdateSchema]):
"""MCP 服务器数据层"""
def __init__(self, auth: AuthSchema) -> None:
"""
初始化CRUD
参数:
- auth (AuthSchema): 认证信息模型
"""
self.auth = auth
super().__init__(model=McpModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> McpModel | None:
"""
获取MCP服务器详情
参数:
- id (int): MCP服务器ID
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
返回:
- McpModel | None: MCP服务器模型实例如果存在
"""
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> McpModel | None:
"""
通过名称获取MCP服务器
参数:
- name (str): MCP服务器名称
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
返回:
- Optional[McpModel]: MCP服务器模型实例如果存在
"""
return await self.get(name=name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[McpModel]:
"""
列表查询MCP服务器
参数:
- search (dict | None): 查询参数字典
- order_by (list[dict[str, str]] | None): 排序参数列表
- preload (list[str | Any] | None): 预加载关系,未提供时使用模型默认项
返回:
- Sequence[McpModel]: MCP服务器模型实例序列
"""
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: McpCreateSchema) -> McpModel | None:
"""
创建MCP服务器
参数:
- data (McpCreateSchema): 创建MCP服务器模型
返回:
- Optional[McpModel]: 创建的MCP服务器模型实例如果成功
"""
return await self.create(data=data)
async def update_crud(self, id: int, data: McpUpdateSchema) -> McpModel | None:
"""
更新MCP服务器
参数:
- id (int): MCP服务器ID
- data (McpUpdateSchema): 更新MCP服务器模型
返回:
- McpModel | None: 更新的MCP服务器模型实例如果成功
"""
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
"""
批量删除MCP服务器
参数:
- ids (list[int]): MCP服务器ID列表
返回:
- None
"""
return await self.delete(ids=ids)
class AIProviderCRUD(CRUDBase[AIProviderModel, AIProviderCreateSchema, AIProviderUpdateSchema]):
"""AI供应商数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=AIProviderModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIProviderModel | None:
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> AIProviderModel | None:
return await self.get(name=name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIProviderModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: AIProviderCreateSchema) -> AIProviderModel | None:
return await self.create(data=data)
async def update_crud(self, id: int, data: AIProviderUpdateSchema) -> AIProviderModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
async def clear_default_crud(self) -> None:
"""清除所有默认设置"""
obj_list = await self.list(search={'is_default': 1})
for obj in obj_list:
await self.update(id=obj.id, data=AIProviderUpdateSchema(
name=obj.name,
provider_type=obj.provider_type,
base_url=obj.base_url,
api_key=obj.api_key,
is_default=0,
description=obj.description
))
class EmbeddingConfigCRUD(CRUDBase[EmbeddingConfigModel, EmbeddingConfigCreateSchema, EmbeddingConfigUpdateSchema]):
"""向量化配置数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=EmbeddingConfigModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> EmbeddingConfigModel | None:
return await self.get(name=name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[EmbeddingConfigModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: EmbeddingConfigCreateSchema) -> EmbeddingConfigModel | None:
return await self.create(data=data)
async def update_crud(self, id: int, data: EmbeddingConfigUpdateSchema) -> EmbeddingConfigModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
async def clear_default_crud(self) -> None:
"""清除所有默认设置"""
obj_list = await self.list(search={'is_default': 1})
for obj in obj_list:
await self.update(id=obj.id, data=EmbeddingConfigUpdateSchema(
name=obj.name,
embedding_type=obj.embedding_type,
model_name=obj.model_name,
base_url=obj.base_url,
api_key=obj.api_key,
is_default=0,
description=obj.description
))
class KnowledgeBaseCRUD(CRUDBase[KnowledgeBaseModel, KnowledgeBaseCreateSchema, KnowledgeBaseUpdateSchema]):
"""知识库数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=KnowledgeBaseModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
return await self.get(id=id, preload=preload)
async def get_by_name_crud(self, name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
return await self.get(name=name, preload=preload)
async def get_by_collection_crud(self, collection_name: str, preload: list[str | Any] | None = None) -> KnowledgeBaseModel | None:
return await self.get(collection_name=collection_name, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[KnowledgeBaseModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'desc'}], preload=preload)
async def create_crud(self, data: dict) -> KnowledgeBaseModel | None:
"""create方法接受字典因为需要额外字段"""
return await self.create(data=data)
async def update_crud(self, id: int, data: KnowledgeBaseUpdateSchema | dict) -> KnowledgeBaseModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
class AIModelConfigCRUD(CRUDBase[AIModelConfigModel, AIModelConfigCreateSchema, AIModelConfigUpdateSchema]):
"""AI模型配置数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=AIModelConfigModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
return await self.get(id=id, preload=preload)
async def get_by_model_type_crud(self, model_type: str, preload: list[str | Any] | None = None) -> AIModelConfigModel | None:
return await self.get(model_type=model_type, preload=preload)
async def get_list_crud(self, search: dict | None = None, order_by: list[dict[str, str]] | None = None, preload: list[str | Any] | None = None) -> Sequence[AIModelConfigModel]:
return await self.list(search=search or {}, order_by=order_by or [{'id': 'asc'}], preload=preload)
async def create_crud(self, data: AIModelConfigCreateSchema | dict) -> AIModelConfigModel | None:
return await self.create(data=data)
async def update_crud(self, id: int, data: AIModelConfigUpdateSchema | dict) -> AIModelConfigModel | None:
return await self.update(id=id, data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
class AIModelTrainingMessageCRUD(CRUDBase[AIModelTrainingMessageModel, AIModelTrainingMessageCreateSchema, AIModelTrainingMessageCreateSchema]):
"""AI模型训练对话数据层"""
def __init__(self, auth: AuthSchema) -> None:
self.auth = auth
super().__init__(model=AIModelTrainingMessageModel, auth=auth)
async def get_by_id_crud(self, id: int, preload: list[str | Any] | None = None) -> AIModelTrainingMessageModel | None:
return await self.get(id=id, preload=preload)
async def get_list_by_config_crud(self, model_config_id: int, preload: list[str | Any] | None = None) -> Sequence[AIModelTrainingMessageModel]:
"""获取某个模型配置的所有训练对话"""
return await self.list(search={'model_config_id': model_config_id}, order_by=[{'id': 'asc'}], preload=preload)
async def create_crud(self, data: AIModelTrainingMessageCreateSchema | dict) -> AIModelTrainingMessageModel | None:
return await self.create(data=data)
async def delete_crud(self, ids: list[int]) -> None:
return await self.delete(ids=ids)
async def delete_by_config_crud(self, model_config_id: int) -> None:
"""删除某个模型配置的所有训练对话"""
messages = await self.list(search={'model_config_id': model_config_id})
if messages:
ids = [m.id for m in messages]
await self.delete(ids=ids)