307 lines
11 KiB
Python
307 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from pydantic import ConfigDict, Field, HttpUrl, BaseModel
|
|
from fastapi import Query
|
|
|
|
from app.core.base_schema import BaseSchema
|
|
from app.common.enums import McpLLMProvider, EmbeddingType, KnowledgeBaseStatus
|
|
from app.core.base_schema import BaseSchema, UserBySchema
|
|
from app.common.enums import McpType
|
|
from app.core.validator import DateTimeStr
|
|
|
|
|
|
class ChatQuerySchema(BaseModel):
|
|
"""聊天查询模型"""
|
|
message: str = Field(..., min_length=1, max_length=4000, description="聊天消息")
|
|
|
|
|
|
class McpCreateSchema(BaseModel):
|
|
"""创建 MCP 服务器参数"""
|
|
name: str = Field(..., max_length=64, description='MCP 名称')
|
|
type: McpType = Field(McpType.stdio, description='MCP 类型')
|
|
description: str | None = Field(None, max_length=255, description='MCP 描述')
|
|
url: HttpUrl | None = Field(None, description='远程 SSE 地址')
|
|
command: str | None = Field(None, max_length=255, description='MCP 命令')
|
|
args: str | None = Field(None, max_length=255, description='MCP 命令参数,多个参数用英文逗号隔开')
|
|
env: dict[str, str] | None = Field(None, description='MCP 环境变量')
|
|
|
|
|
|
class McpUpdateSchema(McpCreateSchema):
|
|
"""更新 MCP 服务器参数"""
|
|
...
|
|
|
|
|
|
class McpOutSchema(McpCreateSchema, BaseSchema, UserBySchema):
|
|
"""MCP 服务器详情"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
class McpQueryParam:
|
|
"""MCP 服务器查询参数"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str | None = Query(None, description="MCP 名称"),
|
|
type: McpType | None = Query(None, description="MCP 类型"),
|
|
created_time: list[DateTimeStr] | None = Query(None, description="创建时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
|
updated_time: list[DateTimeStr] | None = Query(None, description="更新时间范围", examples=["2025-01-01 00:00:00", "2025-12-31 23:59:59"]),
|
|
created_id: int | None = Query(None, description="创建人"),
|
|
updated_id: int | None = Query(None, description="更新人"),
|
|
) -> None:
|
|
|
|
# 模糊查询字段
|
|
self.name = ("like", name) if name else None
|
|
|
|
# 精确查询字段
|
|
self.type = type
|
|
self.created_id = created_id
|
|
self.updated_id = updated_id
|
|
|
|
# 时间范围查询
|
|
if created_time and len(created_time) == 2:
|
|
self.created_time = ("between", (created_time[0], created_time[1]))
|
|
if updated_time and len(updated_time) == 2:
|
|
self.updated_time = ("between", (updated_time[0], updated_time[1]))
|
|
|
|
|
|
class McpChatParam(BaseSchema):
|
|
"""MCP 聊天参数"""
|
|
pk: list[int] = Field(..., description='MCP ID 列表')
|
|
provider: McpLLMProvider = Field(McpLLMProvider.openai, description='LLM 供应商')
|
|
model: str = Field(..., description='LLM 名称')
|
|
key: str = Field(..., description='LLM API Key')
|
|
base_url: str | None = Field(None, description='自定义 LLM API 地址,必须兼容 openai 供应商')
|
|
prompt: str = Field(..., description='用户提示词')
|
|
|
|
|
|
# ============== AI供应商配置 ==============
|
|
|
|
class AIProviderCreateSchema(BaseModel):
|
|
"""创建 AI供应商参数"""
|
|
name: str = Field(..., max_length=50, description='供应商名称')
|
|
provider_type: str = Field(..., max_length=50, description='供应商类型(openai/deepseek/anthropic/gemini/qwen等)')
|
|
base_url: str = Field(..., max_length=255, description='接口地址BaseURL')
|
|
api_key: str = Field(..., max_length=255, description='API Key')
|
|
is_default: int = Field(0, description='是否默认供应商(0:否 1:是)')
|
|
description: str | None = Field(None, max_length=255, description='备注')
|
|
|
|
|
|
class AIProviderUpdateSchema(AIProviderCreateSchema):
|
|
"""更新 AI供应商参数"""
|
|
...
|
|
|
|
|
|
class AIProviderOutSchema(AIProviderCreateSchema, BaseSchema, UserBySchema):
|
|
"""AI供应商详情"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
class AIProviderQueryParam:
|
|
"""AI供应商查询参数"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str | None = Query(None, description="供应商名称"),
|
|
provider_type: str | None = Query(None, description="供应商类型"),
|
|
is_default: int | None = Query(None, description="是否默认"),
|
|
) -> None:
|
|
self.name = ("like", name) if name else None
|
|
self.provider_type = provider_type
|
|
self.is_default = is_default
|
|
|
|
|
|
# ============== 向量化配置 ==============
|
|
|
|
class EmbeddingConfigCreateSchema(BaseModel):
|
|
"""创建 向量化配置参数"""
|
|
name: str = Field(..., max_length=50, description='配置名称')
|
|
embedding_type: int = Field(0, description='向量化类型(0:本地 1:远程)')
|
|
model_name: str = Field(..., max_length=100, description='Embedding模型名称')
|
|
base_url: str | None = Field(None, max_length=255, description='远程接口地址(远程模式必填)')
|
|
api_key: str | None = Field(None, max_length=255, description='远程API Key(远程模式必填)')
|
|
is_default: int = Field(0, description='是否默认配置(0:否 1:是)')
|
|
description: str | None = Field(None, max_length=255, description='备注')
|
|
|
|
|
|
class EmbeddingConfigUpdateSchema(EmbeddingConfigCreateSchema):
|
|
"""更新 向量化配置参数"""
|
|
...
|
|
|
|
|
|
class EmbeddingConfigOutSchema(EmbeddingConfigCreateSchema, BaseSchema, UserBySchema):
|
|
"""向量化配置详情"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
class EmbeddingConfigQueryParam:
|
|
"""向量化配置查询参数"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str | None = Query(None, description="配置名称"),
|
|
embedding_type: int | None = Query(None, description="向量化类型"),
|
|
is_default: int | None = Query(None, description="是否默认"),
|
|
) -> None:
|
|
self.name = ("like", name) if name else None
|
|
self.embedding_type = embedding_type
|
|
self.is_default = is_default
|
|
|
|
|
|
# ============== 知识库 ==============
|
|
|
|
class KnowledgeBaseCreateSchema(BaseModel):
|
|
"""创建 知识库参数"""
|
|
name: str = Field(..., max_length=100, description='知识库名称')
|
|
embedding_config_id: int | None = Field(None, description='向量化配置ID')
|
|
description: str | None = Field(None, max_length=255, description='备注')
|
|
|
|
|
|
class KnowledgeBaseUpdateSchema(BaseModel):
|
|
"""更新 知识库参数"""
|
|
name: str | None = Field(None, max_length=100, description='知识库名称')
|
|
embedding_config_id: int | None = Field(None, description='向量化配置ID')
|
|
description: str | None = Field(None, max_length=255, description='备注')
|
|
|
|
|
|
class EmbeddingConfigRefSchema(BaseModel):
|
|
"""向量化配置引用"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: int
|
|
name: str
|
|
embedding_type: int
|
|
model_name: str
|
|
|
|
|
|
class KnowledgeBaseOutSchema(BaseSchema, UserBySchema):
|
|
"""知识库详情"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
name: str
|
|
embedding_config_id: int | None = None
|
|
collection_name: str
|
|
document_count: int = 0
|
|
vector_count: int = 0
|
|
kb_status: int = 0
|
|
error_message: str | None = None
|
|
description: str | None = None
|
|
file_paths: list[str] | None = None
|
|
embedding_config: EmbeddingConfigRefSchema | None = None
|
|
|
|
|
|
class KnowledgeBaseQueryParam:
|
|
"""知识库查询参数"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str | None = Query(None, description="知识库名称"),
|
|
embedding_config_id: int | None = Query(None, description="向量化配置ID"),
|
|
kb_status: int | None = Query(None, description="知识库状态"),
|
|
) -> None:
|
|
self.name = ("like", name) if name else None
|
|
self.embedding_config_id = embedding_config_id
|
|
self.kb_status = kb_status
|
|
|
|
|
|
# ============== AI模型配置 ==============
|
|
|
|
# AI模型类型常量
|
|
AI_MODEL_TYPES = {
|
|
'enterprise_naming': '企业起名',
|
|
'enterprise_renaming': '企业改名',
|
|
'enterprise_scoring': '企业测名',
|
|
'enterprise_scoring_trial': '企业测名试用',
|
|
'personal_naming': '个人起名',
|
|
'personal_renaming': '个人改名',
|
|
'personal_scoring': '个人测名',
|
|
'personal_scoring_trial': '个人测名试用',
|
|
'yuanfen_hepan': '缘分合盘',
|
|
'bazi_zeji': '八字择吉',
|
|
'caiyun_jiexi': '财运解析',
|
|
'caiyun_jiexi_qiye': '企业财运解析'
|
|
}
|
|
|
|
|
|
class AIProviderRefSchema(BaseModel):
|
|
"""供应商引用"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: int
|
|
name: str
|
|
provider_type: str
|
|
base_url: str
|
|
|
|
|
|
class AIModelConfigCreateSchema(BaseModel):
|
|
"""创建 AI模型配置参数"""
|
|
model_type: str = Field(..., max_length=50, description='模型类型(naming/renaming/scoring/report)')
|
|
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
|
|
provider_id: int | None = Field(None, description='AI供应商ID')
|
|
system_prompt: str | None = Field(None, description='系统提示词')
|
|
temperature: float = Field(1.0, ge=0, le=2, description='模型温度(0-2)')
|
|
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
|
|
|
|
|
|
class AIModelConfigUpdateSchema(BaseModel):
|
|
"""更新 AI模型配置参数"""
|
|
model_name: str | None = Field(None, max_length=100, description='使用的模型名称')
|
|
provider_id: int | None = Field(None, description='AI供应商ID')
|
|
system_prompt: str | None = Field(None, description='系统提示词')
|
|
temperature: float | None = Field(None, ge=0, le=2, description='模型温度(0-2)')
|
|
knowledge_base_ids: list[int] | None = Field(None, description='关联的知识库ID列表')
|
|
|
|
|
|
class AIModelConfigOutSchema(BaseSchema, UserBySchema):
|
|
"""模型配置详情"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
model_type: str
|
|
model_name: str | None = None
|
|
provider_id: int | None = None
|
|
system_prompt: str | None = None
|
|
temperature: float = 1.0
|
|
knowledge_base_ids: list[int] | None = None
|
|
provider: AIProviderRefSchema | None = None
|
|
|
|
|
|
class AIModelConfigQueryParam:
|
|
"""AI模型配置查询参数"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_type: str | None = Query(None, description="模型类型"),
|
|
) -> None:
|
|
self.model_type = model_type
|
|
|
|
|
|
# ============== AI模型训练对话 ==============
|
|
|
|
class AIModelTrainingMessageCreateSchema(BaseModel):
|
|
"""创建训练对话消息参数"""
|
|
model_config_id: int = Field(..., description='模型配置ID')
|
|
role: str = Field(..., max_length=20, description='角色(user/assistant)')
|
|
content: str = Field(..., description='消息内容')
|
|
|
|
|
|
class AIModelTrainingMessageOutSchema(BaseSchema, UserBySchema):
|
|
"""训练对话消息详情"""
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
model_config_id: int
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class AIModelTrainingChatSchema(BaseModel):
|
|
"""训练对话请求参数"""
|
|
model_type: str = Field(..., description='模型类型')
|
|
message: str = Field(..., min_length=1, description='用户消息')
|
|
# 如果配置有变动,先保存配置
|
|
config_changed: bool = Field(False, description='配置是否有变动')
|
|
config_data: AIModelConfigUpdateSchema | None = Field(None, description='变动的配置数据')
|
|
|
|
|
|
class AIModelTestSchema(BaseModel):
|
|
"""起名测试请求参数(用于小程序/外部调用)"""
|
|
model_type: str = Field(..., description='模型类型(enterprise_naming/personal_naming等)')
|
|
text: str = Field(..., min_length=1, max_length=4000, description='用户输入的文本') |