upload project source code
This commit is contained in:
2
后端源码/yifan.action-ai.cn/api-bak/app/core/__init__.py
Normal file
2
后端源码/yifan.action-ai.cn/api-bak/app/core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
473
后端源码/yifan.action-ai.cn/api-bak/app/core/base_crud.py
Normal file
473
后端源码/yifan.action-ai.cn/api-bak/app/core/base_crud.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import TypeVar, Sequence, Generic, Dict, Any, List, Optional, Type, Union
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.engine import Result
|
||||
from sqlalchemy import asc, func, select, delete, Select, desc, update
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
from app.core.base_model import MappedBase
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.permission import Permission
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=MappedBase)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
OutSchemaType = TypeVar("OutSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""基础数据层"""
|
||||
|
||||
def __init__(self, model: Type[ModelType], auth: AuthSchema) -> None:
|
||||
"""
|
||||
初始化CRUDBase类
|
||||
|
||||
参数:
|
||||
- model (Type[ModelType]): 数据模型类。
|
||||
- auth (AuthSchema): 认证信息。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
self.model = model
|
||||
self.auth = auth
|
||||
|
||||
async def get(self, preload: Optional[List[Union[str, Any]]] = None, **kwargs) -> Optional[ModelType]:
|
||||
"""
|
||||
根据条件获取单个对象
|
||||
|
||||
参数:
|
||||
- preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
|
||||
- **kwargs: 查询条件
|
||||
|
||||
返回:
|
||||
- Optional[ModelType]: 对象实例
|
||||
|
||||
异常:
|
||||
- CustomException: 查询失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
conditions = await self.__build_conditions(**kwargs)
|
||||
sql = select(self.model).where(*conditions)
|
||||
# 应用可配置的预加载选项
|
||||
for opt in self.__loader_options(preload):
|
||||
sql = sql.options(opt)
|
||||
|
||||
sql = await self.__filter_permissions(sql)
|
||||
|
||||
result: Result = await self.auth.db.execute(sql)
|
||||
obj = result.scalars().first()
|
||||
return obj
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"获取查询失败: {str(e)}")
|
||||
|
||||
async def list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
|
||||
"""
|
||||
根据条件获取对象列表
|
||||
|
||||
参数:
|
||||
- search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
|
||||
- order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
|
||||
- preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
|
||||
|
||||
返回:
|
||||
- Sequence[ModelType]: 对象列表
|
||||
|
||||
异常:
|
||||
- CustomException: 查询失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
conditions = await self.__build_conditions(**search) if search else []
|
||||
order = order_by or [{'id': 'asc'}]
|
||||
sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
|
||||
# 应用可配置的预加载选项
|
||||
for opt in self.__loader_options(preload):
|
||||
sql = sql.options(opt)
|
||||
sql = await self.__filter_permissions(sql)
|
||||
result: Result = await self.auth.db.execute(sql)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"列表查询失败: {str(e)}")
|
||||
|
||||
async def tree_list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, children_attr: str = 'children', preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
|
||||
"""
|
||||
获取树形结构数据列表
|
||||
|
||||
参数:
|
||||
- search (Optional[Dict]): 查询条件
|
||||
- order_by (Optional[List[Dict[str, str]]]): 排序字段
|
||||
- children_attr (str): 子节点属性名
|
||||
- preload (Optional[List[Union[str, Any]]]): 额外预加载关系,若为None则默认包含children_attr
|
||||
|
||||
返回:
|
||||
- Sequence[ModelType]: 树形结构数据列表
|
||||
|
||||
异常:
|
||||
- CustomException: 查询失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
conditions = await self.__build_conditions(**search) if search else []
|
||||
order = order_by or [{'id': 'asc'}]
|
||||
sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
|
||||
|
||||
# 处理预加载选项
|
||||
final_preload = preload
|
||||
# 如果没有提供preload且children_attr存在,则添加到预加载选项中
|
||||
if preload is None and children_attr and hasattr(self.model, children_attr):
|
||||
# 获取模型默认预加载选项
|
||||
model_defaults = getattr(self.model, "__loader_options__", [])
|
||||
# 将children_attr添加到默认预加载选项中
|
||||
final_preload = list(model_defaults) + [children_attr]
|
||||
|
||||
# 应用预加载选项
|
||||
for opt in self.__loader_options(final_preload):
|
||||
sql = sql.options(opt)
|
||||
|
||||
sql = await self.__filter_permissions(sql)
|
||||
result: Result = await self.auth.db.execute(sql)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"树形列表查询失败: {str(e)}")
|
||||
|
||||
async def page(self, offset: int, limit: int, order_by: List[Dict[str, str]], search: Dict, out_schema: Type[OutSchemaType], preload: Optional[List[Union[str, Any]]] = None) -> Dict:
|
||||
"""
|
||||
获取分页数据
|
||||
|
||||
参数:
|
||||
- offset (int): 偏移量
|
||||
- limit (int): 每页数量
|
||||
- order_by (List[Dict[str, str]]): 排序字段
|
||||
- search (Dict): 查询条件
|
||||
- out_schema (Type[OutSchemaType]): 输出数据模型
|
||||
- preload (Optional[List[Union[str, Any]]]): 预加载关系
|
||||
|
||||
返回:
|
||||
- Dict: 分页数据
|
||||
|
||||
异常:
|
||||
- CustomException: 查询失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
conditions = await self.__build_conditions(**search) if search else []
|
||||
order = order_by or [{'id': 'asc'}]
|
||||
sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
|
||||
# 应用预加载选项
|
||||
for opt in self.__loader_options(preload):
|
||||
sql = sql.options(opt)
|
||||
sql = await self.__filter_permissions(sql)
|
||||
|
||||
# 优化count查询:使用主键计数而非全表扫描
|
||||
mapper = sa_inspect(self.model)
|
||||
pk_cols = list(getattr(mapper, "primary_key", []))
|
||||
if pk_cols:
|
||||
# 使用主键的第一列进行计数(主键必定非NULL,性能更好)
|
||||
count_sql = select(func.count(pk_cols[0])).select_from(self.model)
|
||||
else:
|
||||
# 降级方案:使用count(*)
|
||||
count_sql = select(func.count()).select_from(self.model)
|
||||
|
||||
if conditions:
|
||||
count_sql = count_sql.where(*conditions)
|
||||
count_sql = await self.__filter_permissions(count_sql)
|
||||
|
||||
total_result = await self.auth.db.execute(count_sql)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
result: Result = await self.auth.db.execute(sql.offset(offset).limit(limit))
|
||||
objs = result.scalars().all()
|
||||
|
||||
return {
|
||||
"page_no": offset // limit + 1 if limit else 1,
|
||||
"page_size": limit if limit else 10,
|
||||
"total": total,
|
||||
"has_next": offset + limit < total,
|
||||
"items": [out_schema.model_validate(obj).model_dump() for obj in objs]
|
||||
}
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"分页查询失败: {str(e)}")
|
||||
|
||||
async def create(self, data: Union[CreateSchemaType, Dict]) -> ModelType:
|
||||
"""
|
||||
创建新对象
|
||||
|
||||
参数:
|
||||
- data (Union[CreateSchemaType, Dict]): 对象属性
|
||||
|
||||
返回:
|
||||
- ModelType: 新创建的对象实例
|
||||
|
||||
异常:
|
||||
- CustomException: 创建失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
obj_dict = data if isinstance(data, dict) else data.model_dump()
|
||||
obj = self.model(**obj_dict)
|
||||
|
||||
# 设置字段值(只检查一次current_user)
|
||||
if self.auth.user:
|
||||
if hasattr(obj, "created_id"):
|
||||
setattr(obj, "created_id", self.auth.user.id)
|
||||
if hasattr(obj, "updated_id"):
|
||||
setattr(obj, "updated_id", self.auth.user.id)
|
||||
|
||||
self.auth.db.add(obj)
|
||||
await self.auth.db.flush()
|
||||
await self.auth.db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"创建失败: {str(e)}")
|
||||
|
||||
async def update(self, id: int, data: Union[UpdateSchemaType, Dict]) -> ModelType:
|
||||
"""
|
||||
更新对象
|
||||
|
||||
参数:
|
||||
- id (int): 对象ID
|
||||
- data (Union[UpdateSchemaType, Dict]): 更新的属性及值
|
||||
|
||||
返回:
|
||||
- ModelType: 更新后的对象实例
|
||||
|
||||
异常:
|
||||
- CustomException: 更新失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
obj_dict = data if isinstance(data, dict) else data.model_dump(exclude_unset=True, exclude={"id"})
|
||||
obj = await self.get(id=id)
|
||||
if not obj:
|
||||
raise CustomException(msg="更新对象不存在")
|
||||
|
||||
# 设置字段值(只检查一次current_user)
|
||||
if self.auth.user:
|
||||
if hasattr(obj, "updated_id"):
|
||||
setattr(obj, "updated_id", self.auth.user.id)
|
||||
|
||||
for key, value in obj_dict.items():
|
||||
if hasattr(obj, key):
|
||||
setattr(obj, key, value)
|
||||
|
||||
await self.auth.db.flush()
|
||||
await self.auth.db.refresh(obj)
|
||||
|
||||
# 权限二次确认:flush后再次验证对象仍在权限范围内
|
||||
# 防止并发修改导致的权限逃逸(如其他事务修改了created_id)
|
||||
verify_obj = await self.get(id=id)
|
||||
if not verify_obj:
|
||||
# 对象已被删除或权限已失效
|
||||
raise CustomException(msg="更新失败,对象不存在或无权限访问")
|
||||
|
||||
return obj
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"更新失败: {str(e)}")
|
||||
|
||||
async def delete(self, ids: List[int]) -> None:
|
||||
"""
|
||||
删除对象
|
||||
|
||||
参数:
|
||||
- ids (List[int]): 对象ID列表
|
||||
|
||||
异常:
|
||||
- CustomException: 删除失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
# 先查询确认权限,避免删除无权限的数据
|
||||
objs = await self.list(search={"id": ("in", ids)})
|
||||
accessible_ids = [obj.id for obj in objs]
|
||||
|
||||
# 检查是否所有ID都有权限访问
|
||||
inaccessible_count = len(ids) - len(accessible_ids)
|
||||
if inaccessible_count > 0:
|
||||
raise CustomException(msg=f"无权限删除{inaccessible_count}条数据")
|
||||
|
||||
if not accessible_ids:
|
||||
return # 没有可删除的数据
|
||||
|
||||
mapper = sa_inspect(self.model)
|
||||
pk_cols = list(getattr(mapper, "primary_key", []))
|
||||
if not pk_cols:
|
||||
raise CustomException(msg="模型缺少主键,无法删除")
|
||||
if len(pk_cols) > 1:
|
||||
raise CustomException(msg="暂不支持复合主键的批量删除")
|
||||
|
||||
# 只删除有权限的数据
|
||||
sql = delete(self.model).where(pk_cols[0].in_(accessible_ids))
|
||||
await self.auth.db.execute(sql)
|
||||
await self.auth.db.flush()
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"删除失败: {str(e)}")
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""
|
||||
清空对象表
|
||||
|
||||
异常:
|
||||
- CustomException: 清空失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
sql = delete(self.model)
|
||||
await self.auth.db.execute(sql)
|
||||
await self.auth.db.flush()
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"清空失败: {str(e)}")
|
||||
|
||||
async def set(self, ids: List[int], **kwargs) -> None:
|
||||
"""
|
||||
批量更新对象
|
||||
|
||||
参数:
|
||||
- ids (List[int]): 对象ID列表
|
||||
- **kwargs: 更新的属性及值
|
||||
|
||||
异常:
|
||||
- CustomException: 更新失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
# 先查询确认权限,避免更新无权限的数据
|
||||
objs = await self.list(search={"id": ("in", ids)})
|
||||
accessible_ids = [obj.id for obj in objs]
|
||||
|
||||
# 检查是否所有ID都有权限访问
|
||||
inaccessible_count = len(ids) - len(accessible_ids)
|
||||
if inaccessible_count > 0:
|
||||
raise CustomException(msg=f"无权限更新{inaccessible_count}条数据")
|
||||
|
||||
if not accessible_ids:
|
||||
return # 没有可更新的数据
|
||||
|
||||
mapper = sa_inspect(self.model)
|
||||
pk_cols = list(getattr(mapper, "primary_key", []))
|
||||
if not pk_cols:
|
||||
raise CustomException(msg="模型缺少主键,无法更新")
|
||||
if len(pk_cols) > 1:
|
||||
raise CustomException(msg="暂不支持复合主键的批量更新")
|
||||
|
||||
# 只更新有权限的数据
|
||||
sql = update(self.model).where(pk_cols[0].in_(accessible_ids)).values(**kwargs)
|
||||
await self.auth.db.execute(sql)
|
||||
await self.auth.db.flush()
|
||||
except CustomException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CustomException(msg=f"批量更新失败: {str(e)}")
|
||||
|
||||
async def __filter_permissions(self, sql: Select) -> Select:
|
||||
"""
|
||||
过滤数据权限(仅用于Select)。
|
||||
"""
|
||||
filter = Permission(
|
||||
model=self.model,
|
||||
auth=self.auth
|
||||
)
|
||||
return await filter.filter_query(sql)
|
||||
|
||||
async def __build_conditions(self, **kwargs) -> List[ColumnElement]:
|
||||
"""
|
||||
构建查询条件
|
||||
|
||||
参数:
|
||||
- **kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
- List[ColumnElement]: SQL条件表达式列表
|
||||
|
||||
异常:
|
||||
- CustomException: 查询参数不存在时抛出异常
|
||||
"""
|
||||
conditions = []
|
||||
for key, value in kwargs.items():
|
||||
if value is None or value == "":
|
||||
continue
|
||||
|
||||
attr = getattr(self.model, key)
|
||||
if isinstance(value, tuple):
|
||||
seq, val = value
|
||||
if seq == "None":
|
||||
conditions.append(attr.is_(None))
|
||||
elif seq == "not None":
|
||||
conditions.append(attr.isnot(None))
|
||||
elif seq == "date" and val:
|
||||
conditions.append(func.date_format(attr, "%Y-%m-%d") == val)
|
||||
elif seq == "month" and val:
|
||||
conditions.append(func.date_format(attr, "%Y-%m") == val)
|
||||
elif seq == "like" and val:
|
||||
conditions.append(attr.like(f"%{val}%"))
|
||||
elif seq == "in" and val:
|
||||
conditions.append(attr.in_(val))
|
||||
elif seq == "between" and isinstance(val, (list, tuple)) and len(val) == 2:
|
||||
conditions.append(attr.between(val[0], val[1]))
|
||||
elif seq == "!=" and val:
|
||||
conditions.append(attr != val)
|
||||
elif seq == ">" and val:
|
||||
conditions.append(attr > val)
|
||||
elif seq == ">=" and val:
|
||||
conditions.append(attr >= val)
|
||||
elif seq == "<" and val:
|
||||
conditions.append(attr < val)
|
||||
elif seq == "<=" and val:
|
||||
conditions.append(attr <= val)
|
||||
elif seq == "==" and val:
|
||||
conditions.append(attr == val)
|
||||
else:
|
||||
conditions.append(attr == value)
|
||||
return conditions
|
||||
|
||||
def __order_by(self, order_by: List[Dict[str, str]]) -> List[ColumnElement]:
|
||||
"""
|
||||
获取排序字段
|
||||
|
||||
参数:
|
||||
- order_by (List[Dict[str, str]]): 排序字段列表,格式为 [{'id': 'asc'}, {'name': 'desc'}]
|
||||
|
||||
返回:
|
||||
- List[ColumnElement]: 排序字段列表
|
||||
|
||||
异常:
|
||||
- CustomException: 排序字段不存在时抛出异常
|
||||
"""
|
||||
columns = []
|
||||
for order in order_by:
|
||||
for field, direction in order.items():
|
||||
column = getattr(self.model, field)
|
||||
columns.append(desc(column) if direction.lower() == 'desc' else asc(column))
|
||||
return columns
|
||||
|
||||
def __loader_options(self, preload: Optional[List[Union[str, Any]]] = None) -> List[Any]:
|
||||
"""
|
||||
构建预加载选项
|
||||
|
||||
参数:
|
||||
- preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
|
||||
|
||||
返回:
|
||||
- List[Any]: 预加载选项列表
|
||||
"""
|
||||
options = []
|
||||
# 获取模型定义的默认加载选项
|
||||
model_loader_options = getattr(self.model, '__loader_options__', [])
|
||||
|
||||
# 合并所有需要预加载的选项
|
||||
all_preloads = set(model_loader_options)
|
||||
if preload:
|
||||
for opt in preload:
|
||||
if isinstance(opt, str):
|
||||
all_preloads.add(opt)
|
||||
elif preload == []:
|
||||
# 如果明确指定空列表,则不使用任何预加载
|
||||
all_preloads = set()
|
||||
|
||||
# 处理所有预加载选项
|
||||
for opt in all_preloads:
|
||||
if isinstance(opt, str):
|
||||
# 使用selectinload来避免在异步环境中的MissingGreenlet错误
|
||||
if hasattr(self.model, opt):
|
||||
options.append(selectinload(getattr(self.model, opt)))
|
||||
else:
|
||||
# 直接使用非字符串的加载选项
|
||||
options.append(opt)
|
||||
|
||||
return options
|
||||
118
后端源码/yifan.action-ai.cn/api-bak/app/core/base_model.py
Normal file
118
后端源码/yifan.action-ai.cn/api-bak/app/core/base_model.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import DateTime, String, Integer, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, declared_attr, relationship
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.api.v1.module_system.user.model import UserModel
|
||||
|
||||
from app.utils.common_util import uuid4_str
|
||||
|
||||
|
||||
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||
"""
|
||||
声明式基类
|
||||
|
||||
`AsyncAttrs <https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncAttrs>`__
|
||||
|
||||
`DeclarativeBase <https://docs.sqlalchemy.org/en/20/orm/declarative_config.html>`__
|
||||
|
||||
`mapped_column() <https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.mapped_column>`__
|
||||
|
||||
兼容 SQLite、MySQL 和 PostgreSQL
|
||||
"""
|
||||
|
||||
__abstract__: bool = True
|
||||
|
||||
|
||||
class ModelMixin(MappedBase):
|
||||
"""
|
||||
模型混入类 - 提供通用字段和功能
|
||||
|
||||
基础模型混合类 Mixin: 一种面向对象编程概念, 使结构变得更加清晰
|
||||
|
||||
数据隔离设计原则:
|
||||
==================
|
||||
数据权限 (created_id/updated_id):
|
||||
- 配合角色的data_scope字段实现精细化权限控制
|
||||
- 1:仅本人
|
||||
- 2:本部门
|
||||
- 3:本部门及以下
|
||||
- 4:全部数据
|
||||
- 5:自定义
|
||||
|
||||
SQLAlchemy加载策略说明:
|
||||
- select(默认): 延迟加载,访问时单独查询
|
||||
- joined: 使用LEFT JOIN预加载
|
||||
- selectin: 使用IN查询批量预加载(推荐用于一对多)
|
||||
- subquery: 使用子查询预加载
|
||||
- raise/raise_on_sql: 禁止加载
|
||||
- noload: 不加载,返回None
|
||||
- immediate: 立即加载
|
||||
- write_only: 只写不读
|
||||
- dynamic: 返回查询对象,支持进一步过滤
|
||||
"""
|
||||
__abstract__: bool = True
|
||||
|
||||
# 基础字段
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment='主键ID')
|
||||
uuid: Mapped[str] = mapped_column(String(64), default=uuid4_str, nullable=False, unique=True, comment='UUID全局唯一标识')
|
||||
status: Mapped[str] = mapped_column(String(10), default='0', nullable=False, comment="是否启用(0:启用 1:禁用)")
|
||||
description: Mapped[str | None] = mapped_column(Text, default=None, nullable=True, comment="备注/描述")
|
||||
created_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||
updated_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False, comment='更新时间')
|
||||
|
||||
|
||||
class UserMixin(MappedBase):
|
||||
"""
|
||||
用户审计字段 Mixin
|
||||
|
||||
用于记录数据的创建者和更新者
|
||||
用于实现数据权限中的"仅本人数据权限"
|
||||
"""
|
||||
__abstract__: bool = True
|
||||
|
||||
created_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('sys_user.id', ondelete="SET NULL", onupdate="CASCADE"),
|
||||
default=None,
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="创建人ID"
|
||||
)
|
||||
updated_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey('sys_user.id', ondelete="SET NULL", onupdate="CASCADE"),
|
||||
default=None,
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="更新人ID"
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
def created_by(cls) -> Mapped[Optional["UserModel"]]:
|
||||
"""
|
||||
创建人关联关系(延迟加载,避免循环依赖)
|
||||
"""
|
||||
return relationship(
|
||||
"UserModel",
|
||||
lazy="selectin",
|
||||
foreign_keys=lambda: cls.created_id,
|
||||
uselist=False
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
def updated_by(cls) -> Mapped[Optional["UserModel"]]:
|
||||
"""
|
||||
更新人关联关系(延迟加载,避免循环依赖)
|
||||
"""
|
||||
return relationship(
|
||||
"UserModel",
|
||||
lazy="selectin",
|
||||
foreign_keys=lambda: cls.updated_id,
|
||||
uselist=False
|
||||
)
|
||||
41
后端源码/yifan.action-ai.cn/api-bak/app/core/base_params.py
Normal file
41
后端源码/yifan.action-ai.cn/api-bak/app/core/base_params.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import Query
|
||||
|
||||
|
||||
class PaginationQueryParam:
|
||||
"""分页查询参数基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page_no: int = Query(default=1, description="当前页码", ge=1),
|
||||
page_size: int = Query(default=10, description="每页数量", ge=1, le=100),
|
||||
order_by: str | None = Query(default=None, description="排序字段,格式:field1,asc;field2,desc"),
|
||||
) -> None:
|
||||
"""
|
||||
初始化分页查询参数。
|
||||
|
||||
参数:
|
||||
- page_no (int | None): 当前页码,默认 None。
|
||||
- page_size (int | None): 每页数量,默认 None,最大 100。
|
||||
- order_by (str | None): 排序字段,格式 'field,asc;field2,desc'。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
self.page_no = page_no
|
||||
self.page_size = page_size
|
||||
# 将字符串格式的order_by转换为服务层需要的List[Dict[str, str]]格式
|
||||
if order_by:
|
||||
try:
|
||||
self.order_by = []
|
||||
for item in order_by.split(';'):
|
||||
if item.strip():
|
||||
field, direction = item.split(',', 1)
|
||||
self.order_by.append({field.strip(): direction.strip().lower()})
|
||||
except ValueError:
|
||||
# 如果解析失败,使用默认排序
|
||||
self.order_by = [{'updated_time': 'desc'}]
|
||||
else:
|
||||
self.order_by = [{'updated_time': 'desc'}]
|
||||
|
||||
66
后端源码/yifan.action-ai.cn/api-bak/app/core/base_schema.py
Normal file
66
后端源码/yifan.action-ai.cn/api-bak/app/core/base_schema.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.core.validator import DateTimeStr
|
||||
|
||||
|
||||
class UserInfoSchema(BaseModel):
|
||||
"""用户信息模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int | None = Field(default=None, description="用户ID")
|
||||
name: str | None = Field(default=None, description="用户姓名")
|
||||
username: str | None = Field(default=None, description="用户名")
|
||||
|
||||
|
||||
class CommonSchema(BaseModel):
|
||||
"""通用信息模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(description="编号ID")
|
||||
name: str = Field(description="名称")
|
||||
|
||||
|
||||
class BaseSchema(BaseModel):
|
||||
"""通用输出模型,包含基础字段和审计字段"""
|
||||
model_config = ConfigDict(from_attributes=True, coerce_numbers_to_str=True)
|
||||
|
||||
id: int | None = Field(default=None, description="主键ID")
|
||||
uuid: str | None = Field(default=None, description="UUID")
|
||||
status: str = Field(default="0", description="状态")
|
||||
description: str | None = Field(default=None, description="描述")
|
||||
created_time: DateTimeStr | None = Field(default=None, description="创建时间")
|
||||
updated_time: DateTimeStr | None = Field(default=None, description="更新时间")
|
||||
|
||||
|
||||
class UserBySchema(BaseModel):
|
||||
"""通用创建模型,包含基础字段和审计字段"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
created_id: int | None = Field(default=None, description="创建人ID")
|
||||
created_by: UserInfoSchema | None = Field(default=None, description="创建人信息")
|
||||
updated_id: int | None = Field(default=None, description="更新人ID")
|
||||
updated_by: UserInfoSchema | None = Field(default=None, description="更新人信息")
|
||||
|
||||
|
||||
class BatchSetAvailable(BaseModel):
|
||||
"""批量设置可用状态的请求模型"""
|
||||
ids: list[int] = Field(default_factory=list, description="ID列表")
|
||||
status: str = Field(default="0", description="是否可用")
|
||||
|
||||
|
||||
class UploadResponseSchema(BaseModel):
|
||||
"""上传响应模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
file_path: str | None = Field(default=None, description='新文件映射路径')
|
||||
file_name: str | None = Field(default=None, description='新文件名称')
|
||||
origin_name: str | None = Field(default=None, description='原文件名称')
|
||||
file_url: str | None = Field(default=None, description='新文件访问地址')
|
||||
|
||||
|
||||
class DownloadFileSchema(BaseModel):
|
||||
"""下载文件模型"""
|
||||
file_path: str = Field(..., description='新文件映射路径')
|
||||
file_name: str = Field(..., description='新文件名称')
|
||||
126
后端源码/yifan.action-ai.cn/api-bak/app/core/database.py
Normal file
126
后端源码/yifan.action-ai.cn/api-bak/app/core/database.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis import exceptions
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import create_engine, Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession, AsyncEngine
|
||||
|
||||
from app.core.logger import log
|
||||
from app.config.setting import settings
|
||||
from app.core.exceptions import CustomException
|
||||
|
||||
|
||||
def create_engine_and_session(
|
||||
db_url: str = settings.DB_URI
|
||||
) -> tuple[Engine, sessionmaker]:
|
||||
"""
|
||||
创建同步数据库引擎和会话工厂。
|
||||
|
||||
参数:
|
||||
- db_url (str): 数据库连接URL,默认从配置中获取。
|
||||
|
||||
返回:
|
||||
- tuple[Engine, sessionmaker]: 同步数据库引擎和会话工厂。
|
||||
"""
|
||||
try:
|
||||
if not settings.SQL_DB_ENABLE:
|
||||
raise CustomException(msg="请先开启数据库连接", data="请启用 app/config/setting.py: SQL_DB_ENABLE")
|
||||
# 同步数据库引擎
|
||||
engine: Engine = create_engine(
|
||||
url=db_url,
|
||||
echo=settings.DATABASE_ECHO,
|
||||
pool_pre_ping=settings.POOL_PRE_PING,
|
||||
pool_recycle=settings.POOL_RECYCLE,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f'❌ 数据库连接失败 {e}')
|
||||
raise
|
||||
else:
|
||||
# 同步数据库会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return engine, SessionLocal
|
||||
|
||||
def create_async_engine_and_session(
|
||||
db_url: str = settings.ASYNC_DB_URI
|
||||
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
|
||||
"""
|
||||
获取异步数据库会话连接。
|
||||
|
||||
返回:
|
||||
- tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 异步数据库引擎和会话工厂。
|
||||
"""
|
||||
try:
|
||||
if not settings.SQL_DB_ENABLE:
|
||||
raise CustomException(msg="请先开启数据库连接", data="请启用 app/config/setting.py: SQL_DB_ENABLE")
|
||||
# 异步数据库引擎
|
||||
async_engine: AsyncEngine = create_async_engine(
|
||||
url=db_url,
|
||||
echo=settings.DATABASE_ECHO,
|
||||
echo_pool=settings.ECHO_POOL,
|
||||
pool_pre_ping=settings.POOL_PRE_PING,
|
||||
future=settings.FUTURE,
|
||||
pool_recycle=settings.POOL_RECYCLE,
|
||||
pool_size=settings.POOL_SIZE,
|
||||
max_overflow=settings.MAX_OVERFLOW,
|
||||
pool_timeout=settings.POOL_TIMEOUT,
|
||||
pool_use_lifo=settings.POOL_USE_LIFO,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f'❌ 数据库连接失败 {e}')
|
||||
raise
|
||||
else:
|
||||
# 异步数据库会话工厂
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
autocommit=settings.AUTOCOMMIT,
|
||||
autoflush=settings.AUTOFETCH,
|
||||
expire_on_commit=settings.EXPIRE_ON_COMMIT,
|
||||
class_=AsyncSession
|
||||
)
|
||||
return async_engine, AsyncSessionLocal
|
||||
|
||||
engine, db_session = create_engine_and_session(settings.DB_URI)
|
||||
async_engine, async_db_session = create_async_engine_and_session(settings.ASYNC_DB_URI)
|
||||
|
||||
async def redis_connect(app: FastAPI, status: str) -> Redis | None:
|
||||
"""
|
||||
创建或关闭Redis连接。
|
||||
|
||||
参数:
|
||||
- app (FastAPI): FastAPI应用实例。
|
||||
- status (bool): 连接状态,True为创建连接,False为关闭连接。
|
||||
|
||||
返回:
|
||||
- Redis | None: Redis连接实例,如果连接失败则返回None。
|
||||
"""
|
||||
if not settings.REDIS_ENABLE:
|
||||
raise CustomException(msg="请先开启Redis连接", data="请启用 app/core/config.py: REDIS_ENABLE")
|
||||
|
||||
if status:
|
||||
try:
|
||||
rd = await Redis.from_url(
|
||||
url=settings.REDIS_URI,
|
||||
encoding='utf-8',
|
||||
decode_responses=True,
|
||||
health_check_interval=20,
|
||||
max_connections=settings.POOL_SIZE,
|
||||
socket_timeout=settings.POOL_TIMEOUT
|
||||
)
|
||||
app.state.redis = rd
|
||||
if await rd.ping():
|
||||
log.info("✅️ Redis连接成功...")
|
||||
return rd
|
||||
except exceptions.AuthenticationError as e:
|
||||
log.error(f"❌ 数据库 Redis 认证失败: {e}")
|
||||
raise
|
||||
except exceptions.TimeoutError as e:
|
||||
log.error(f"❌ 数据库 Redis 连接超时: {e}")
|
||||
raise
|
||||
except exceptions.RedisError as e:
|
||||
log.error(f"❌ 数据库 Redis 连接错误: {e}")
|
||||
raise
|
||||
else:
|
||||
await app.state.redis.aclose()
|
||||
log.info('✅️ Redis连接已关闭')
|
||||
178
后端源码/yifan.action-ai.cn/api-bak/app/core/dependencies.py
Normal file
178
后端源码/yifan.action-ai.cn/api-bak/app/core/dependencies.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from redis.asyncio.client import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, Request
|
||||
from fastapi import Depends
|
||||
|
||||
from app.common.enums import RedisInitKeyConfig
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.database import async_db_session
|
||||
from app.core.redis_crud import RedisCURD
|
||||
from app.core.security import OAuth2Schema, decode_access_token
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.user.model import UserModel
|
||||
from app.api.v1.module_system.role.model import RoleModel
|
||||
from app.api.v1.module_system.user.crud import UserCRUD
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
|
||||
async def db_getter() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话连接
|
||||
|
||||
返回:
|
||||
- AsyncSession: 数据库会话连接
|
||||
"""
|
||||
async with async_db_session() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
|
||||
async def redis_getter(request: Request) -> Redis:
|
||||
"""获取Redis连接
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象
|
||||
|
||||
返回:
|
||||
- Redis: Redis连接
|
||||
"""
|
||||
return request.app.state.redis
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(db_getter),
|
||||
redis: Redis = Depends(redis_getter),
|
||||
token: str = Depends(OAuth2Schema),
|
||||
) -> AuthSchema:
|
||||
"""获取当前用户
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象
|
||||
- db (AsyncSession): 数据库会话
|
||||
- redis (Redis): Redis连接
|
||||
- token (str): 访问令牌
|
||||
|
||||
返回:
|
||||
- AuthSchema: 认证信息模型
|
||||
"""
|
||||
if not token:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
# 处理Bearer token
|
||||
if token.startswith('Bearer'):
|
||||
token = token.split(' ')[1]
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if not payload or not hasattr(payload, 'is_refresh') or payload.is_refresh:
|
||||
raise CustomException(msg="非法凭证", code=10401, status_code=401)
|
||||
|
||||
online_user_info = payload.sub
|
||||
# 从Redis中获取用户信息
|
||||
user_info = json.loads(online_user_info) # 确保是字典类型
|
||||
|
||||
session_id = user_info.get("session_id")
|
||||
if not session_id:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
# 检查用户是否在线
|
||||
online_ok = await RedisCURD(redis).exists(key=f'{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
||||
if not online_ok:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
# 关闭数据权限过滤,避免当前用户查询被拦截
|
||||
auth = AuthSchema(db=db, check_data_scope=False)
|
||||
username = user_info.get("user_name")
|
||||
if not username:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
# 获取用户信息,使用深层预加载确保RoleModel.creator被正确加载
|
||||
user = await UserCRUD(auth).get_by_username_crud(
|
||||
username=username,
|
||||
preload=[
|
||||
"dept",
|
||||
selectinload(UserModel.roles),
|
||||
"positions",
|
||||
"created_by"
|
||||
]
|
||||
)
|
||||
if not user:
|
||||
raise CustomException(msg="用户不存在", code=10401, status_code=401)
|
||||
if not user.status:
|
||||
raise CustomException(msg="用户已被停用", code=10401, status_code=401)
|
||||
|
||||
# 设置请求上下文
|
||||
request.state.user = user
|
||||
request.state.user_id = user.id
|
||||
request.state.user_username = user.username
|
||||
request.scope["user_id"] = user.id
|
||||
request.scope["user_username"] = user.username
|
||||
|
||||
# 过滤可用的角色和职位
|
||||
if hasattr(user, 'roles'):
|
||||
user.roles = [role for role in user.roles if role and role.status]
|
||||
if hasattr(user, 'positions'):
|
||||
user.positions = [pos for pos in user.positions if pos and pos.status]
|
||||
|
||||
auth.user = user
|
||||
return auth
|
||||
|
||||
|
||||
class AuthPermission:
|
||||
"""权限验证类"""
|
||||
|
||||
def __init__(self, permissions: list[str] | None = None, check_data_scope: bool = True) -> None:
|
||||
"""
|
||||
初始化权限验证
|
||||
|
||||
参数:
|
||||
- permissions (list[str] | None): 权限标识列表。
|
||||
- check_data_scope (bool): 是否启用严格模式校验。
|
||||
"""
|
||||
self.permissions = permissions or []
|
||||
self.check_data_scope = check_data_scope
|
||||
|
||||
async def __call__(self, auth: AuthSchema = Depends(get_current_user)) -> AuthSchema:
|
||||
"""
|
||||
调用权限验证
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息对象。
|
||||
|
||||
返回:
|
||||
- AuthSchema: 认证信息对象。
|
||||
"""
|
||||
auth.check_data_scope = self.check_data_scope
|
||||
|
||||
# 超级管理员直接通过
|
||||
if auth.user and auth.user.is_superuser:
|
||||
return auth
|
||||
|
||||
# 无需验证权限
|
||||
if not self.permissions:
|
||||
return auth
|
||||
|
||||
# 超级管理员权限标识
|
||||
if "*" in self.permissions or "*:*:*" in self.permissions:
|
||||
return auth
|
||||
|
||||
# 检查用户是否有角色
|
||||
if not auth.user or not auth.user.roles:
|
||||
raise CustomException(msg="无权限操作", code=10403, status_code=403)
|
||||
|
||||
# 获取用户权限集合
|
||||
user_permissions = {
|
||||
menu.permission
|
||||
for role in auth.user.roles
|
||||
for menu in role.menus
|
||||
if role.status and menu.permission and menu.status
|
||||
}
|
||||
|
||||
# 权限验证 - 满足任一权限即可
|
||||
if not any(perm in user_permissions for perm in self.permissions):
|
||||
log.error(f"用户缺少任何所需的权限: {self.permissions}")
|
||||
raise CustomException(msg="无权限操作", code=10403, status_code=403)
|
||||
|
||||
return auth
|
||||
356
后端源码/yifan.action-ai.cn/api-bak/app/core/discover.py
Normal file
356
后端源码/yifan.action-ai.cn/api-bak/app/core/discover.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
集中式路由发现与注册
|
||||
|
||||
约定:
|
||||
- 仅扫描 `app.api.v1` 包内,顶级目录以 `module_` 开头的模块。
|
||||
- 在各模块任意子目录下的 `controller.py` 中定义的 `APIRouter` 实例会自动被注册。
|
||||
- 顶级目录 `module_xxx` 会映射为容器路由前缀 `/<xxx>`。
|
||||
|
||||
设计目标:
|
||||
- 稳定、可预测:有序扫描与注册,确定性日志输出。
|
||||
- 简洁、易维护:职责拆分成小函数,类型提示与清晰注释。
|
||||
- 安全、可控:去重处理、异常分层记录、可配置的前缀映射与忽略规则。
|
||||
- 灵活、可扩展:基于类的设计,支持配置自定义和实例化多套路由系统。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Any
|
||||
from functools import wraps
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
def _log_error_handling(func: Callable) -> Callable:
|
||||
"""错误处理装饰器,用于统一捕获和记录方法执行过程中的异常"""
|
||||
@wraps(func)
|
||||
def wrapper(self: 'DiscoverRouter', *args: Any, **kwargs: Any) -> Any:
|
||||
method_name = func.__name__
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except ModuleNotFoundError as e:
|
||||
log.error(f"❌️ 模块未找到 [{method_name}]: {str(e)}")
|
||||
raise
|
||||
except ImportError as e:
|
||||
log.error(f"❌️ 导入错误 [{method_name}]: {str(e)}")
|
||||
raise
|
||||
except AttributeError as e:
|
||||
log.error(f"❌️ 属性错误 [{method_name}]: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"❌️ 未知错误 [{method_name}]: {str(e)}")
|
||||
# 在调试模式下打印完整堆栈信息
|
||||
if getattr(self, 'debug', False):
|
||||
import traceback
|
||||
log.error(traceback.format_exc())
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
|
||||
class DiscoverRouter:
|
||||
"""
|
||||
路由自动发现与注册器
|
||||
|
||||
提供基于约定的路由自动发现与注册功能,支持自定义配置和灵活扩展。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_prefix: str = "module_",
|
||||
base_package: str = "app.api.v1",
|
||||
prefix_map: dict[str, str] | None = None,
|
||||
exclude_dirs: set[str] | None = None,
|
||||
exclude_files: set[str] | None = None,
|
||||
auto_discover: bool = True,
|
||||
debug: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
初始化路由发现注册器
|
||||
|
||||
参数:
|
||||
- module_prefix: 模块目录前缀,默认为 "module_"
|
||||
- base_package: 基础包名,默认为 "app.api.v1"
|
||||
- prefix_map: 前缀映射字典,用于自定义路由前缀
|
||||
- exclude_dirs: 排除的目录集合
|
||||
- exclude_files: 排除的文件集合
|
||||
- auto_discover: 是否在初始化时自动执行发现和注册,默认为 True
|
||||
- debug: 是否启用调试模式,在调试模式下会输出更详细的错误信息,默认为 False
|
||||
"""
|
||||
self.module_prefix = module_prefix
|
||||
self.base_package = base_package
|
||||
self.prefix_map = prefix_map or {}
|
||||
self.exclude_dirs = exclude_dirs or set()
|
||||
self.exclude_files = exclude_files or set()
|
||||
self.debug = debug
|
||||
self._router = APIRouter()
|
||||
self._seen_router_ids: set[int] = set()
|
||||
self._discovery_stats: dict[str, int] = {
|
||||
"scanned_files": 0,
|
||||
"imported_modules": 0,
|
||||
"included_routers": 0,
|
||||
"container_count": 0
|
||||
}
|
||||
|
||||
# 自动执行发现和注册
|
||||
if auto_discover:
|
||||
self.discover_and_register()
|
||||
|
||||
@property
|
||||
def router(self) -> APIRouter:
|
||||
"""获取根路由实例"""
|
||||
return self._router
|
||||
|
||||
@property
|
||||
def discovery_stats(self) -> dict[str, int]:
|
||||
"""获取路由发现统计信息"""
|
||||
return self._discovery_stats.copy()
|
||||
|
||||
@_log_error_handling
|
||||
def _get_base_dir_and_pkg(self) -> tuple[Path, str]:
|
||||
"""定位基础包的文件系统路径与包名。
|
||||
|
||||
返回:
|
||||
- (Path, str): (包的路径, 包名)
|
||||
"""
|
||||
base_pkg = importlib.import_module(self.base_package)
|
||||
base_dir = Path(next(iter(base_pkg.__path__)))
|
||||
log.info(f"📁 基础包路径: {base_dir}, 包名: {base_pkg.__name__}")
|
||||
return base_dir, base_pkg.__name__
|
||||
|
||||
def _iter_controller_files(self, base_dir: Path) -> Iterable[Path]:
|
||||
"""递归查找并返回所有 `controller.py` 文件,按路径排序保证确定性。"""
|
||||
try:
|
||||
files = sorted(base_dir.rglob("controller.py"), key=lambda p: p.as_posix())
|
||||
log.info(f"🔍 发现 {len(files)} 个控制器文件")
|
||||
return files
|
||||
except PermissionError as e:
|
||||
log.error(f"❌️ 权限错误: 无法访问目录 {base_dir}: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"❌️ 查找控制器文件失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _resolve_prefix(self, top_module: str) -> str | None:
|
||||
"""将顶级模块目录名解析为容器前缀。"""
|
||||
if top_module in self.exclude_dirs:
|
||||
if self.debug:
|
||||
log.warning(f"⚠️ 目录 {top_module} 被排除")
|
||||
return None
|
||||
if not top_module.startswith(self.module_prefix):
|
||||
if self.debug:
|
||||
log.warning(f"⚠️ 目录 {top_module} 不符合前缀约定 {self.module_prefix}")
|
||||
return None
|
||||
|
||||
mapped = self.prefix_map.get(top_module)
|
||||
if mapped:
|
||||
log.info(f"🔄 模块 {top_module} 映射到前缀 {mapped}")
|
||||
return mapped
|
||||
|
||||
prefix = f"/{top_module[len(self.module_prefix):]}"
|
||||
if self.debug:
|
||||
log.debug(f"📋 模块 {top_module} 使用默认前缀 {prefix}")
|
||||
return prefix
|
||||
|
||||
@_log_error_handling
|
||||
def _include_module_routers(self, mod: object, container: APIRouter) -> int:
|
||||
"""将模块中的所有 `APIRouter` 实例包含到指定容器路由中。
|
||||
|
||||
返回:
|
||||
- int: 新增注册的路由数量
|
||||
"""
|
||||
from fastapi import APIRouter as _APIRouter
|
||||
|
||||
added = 0
|
||||
mod_name = getattr(mod, "__name__", "<unknown>")
|
||||
router_count = 0
|
||||
|
||||
for attr_name in dir(mod):
|
||||
attr = getattr(mod, attr_name, None)
|
||||
if isinstance(attr, _APIRouter):
|
||||
router_count += 1
|
||||
rid = id(attr)
|
||||
if rid in self._seen_router_ids:
|
||||
log.warning(f"⚠️ 路由 {attr_name} 在模块 {mod_name} 中已注册,跳过重复注册")
|
||||
continue
|
||||
|
||||
self._seen_router_ids.add(rid)
|
||||
container.include_router(attr)
|
||||
added += 1
|
||||
log.info(f"➕ 注册路由 {attr_name} 到容器")
|
||||
|
||||
if router_count == 0:
|
||||
log.warning(f"⚠️ 模块 {mod_name} 中未发现 APIRouter 实例")
|
||||
|
||||
return added
|
||||
|
||||
@_log_error_handling
|
||||
def discover_and_register(self) -> dict[str, int]:
|
||||
"""
|
||||
执行路由发现与注册
|
||||
|
||||
返回:
|
||||
- dict[str, int]: 包含发现统计信息的字典
|
||||
- scanned_files: 扫描的文件数量
|
||||
- imported_modules: 导入的模块数量
|
||||
- included_routers: 注册的路由数量
|
||||
- container_count: 容器数量
|
||||
"""
|
||||
log.info("🚀 开始路由发现与注册...")
|
||||
base_dir, base_pkg = self._get_base_dir_and_pkg()
|
||||
containers: dict[str, APIRouter] = {}
|
||||
container_counts: dict[str, int] = {}
|
||||
|
||||
scanned_files = 0
|
||||
imported_modules = 0
|
||||
included_routers = 0
|
||||
|
||||
try:
|
||||
for file in self._iter_controller_files(base_dir):
|
||||
rel_path = file.relative_to(base_dir).as_posix()
|
||||
scanned_files += 1
|
||||
|
||||
if rel_path in self.exclude_files:
|
||||
log.warning(f"⚠️ 文件 {rel_path} 被排除")
|
||||
continue
|
||||
|
||||
parts = file.relative_to(base_dir).parts
|
||||
if len(parts) < 2:
|
||||
log.warning(f"⚠️ 文件路径不完整: {rel_path},跳过")
|
||||
continue
|
||||
|
||||
top_module = parts[0]
|
||||
prefix = self._resolve_prefix(top_module)
|
||||
if not prefix:
|
||||
continue
|
||||
|
||||
# 拼接模块导入路径
|
||||
mod_path = ".".join((base_pkg,) + tuple(parts[:-1]) + ("controller",))
|
||||
try:
|
||||
mod = importlib.import_module(mod_path)
|
||||
imported_modules += 1
|
||||
log.info(f"📥 导入模块: {mod_path}")
|
||||
except ModuleNotFoundError as e:
|
||||
log.error(f"❌️ 未找到控制器模块: {mod_path} -> {str(e)}")
|
||||
continue
|
||||
except ImportError as e:
|
||||
log.error(f"❌️ 导入控制器失败: {mod_path} -> {str(e)}")
|
||||
continue
|
||||
|
||||
container = containers.setdefault(prefix, APIRouter(prefix=prefix))
|
||||
try:
|
||||
added = self._include_module_routers(mod, container)
|
||||
included_routers += added
|
||||
container_counts[prefix] = container_counts.get(prefix, 0) + added
|
||||
except Exception as e:
|
||||
log.error(f"❌️ 注册控制器路由失败: {mod_path} -> {str(e)}")
|
||||
|
||||
# 将容器路由按前缀名称排序后注册到根路由,保证顺序稳定
|
||||
for prefix in sorted(containers.keys()):
|
||||
container = containers[prefix]
|
||||
rid = id(container)
|
||||
if rid in self._seen_router_ids:
|
||||
continue
|
||||
self._seen_router_ids.add(rid)
|
||||
self._router.include_router(container)
|
||||
# 更丰富的注册日志(含路由数量)
|
||||
count = container_counts.get(prefix, 0)
|
||||
log.info(f"✅️ 已注册模块容器: {prefix} (路由数: {count})")
|
||||
|
||||
# 更新统计信息
|
||||
stats = {
|
||||
"scanned_files": scanned_files,
|
||||
"imported_modules": imported_modules,
|
||||
"included_routers": included_routers,
|
||||
"container_count": len(containers)
|
||||
}
|
||||
self._discovery_stats = stats
|
||||
|
||||
# 生成总结日志
|
||||
log.info(
|
||||
(
|
||||
f"✅️ 路由发现完成: 扫描文件 {scanned_files}, "
|
||||
f"导入模块 {imported_modules}, 注册路由 {included_routers}, "
|
||||
f"容器 {len(containers)}"
|
||||
)
|
||||
)
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
log.error(f"❌️ 路由发现与注册过程失败: {str(e)}")
|
||||
# 确保返回统计信息,即使过程中出错
|
||||
return self._discovery_stats
|
||||
|
||||
def set_debug(self, debug: bool) -> 'DiscoverRouter':
|
||||
"""设置调试模式
|
||||
|
||||
参数:
|
||||
- debug: 是否开启调试模式
|
||||
|
||||
返回:
|
||||
- self: 支持链式调用
|
||||
"""
|
||||
self.debug = debug
|
||||
log_level = "DEBUG" if debug else "INFO"
|
||||
log.info(f"⚙️ 调试模式已{'开启' if debug else '关闭'},日志级别: {log_level}")
|
||||
return self
|
||||
|
||||
def add_exclude_dir(self, dir_name: str) -> 'DiscoverRouter':
|
||||
"""添加排除的目录
|
||||
|
||||
参数:
|
||||
- dir_name: 要排除的目录名称
|
||||
|
||||
返回:
|
||||
- self: 支持链式调用
|
||||
"""
|
||||
self.exclude_dirs.add(dir_name)
|
||||
log.info(f"📝 添加排除目录: {dir_name}")
|
||||
return self
|
||||
|
||||
def add_prefix_map(self, module_name: str, prefix: str) -> 'DiscoverRouter':
|
||||
"""添加前缀映射
|
||||
|
||||
参数:
|
||||
- module_name: 模块名称
|
||||
- prefix: 对应的路由前缀
|
||||
|
||||
返回:
|
||||
- self: 支持链式调用
|
||||
"""
|
||||
self.prefix_map[module_name] = prefix
|
||||
log.info(f"📝 添加前缀映射: {module_name} -> {prefix}")
|
||||
return self
|
||||
|
||||
@_log_error_handling
|
||||
def register_router(self, router: APIRouter, tags: list[str | Enum] | None = None) -> None:
|
||||
"""手动注册一个路由实例
|
||||
|
||||
参数:
|
||||
- router: 要注册的 APIRouter 实例
|
||||
- tags: 路由标签,用于 API 文档分组
|
||||
"""
|
||||
rid = id(router)
|
||||
if rid not in self._seen_router_ids:
|
||||
self._seen_router_ids.add(rid)
|
||||
self._router.include_router(router, tags=tags)
|
||||
log.info(f"📌 手动注册路由,标签: {tags}")
|
||||
else:
|
||||
log.warning(f"⚠️ 路由已存在,跳过重复注册")
|
||||
|
||||
|
||||
# 创建默认实例并执行自动发现注册
|
||||
_discoverer = DiscoverRouter()
|
||||
|
||||
# 保持向后兼容,导出原始的 router 变量
|
||||
router = _discoverer.router
|
||||
|
||||
# 导出 DiscoverRouter 类供外部使用
|
||||
__all__ = ["DiscoverRouter", "router"]
|
||||
|
||||
|
||||
# 执行自动发现注册(已由 DiscoverRouter 实例内部处理)
|
||||
211
后端源码/yifan.action-ai.cn/api-bak/app/core/exceptions.py
Normal file
211
后端源码/yifan.action-ai.cn/api-bak/app/core/exceptions.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
||||
from pydantic_validation_decorator import FieldValidationError
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.common.constant import RET
|
||||
from app.common.response import ErrorResponse
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
"""
|
||||
自定义异常基类
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
msg: str = RET.EXCEPTION.msg,
|
||||
code: int = RET.EXCEPTION.code,
|
||||
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
data: Any | None = None,
|
||||
success: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
初始化异常对象。
|
||||
|
||||
参数:
|
||||
- msg (str): 错误消息。
|
||||
- code (int): 业务状态码。
|
||||
- status_code (int): HTTP 状态码。
|
||||
- data (Any | None): 附加数据。
|
||||
- success (bool): 是否成功标记,默认 False。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
super().__init__(msg) # 调用父类初始化方法
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
self.msg = msg
|
||||
self.data = data
|
||||
self.success = success
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回异常消息
|
||||
|
||||
返回:
|
||||
- str: 异常消息
|
||||
"""
|
||||
return self.msg
|
||||
|
||||
|
||||
def handle_exception(app: FastAPI):
|
||||
"""
|
||||
注册全局异常处理器。
|
||||
|
||||
参数:
|
||||
- app (FastAPI): 应用实例。
|
||||
|
||||
返回:
|
||||
- None
|
||||
"""
|
||||
@app.exception_handler(CustomException)
|
||||
async def CustomExceptionHandler(request: Request, exc: CustomException) -> JSONResponse:
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (CustomException): 自定义异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
log.error(f"[自定义异常] {request.method} {request.url.path} | 错误码: {exc.code} | 错误信息: {exc.msg} | 详情: {exc.data}")
|
||||
return ErrorResponse(msg=exc.msg, code=exc.code, status_code=exc.status_code, data=exc.data)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def HttpExceptionHandler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""
|
||||
HTTP异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (HTTPException): HTTP异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
log.error(f"[HTTP异常] {request.method} {request.url.path} | 状态码: {exc.status_code} | 错误信息: {exc.detail}")
|
||||
return ErrorResponse(msg=exc.detail, status_code=exc.status_code)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def ValidationExceptionHandler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
"""
|
||||
请求参数验证异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (RequestValidationError): 请求参数验证异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
error_mapping = {
|
||||
"Field required": "请求失败,缺少必填项!",
|
||||
"value is not a valid list": "类型错误,提交参数应该为列表!",
|
||||
"value is not a valid int": "类型错误,提交参数应该为整数!",
|
||||
"value could not be parsed to a boolean": "类型错误,提交参数应该为布尔值!",
|
||||
"Input should be a valid list": "类型错误,输入应该是一个有效的列表!"
|
||||
}
|
||||
raw_msg = exc.errors()[0].get('msg')
|
||||
msg = error_mapping.get(raw_msg, raw_msg)
|
||||
# 去掉Pydantic默认的前缀“Value error”, 仅保留具体提示内容
|
||||
if isinstance(msg, str) and msg.startswith("Value error"):
|
||||
if "," in msg:
|
||||
msg = msg.split(",", 1)[1].strip()
|
||||
else:
|
||||
msg = msg.replace("Value error", "").strip()
|
||||
log.error(f"[参数验证异常] {request.method} {request.url.path} | 错误信息: {msg} | 原始错误: {exc.errors()}")
|
||||
# 如果是bytes类型(如文件上传),不返回原始数据,避免JSON序列化失败
|
||||
response_data = None if isinstance(exc.body, bytes) else exc.body
|
||||
return ErrorResponse(msg=str(msg), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, data=response_data)
|
||||
|
||||
@app.exception_handler(ResponseValidationError)
|
||||
async def ResponseValidationHandle(request: Request, exc: ResponseValidationError) -> JSONResponse:
|
||||
"""
|
||||
响应参数验证异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (ResponseValidationError): 响应参数验证异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
log.error(f"[响应验证异常] {request.method} {request.url.path} | 错误信息: 响应数据格式错误 | 详情: {exc.errors()}")
|
||||
# 如果是bytes类型(如文件上传),不返回原始数据,避免JSON序列化失败
|
||||
response_data = None if isinstance(exc.body, bytes) else exc.body
|
||||
return ErrorResponse(msg="服务器响应格式错误", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, data=response_data)
|
||||
|
||||
@app.exception_handler(SQLAlchemyError)
|
||||
async def SQLAlchemyExceptionHandler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
|
||||
"""
|
||||
数据库异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (SQLAlchemyError): 数据库异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
error_msg = '数据库操作失败'
|
||||
exc_type = type(exc).__name__
|
||||
|
||||
# 对于生产环境,返回通用错误消息
|
||||
log.error(f"[数据库异常] {request.method} {request.url.path} | 错误类型: {exc_type} | 错误详情: {str(exc)}")
|
||||
return ErrorResponse(msg=f'{error_msg}: {exc_type}', status_code=status.HTTP_400_BAD_REQUEST, data=str(exc))
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def ValueExceptionHandler(request: Request, exc: ValueError) -> JSONResponse:
|
||||
"""
|
||||
值异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (ValueError): 值异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
log.exception(f"[值异常] {request.method} {request.url.path} | 错误信息: {str(exc)}")
|
||||
return ErrorResponse(msg=str(exc), status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@app.exception_handler(FieldValidationError)
|
||||
async def FieldValidationExceptionHandler(request: Request, exc: FieldValidationError) -> JSONResponse:
|
||||
"""
|
||||
字段验证异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (FieldValidationError): 字段验证异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
log.error(f"[字段验证异常] {request.method} {request.url.path} | 错误信息: {exc.message}")
|
||||
return ErrorResponse(msg=exc.message, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def AllExceptionHandler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""
|
||||
全局异常处理器
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象。
|
||||
- exc (Exception): 异常实例。
|
||||
|
||||
返回:
|
||||
- JSONResponse: 包含错误信息的 JSON 响应。
|
||||
"""
|
||||
exc_type = type(exc).__name__
|
||||
log.error(f"[未捕获异常] {request.method} {request.url.path} | 错误类型: {exc_type} | 错误详情: {str(exc)}")
|
||||
# 对于未捕获的异常,返回通用错误信息
|
||||
return ErrorResponse(msg='服务器内部错误', status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, data=None)
|
||||
164
后端源码/yifan.action-ai.cn/api-bak/app/core/logger.py
Normal file
164
后端源码/yifan.action-ai.cn/api-bak/app/core/logger.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import atexit
|
||||
from typing_extensions import override
|
||||
from loguru import logger
|
||||
|
||||
from app.config.path_conf import LOG_DIR
|
||||
from app.config.setting import settings
|
||||
|
||||
# 全局变量记录日志处理器ID
|
||||
_logger_handlers = []
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""
|
||||
日志拦截处理器:将所有 Python 标准日志重定向到 Loguru
|
||||
|
||||
工作原理:
|
||||
1. 继承自 logging.Handler
|
||||
2. 重写 emit 方法处理日志记录
|
||||
3. 将标准库日志转换为 Loguru 格式
|
||||
"""
|
||||
@override
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# 尝试获取日志级别名称
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# 获取调用帧信息,增加None检查
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
# 使用 Loguru 记录日志
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level,
|
||||
record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def cleanup_logging():
|
||||
"""
|
||||
清理日志资源
|
||||
在程序退出时调用,确保所有日志处理器被正确关闭
|
||||
"""
|
||||
global _logger_handlers
|
||||
|
||||
for handler_id in _logger_handlers:
|
||||
try:
|
||||
logger.remove(handler_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_logger_handlers.clear()
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""
|
||||
配置日志系统
|
||||
|
||||
功能:
|
||||
1. 控制台彩色输出
|
||||
2. 文件日志轮转
|
||||
3. 错误日志单独存储
|
||||
4. 智能异步策略:开发环境同步(避免reload资源泄漏),生产环境异步(高性能)
|
||||
"""
|
||||
global _logger_handlers
|
||||
|
||||
if _logger_handlers:
|
||||
return
|
||||
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
|
||||
# 添加上下文信息
|
||||
_ = logger.configure(extra={"app_name": "FastapiAdmin"})
|
||||
# 步骤1:移除默认处理器
|
||||
logger.remove()
|
||||
|
||||
# 步骤2:定义日志格式
|
||||
log_format = (
|
||||
# 时间信息
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
# 日志级别,居中对齐
|
||||
"<level>{level: <8}</level> | "
|
||||
# 文件、函数和行号
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
# 日志消息
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# 智能选择异步策略:开发环境禁用异步(避免reload时资源泄漏),生产环境启用异步(提升性能)
|
||||
use_async = not settings.DEBUG
|
||||
|
||||
# 步骤3:配置控制台输出
|
||||
handler_id = logger.add(
|
||||
sys.stdout,
|
||||
format=log_format,
|
||||
level="DEBUG" if settings.DEBUG else "INFO",
|
||||
enqueue=use_async, # 开发同步,生产异步
|
||||
backtrace=True, # 显示完整的异常回溯
|
||||
diagnose=True, # 显示变量值等诊断信息
|
||||
colorize=True # 启用彩色输出
|
||||
)
|
||||
_logger_handlers.append(handler_id)
|
||||
|
||||
# 步骤4:创建日志目录
|
||||
log_dir = LOG_DIR
|
||||
# 确保日志目录存在,如果不存在则创建
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 步骤5:配置常规日志文件
|
||||
handler_id = logger.add(
|
||||
str(log_dir / "info.log"),
|
||||
format=log_format,
|
||||
level="INFO",
|
||||
rotation="100 MB", # 按文件大小轮转,避免 Windows 文件锁问题
|
||||
retention=30, # 日志保留天数,超过此天数的日志文件将被自动清理
|
||||
compression="gz",
|
||||
encoding="utf-8",
|
||||
enqueue=use_async, # 开发同步,生产异步
|
||||
catch=True # 捕获日志处理过程中的异常,避免程序崩溃
|
||||
)
|
||||
_logger_handlers.append(handler_id)
|
||||
|
||||
# 步骤6:配置错误日志文件
|
||||
handler_id = logger.add(
|
||||
str(log_dir / "error.log"),
|
||||
format=log_format,
|
||||
level="ERROR",
|
||||
rotation="100 MB", # 按文件大小轮转,避免 Windows 文件锁问题
|
||||
retention=30, # 日志保留天数,超过此天数的日志文件将被自动清理
|
||||
compression="gz",
|
||||
encoding="utf-8",
|
||||
enqueue=use_async, # 开发同步,生产异步
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
catch=True # 捕获日志处理过程中的异常,避免程序崩溃
|
||||
)
|
||||
_logger_handlers.append(handler_id)
|
||||
|
||||
# 步骤7:配置标准库日志
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level="DEBUG" if settings.DEBUG else "INFO", force=True)
|
||||
logger_name_list = [name for name in logging.root.manager.loggerDict]
|
||||
|
||||
# 步骤8:配置第三方库日志
|
||||
for logger_name in logger_name_list:
|
||||
_logger = logging.getLogger(logger_name)
|
||||
_logger.handlers = [InterceptHandler()]
|
||||
_logger.propagate = False
|
||||
|
||||
# 注册退出清理函数
|
||||
atexit.register(cleanup_logging)
|
||||
|
||||
setup_logging()
|
||||
|
||||
log = logger
|
||||
213
后端源码/yifan.action-ai.cn/api-bak/app/core/middlewares.py
Normal file
213
后端源码/yifan.action-ai.cn/api-bak/app/core/middlewares.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import time
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.requests import Request
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from starlette.responses import Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.common.response import ErrorResponse
|
||||
from app.config.setting import settings
|
||||
from app.core.logger import log
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.security import decode_access_token
|
||||
from app.api.v1.module_system.params.service import ParamsService
|
||||
|
||||
|
||||
class CustomCORSMiddleware(CORSMiddleware):
|
||||
"""CORS跨域中间件"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
allow_origins=settings.ALLOW_ORIGINS,
|
||||
allow_methods=settings.ALLOW_METHODS,
|
||||
allow_headers=settings.ALLOW_HEADERS,
|
||||
allow_credentials=settings.ALLOW_CREDENTIALS,
|
||||
expose_headers=settings.CORS_EXPOSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
class RequestLogMiddleware:
|
||||
"""
|
||||
记录请求日志中间件
|
||||
|
||||
注意:使用纯 ASGI 中间件实现,避免 BaseHTTPMiddleware 缓冲流式响应的问题。
|
||||
BaseHTTPMiddleware 会等待整个响应体完成后才返回,这会破坏流式响应的功能。
|
||||
"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_id_from_headers(headers: list) -> str | None:
|
||||
"""
|
||||
从请求头中提取session_id
|
||||
|
||||
参数:
|
||||
- headers (list): ASGI 格式的请求头列表
|
||||
|
||||
返回:
|
||||
- str | None: 会话ID,如果无法提取则返回None
|
||||
"""
|
||||
try:
|
||||
authorization = None
|
||||
for key, value in headers:
|
||||
if key == b'authorization':
|
||||
authorization = value.decode('utf-8')
|
||||
break
|
||||
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
# 处理Bearer token
|
||||
token = authorization.replace('Bearer ', '').strip()
|
||||
|
||||
# 解码token
|
||||
payload = decode_access_token(token)
|
||||
if not payload or not hasattr(payload, 'sub'):
|
||||
return None
|
||||
|
||||
# 从payload中提取session_id
|
||||
user_info = json.loads(payload.sub)
|
||||
return user_info.get("session_id")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_header_value(headers: list, key: bytes) -> str | None:
|
||||
"""从头列表中获取指定的头值"""
|
||||
for k, v in headers:
|
||||
if k.lower() == key.lower():
|
||||
return v.decode('utf-8')
|
||||
return None
|
||||
|
||||
async def __call__(self, scope, receive, send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
request = Request(scope, receive)
|
||||
|
||||
# 尝试提取session_id
|
||||
session_id = self._extract_session_id_from_headers(scope.get('headers', []))
|
||||
|
||||
# 组装请求日志字段
|
||||
client_host = scope.get('client', ['未知'])[0] if scope.get('client') else '未知'
|
||||
log_fields = [
|
||||
f"请求来源: {client_host}",
|
||||
f"请求方法: {scope.get('method', 'UNKNOWN')}",
|
||||
f"请求路径: {scope.get('path', '/')}",
|
||||
]
|
||||
log.info(log_fields)
|
||||
|
||||
# 获取请求路径
|
||||
path = scope.get("path")
|
||||
|
||||
# 尝试获取客户端真实IP
|
||||
headers = scope.get('headers', [])
|
||||
x_forwarded_for = self._get_header_value(headers, b'x-forwarded-for')
|
||||
if x_forwarded_for:
|
||||
request_ip = x_forwarded_for.split(',')[0].strip()
|
||||
else:
|
||||
request_ip = client_host
|
||||
|
||||
# 检查是否启用演示模式
|
||||
demo_enable = False
|
||||
ip_white_list = []
|
||||
white_api_list_path = []
|
||||
ip_black_list = []
|
||||
|
||||
try:
|
||||
# 从应用实例获取Redis连接
|
||||
redis = request.app.state.redis
|
||||
if redis:
|
||||
# 使用ParamsService获取系统配置
|
||||
system_config = await ParamsService.get_system_config_for_middleware(redis)
|
||||
demo_enable = system_config["demo_enable"]
|
||||
ip_white_list = system_config["ip_white_list"]
|
||||
white_api_list_path = system_config["white_api_list_path"]
|
||||
ip_black_list = system_config["ip_black_list"]
|
||||
except Exception as e:
|
||||
log.error(f"获取系统配置失败: {e}")
|
||||
|
||||
# 检查是否需要拦截请求
|
||||
should_block = False
|
||||
block_reason = ""
|
||||
method = scope.get('method', '')
|
||||
|
||||
# 1. 首先检查IP是否在黑名单中
|
||||
if request_ip and request_ip in ip_black_list:
|
||||
should_block = True
|
||||
block_reason = f"IP地址 {request_ip} 在黑名单中"
|
||||
|
||||
# 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
|
||||
elif demo_enable in ["true", "True"] and method != "GET":
|
||||
is_ip_whitelisted = request_ip in ip_white_list
|
||||
is_path_whitelisted = path in white_api_list_path
|
||||
|
||||
if not is_ip_whitelisted and not is_path_whitelisted:
|
||||
should_block = True
|
||||
block_reason = f"演示模式下拦截非GET请求,IP: {request_ip}, 路径: {path}"
|
||||
|
||||
if should_block:
|
||||
log.warning([
|
||||
f"会话ID: {session_id or '未认证'}",
|
||||
f"请求被拦截: {block_reason}",
|
||||
f"请求来源: {request_ip}",
|
||||
f"请求方法: {method}",
|
||||
f"请求路径: {path}",
|
||||
f"演示模式: {demo_enable}"
|
||||
])
|
||||
# 返回错误响应
|
||||
response = ErrorResponse(msg="演示环境,禁止操作")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# 用于追踪响应状态
|
||||
response_started = False
|
||||
response_status = 0
|
||||
|
||||
async def send_wrapper(message):
|
||||
nonlocal response_started, response_status
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
response_status = message.get("status", 0)
|
||||
|
||||
# 计算处理时间并添加到响应头
|
||||
process_time = round(time.time() - start_time, 5)
|
||||
headers = list(message.get("headers", []))
|
||||
headers.append((b"x-process-time", str(process_time).encode()))
|
||||
message = {**message, "headers": headers}
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
# 如果是最后一个body chunk,记录日志
|
||||
if not message.get("more_body", False):
|
||||
process_time = round(time.time() - start_time, 5)
|
||||
log.info(
|
||||
f"响应状态: {response_status}, "
|
||||
f"处理时间: {round(process_time * 1000, 3)}ms"
|
||||
)
|
||||
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
except Exception as e:
|
||||
log.error(f"中间件处理异常: {str(e)}")
|
||||
if not response_started:
|
||||
response = ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
class CustomGZipMiddleware(GZipMiddleware):
|
||||
"""GZip压缩中间件"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
minimum_size=settings.GZIP_MIN_SIZE,
|
||||
compresslevel=settings.GZIP_COMPRESS_LEVEL
|
||||
)
|
||||
159
后端源码/yifan.action-ai.cn/api-bak/app/core/permission.py
Normal file
159
后端源码/yifan.action-ai.cn/api-bak/app/core/permission.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy import select
|
||||
from app.api.v1.module_system.user.model import UserModel
|
||||
from app.api.v1.module_system.dept.model import DeptModel
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from app.utils.common_util import get_child_id_map, get_child_recursion
|
||||
|
||||
|
||||
class Permission:
|
||||
"""
|
||||
为业务模型提供数据权限过滤功能
|
||||
"""
|
||||
|
||||
# 数据权限常量定义,提高代码可读性
|
||||
DATA_SCOPE_SELF = 1 # 仅本人数据
|
||||
DATA_SCOPE_DEPT = 2 # 本部门数据
|
||||
DATA_SCOPE_DEPT_AND_CHILD = 3 # 本部门及以下数据
|
||||
DATA_SCOPE_ALL = 4 # 全部数据
|
||||
DATA_SCOPE_CUSTOM = 5 # 自定义数据
|
||||
|
||||
def __init__(self, model: Any, auth: AuthSchema):
|
||||
"""
|
||||
初始化权限过滤器实例
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model: 数据模型类
|
||||
current_user: 当前用户对象
|
||||
auth: 认证信息对象
|
||||
"""
|
||||
self.model = model
|
||||
self.auth = auth
|
||||
self.conditions: list[ColumnElement] = [] # 权限条件列表
|
||||
|
||||
async def filter_query(self, query: Any) -> Any:
|
||||
"""
|
||||
异步过滤查询对象
|
||||
|
||||
Args:
|
||||
query: SQLAlchemy查询对象
|
||||
|
||||
Returns:
|
||||
过滤后的查询对象
|
||||
"""
|
||||
condition = await self.__permission_condition()
|
||||
return query.where(condition) if condition is not None else query
|
||||
|
||||
async def __permission_condition(self) -> ColumnElement | None:
|
||||
"""
|
||||
应用数据范围权限隔离
|
||||
基于角色的五种数据权限范围过滤
|
||||
支持五种权限类型:
|
||||
1. 仅本人数据权限 - 只能查看自己创建的数据
|
||||
2. 本部门数据权限 - 只能查看同部门的数据
|
||||
3. 本部门及以下数据权限 - 可以查看本部门及所有子部门的数据
|
||||
4. 全部数据权限 - 可以查看所有数据
|
||||
5. 自定义数据权限 - 通过role_dept_relation表定义可访问的部门列表
|
||||
|
||||
权限处理原则:
|
||||
- 多个角色的权限取并集(最宽松原则)
|
||||
- 优先级:全部数据 > 部门权限(2、3、5的并集)> 仅本人
|
||||
- 构造权限过滤表达式,返回None表示不限制
|
||||
"""
|
||||
# 如果不需要检查数据权限,则不限制
|
||||
if not self.auth.user:
|
||||
return None
|
||||
|
||||
# 如果检查数据权限为False,则不限制
|
||||
if not self.auth.check_data_scope:
|
||||
return None
|
||||
|
||||
# 如果模型没有创建人created_id字段,则不限制
|
||||
if not hasattr(self.model, "created_id"):
|
||||
return None
|
||||
|
||||
# 超级管理员可以查看所有数据
|
||||
if self.auth.user.is_superuser:
|
||||
return None
|
||||
|
||||
# 如果用户没有角色,则只能查看自己的数据
|
||||
roles = getattr(self.auth.user, "roles", []) or []
|
||||
if not roles:
|
||||
created_id_attr = getattr(self.model, "created_id", None)
|
||||
if created_id_attr is not None:
|
||||
return created_id_attr == self.auth.user.id
|
||||
return None
|
||||
|
||||
# 获取用户所有角色的权限范围
|
||||
data_scopes = set()
|
||||
custom_dept_ids = set() # 自定义权限(data_scope=5)关联的部门ID集合
|
||||
|
||||
for role in roles:
|
||||
data_scopes.add(role.data_scope)
|
||||
# 收集自定义权限(data_scope=5)关联的部门ID
|
||||
if role.data_scope == self.DATA_SCOPE_CUSTOM and hasattr(role, 'depts') and role.depts:
|
||||
for dept in role.depts:
|
||||
custom_dept_ids.add(dept.id)
|
||||
|
||||
# 权限优先级处理:全部数据权限最高优先级
|
||||
if self.DATA_SCOPE_ALL in data_scopes:
|
||||
return None
|
||||
|
||||
# 收集所有可访问的部门ID(2、3、5权限的并集)
|
||||
accessible_dept_ids = set()
|
||||
user_dept_id = getattr(self.auth.user, "dept_id", None)
|
||||
|
||||
# 处理自定义数据权限(5)
|
||||
if self.DATA_SCOPE_CUSTOM in data_scopes:
|
||||
accessible_dept_ids.update(custom_dept_ids)
|
||||
|
||||
# 处理本部门数据权限(2)
|
||||
if self.DATA_SCOPE_DEPT in data_scopes:
|
||||
if user_dept_id is not None:
|
||||
accessible_dept_ids.add(user_dept_id)
|
||||
|
||||
# 处理本部门及以下数据权限(3)
|
||||
if self.DATA_SCOPE_DEPT_AND_CHILD in data_scopes:
|
||||
if user_dept_id is not None:
|
||||
try:
|
||||
# 查询所有部门并递归获取子部门
|
||||
dept_sql = select(DeptModel)
|
||||
dept_result = await self.auth.db.execute(dept_sql)
|
||||
dept_objs = dept_result.scalars().all()
|
||||
id_map = get_child_id_map(dept_objs)
|
||||
# get_child_recursion返回的结果已包含自身ID和所有子部门ID
|
||||
dept_with_children_ids = get_child_recursion(id=user_dept_id, id_map=id_map)
|
||||
accessible_dept_ids.update(dept_with_children_ids)
|
||||
except Exception:
|
||||
# 查询失败时降级到本部门
|
||||
accessible_dept_ids.add(user_dept_id)
|
||||
|
||||
# 如果有部门权限(2、3、5任一),使用部门过滤
|
||||
if accessible_dept_ids:
|
||||
creator_rel = getattr(self.model, "created_by", None)
|
||||
# 优先使用关系过滤(性能更好)
|
||||
if creator_rel is not None and hasattr(UserModel, 'dept_id'):
|
||||
return creator_rel.has(getattr(UserModel, 'dept_id').in_(list(accessible_dept_ids)))
|
||||
# 降级方案:如果模型没有created_by关系但有created_id,则只能查看自己的数据
|
||||
else:
|
||||
created_id_attr = getattr(self.model, "created_id", None)
|
||||
if created_id_attr is not None:
|
||||
return created_id_attr == self.auth.user.id
|
||||
return None
|
||||
|
||||
# 处理仅本人数据权限(1)
|
||||
if self.DATA_SCOPE_SELF in data_scopes:
|
||||
created_id_attr = getattr(self.model, "created_id", None)
|
||||
if created_id_attr is not None:
|
||||
return created_id_attr == self.auth.user.id
|
||||
return None
|
||||
|
||||
# 默认情况:如果用户有角色但没有任何有效权限范围,只能查看自己的数据
|
||||
created_id_attr = getattr(self.model, "created_id", None)
|
||||
if created_id_attr is not None:
|
||||
return created_id_attr == self.auth.user.id
|
||||
return None
|
||||
252
后端源码/yifan.action-ai.cn/api-bak/app/core/redis_crud.py
Normal file
252
后端源码/yifan.action-ai.cn/api-bak/app/core/redis_crud.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pickle
|
||||
from typing import Any, Awaitable
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class RedisCURD:
|
||||
"""缓存工具类"""
|
||||
|
||||
def __init__(self, redis: Redis) -> None:
|
||||
"""初始化"""
|
||||
self.redis = redis
|
||||
|
||||
async def mget(self, keys: list) -> list:
|
||||
"""批量获取缓存
|
||||
|
||||
参数:
|
||||
- keys (list): 键名列表
|
||||
|
||||
返回:
|
||||
- list: 返回缓存值列表,如果获取失败则返回空列表
|
||||
"""
|
||||
try:
|
||||
data = await self.redis.mget(*[str(key) for key in keys])
|
||||
return data
|
||||
except Exception as e:
|
||||
log.error(f"批量获取缓存失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_keys(self, pattern: str = "*") -> list:
|
||||
"""获取缓存键名
|
||||
|
||||
参数:
|
||||
- pattern (str, optional): 匹配模式,默认值为"*"。
|
||||
|
||||
返回:
|
||||
- list: 返回匹配的缓存键名列表,如果获取失败则返回空列表
|
||||
"""
|
||||
try:
|
||||
keys = await self.redis.keys(f"{pattern}")
|
||||
return keys
|
||||
except Exception as e:
|
||||
log.error(f"获取缓存键名失败: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
async def get(self, key: str) -> Any:
|
||||
"""获取缓存
|
||||
|
||||
参数:
|
||||
- key (str): 缓存键名
|
||||
|
||||
返回:
|
||||
- Any: 返回缓存值,如果缓存不存在则返回None
|
||||
"""
|
||||
try:
|
||||
data = await self.redis.get(f"{key}")
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"获取缓存失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, expire: int | None = None) -> bool:
|
||||
"""设置缓存
|
||||
|
||||
参数:
|
||||
- key (str): 缓存键名
|
||||
- value (Any): 缓存值
|
||||
- expire (int | None, optional): 过期时间,单位为秒,默认值为None。
|
||||
|
||||
返回:
|
||||
- bool: 如果设置缓存成功则返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 根据数据类型选择序列化方式
|
||||
if isinstance(value, (int, float, str)):
|
||||
data = str(value).encode('utf-8')
|
||||
else:
|
||||
try:
|
||||
data = pickle.dumps(value)
|
||||
except Exception as e:
|
||||
log.error(f"序列化数据失败: {str(e)}")
|
||||
return False
|
||||
|
||||
await self.redis.set(
|
||||
name = key,
|
||||
value = data,
|
||||
ex=expire
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"设置缓存失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def delete(self, *keys: str) -> bool:
|
||||
"""删除缓存
|
||||
|
||||
参数:
|
||||
- keys (str): 缓存键名
|
||||
|
||||
返回:
|
||||
- bool: 如果删除缓存成功则返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
await self.redis.delete(*keys)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"删除缓存失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def clear(self, pattern: str = "*") -> bool:
|
||||
"""清空缓存
|
||||
|
||||
参数:
|
||||
- pattern (str, optional): 匹配模式,默认值为"*"。
|
||||
|
||||
返回:
|
||||
- bool: 如果清空缓存成功则返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
keys = await self.redis.keys(f"{pattern}")
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"清空缓存失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""判断缓存是否存在
|
||||
|
||||
参数:
|
||||
- key (str): 缓存键名
|
||||
|
||||
返回:
|
||||
- bool: 如果缓存存在则返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
return await self.redis.exists(f"{key}")
|
||||
except Exception as e:
|
||||
log.error(f"判断缓存是否存在失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""获取缓存过期时间
|
||||
|
||||
参数:
|
||||
- key (str): 缓存键名
|
||||
|
||||
返回:
|
||||
- int: 返回缓存过期时间,单位为秒,如果缓存没有设置过期时间则返回-1
|
||||
"""
|
||||
try:
|
||||
return await self.redis.ttl(f"{key}")
|
||||
except Exception as e:
|
||||
log.error(f"获取缓存过期时间失败: {str(e)}")
|
||||
return -1
|
||||
|
||||
async def expire(self, key: str, expire: int) -> bool:
|
||||
"""设置缓存过期时间
|
||||
|
||||
参数:
|
||||
- key (str): 缓存键名
|
||||
- expire (int): 过期时间,单位为秒
|
||||
|
||||
返回:
|
||||
- bool: 如果设置缓存过期时间成功则返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
return await self.redis.expire(f"{key}", expire)
|
||||
except Exception as e:
|
||||
log.error(f"设置缓存过期时间失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def info(self) -> dict:
|
||||
"""获取缓存信息
|
||||
|
||||
返回:
|
||||
- dict: 返回缓存信息字典,如果获取失败则返回空字典
|
||||
"""
|
||||
try:
|
||||
return await self.redis.info()
|
||||
except Exception as e:
|
||||
log.error(f"获取缓存信息失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def db_size(self) -> int:
|
||||
"""获取数据库大小
|
||||
|
||||
返回:
|
||||
- int: 返回数据库大小,如果获取失败则返回0
|
||||
"""
|
||||
try:
|
||||
return await self.redis.dbsize()
|
||||
except Exception as e:
|
||||
log.error(f"获取数据库大小失败: {str(e)}")
|
||||
return 0
|
||||
async def commandstats(self) -> dict:
|
||||
"""获取命令统计信息
|
||||
|
||||
返回:
|
||||
- dict: 返回命令统计信息字典,如果获取失败则返回空字典
|
||||
"""
|
||||
try:
|
||||
return await self.redis.info("commandstats")
|
||||
except Exception as e:
|
||||
log.error(f"获取命令统计信息失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def hash_set(self, name: str, key: str, value: Any) -> bool:
|
||||
"""设置哈希缓存
|
||||
|
||||
参数:
|
||||
- name (str): 哈希缓存名称
|
||||
- key (str): 哈希缓存键名
|
||||
- value (Any): 哈希缓存值
|
||||
|
||||
返回:
|
||||
- bool: 如果设置哈希缓存成功则返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
self.redis.hset(name=name, key=key, value=value)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"设置哈希缓存失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def hash_get(self, name: str, keys: list[str]) -> Awaitable[list[Any]] | list[Any]:
|
||||
"""获取哈希缓存
|
||||
|
||||
参数:
|
||||
- name (str): 哈希缓存名称
|
||||
- keys (list[str]): 哈希缓存键名列表
|
||||
|
||||
返回:
|
||||
- Awaitable[list[Any]] | list[Any]: 返回哈希缓存值列表,如果获取失败则返回空列表
|
||||
"""
|
||||
try:
|
||||
data = self.redis.hmget(name=name, keys=keys)
|
||||
return data
|
||||
except Exception as e:
|
||||
log.error(f"获取哈希缓存失败: {str(e)}")
|
||||
return []
|
||||
148
后端源码/yifan.action-ai.cn/api-bak/app/core/router_class.py
Normal file
148
后端源码/yifan.action-ai.cn/api-bak/app/core/router_class.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Any, Callable, Coroutine
|
||||
from fastapi import Request, Response
|
||||
from fastapi.routing import APIRoute
|
||||
from user_agents import parse
|
||||
|
||||
from app.core.database import async_db_session
|
||||
from app.config.setting import settings
|
||||
from app.utils.ip_local_util import IpLocalUtil
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
from app.api.v1.module_system.log.schema import OperationLogCreateSchema
|
||||
from app.api.v1.module_system.log.service import OperationLogService
|
||||
|
||||
"""
|
||||
在 FastAPI 中,route_class 参数用于自定义路由的行为。
|
||||
通过设置 route_class,你可以定义一个自定义的路由类,从而在每个路由处理之前或之后执行特定的操作。
|
||||
这对于日志记录、权限验证、性能监控等场景非常有用。
|
||||
"""
|
||||
class OperationLogRoute(APIRoute):
|
||||
"""操作日志路由装饰器"""
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
"""
|
||||
自定义路由处理程序,在每个路由处理之前或之后执行特定的操作。
|
||||
|
||||
参数:
|
||||
- request (Request): FastAPI请求对象。
|
||||
|
||||
返回:
|
||||
- Response: FastAPI响应对象。
|
||||
"""
|
||||
original_route_handler = super().get_route_handler()
|
||||
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
"""
|
||||
自定义路由处理程序,在每个路由处理之前或之后执行特定的操作。
|
||||
|
||||
参数:
|
||||
- request (Request): FastAPI请求对象。
|
||||
描述:
|
||||
- 该方法在每个路由处理之前被调用,用于记录操作日志。
|
||||
返回:
|
||||
- Response: FastAPI响应对象。
|
||||
"""
|
||||
start_time = time.time()
|
||||
# 请求前的处理
|
||||
response: Response = await original_route_handler(request)
|
||||
|
||||
# 请求后的处理
|
||||
if not settings.OPERATION_LOG_RECORD:
|
||||
return response
|
||||
if request.method not in settings.OPERATION_RECORD_METHOD:
|
||||
return response
|
||||
route: APIRoute = request.scope.get("route", None)
|
||||
if route.name in settings.IGNORE_OPERATION_FUNCTION:
|
||||
return response
|
||||
|
||||
user_agent = parse(request.headers.get("user-agent"))
|
||||
payload = b"{}"
|
||||
req_content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if req_content_type and (
|
||||
req_content_type.startswith('multipart/form-data') or req_content_type.startswith('application/x-www-form-urlencoded')
|
||||
):
|
||||
form_data = await request.form()
|
||||
oper_param = '\n'.join([f'{k}: {v}' for k, v in form_data.items()])
|
||||
payload = oper_param # 直接使用字符串格式的参数
|
||||
else:
|
||||
payload = await request.body()
|
||||
path_params = request.path_params
|
||||
oper_param = {}
|
||||
|
||||
# 处理请求体数据
|
||||
if payload:
|
||||
try:
|
||||
oper_param['body'] = json.loads(payload.decode())
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
oper_param['body'] = payload.decode('utf-8', errors='ignore')
|
||||
|
||||
# 处理路径参数
|
||||
if path_params:
|
||||
oper_param['path_params'] = dict(path_params)
|
||||
|
||||
payload = json.dumps(oper_param, ensure_ascii=False)
|
||||
|
||||
# 日志表请求参数字段长度最大为2000,因此在此处判断长度
|
||||
if len(payload) > 2000:
|
||||
payload = '请求参数过长'
|
||||
|
||||
response_data = response.body if "application/json" in response.headers.get("Content-Type", "") else b"{}"
|
||||
process_time = f"{(time.time() - start_time):.2f}s"
|
||||
|
||||
# 获取当前用户ID,如果是登录接口则为空
|
||||
log_type = 1 # 1:登录日志 2:操作日志
|
||||
current_user_id = None
|
||||
|
||||
# 优化:只在操作日志场景下获取current_user_id
|
||||
if "user_id" in request.scope:
|
||||
current_user_id = request.scope.get("user_id")
|
||||
log_type = 2
|
||||
|
||||
request_ip = None
|
||||
x_forwarded_for = request.headers.get('X-Forwarded-For')
|
||||
if x_forwarded_for:
|
||||
# 取第一个 IP 地址,通常为客户端真实 IP
|
||||
request_ip = x_forwarded_for.split(',')[0].strip()
|
||||
else:
|
||||
# 若没有 X-Forwarded-For 头,则使用 request.client.host
|
||||
if request.client:
|
||||
request_ip = request.client.host
|
||||
|
||||
login_location = await IpLocalUtil.get_ip_location(request_ip) if request_ip else None
|
||||
|
||||
# 判断请求是否来自api文档
|
||||
referer = request.headers.get('referer')
|
||||
request_from_swagger = referer and referer.endswith('docs')
|
||||
request_from_redoc = referer and referer.endswith('redoc')
|
||||
|
||||
if request_from_swagger or request_from_redoc:
|
||||
# 如果请求来自api文档,则不记录日志
|
||||
pass
|
||||
else:
|
||||
async with async_db_session() as session:
|
||||
async with session.begin():
|
||||
auth = AuthSchema(db=session)
|
||||
await OperationLogService.create_log_service(data=OperationLogCreateSchema(
|
||||
type = log_type,
|
||||
request_path = request.url.path,
|
||||
request_method = request.method,
|
||||
request_payload = payload,
|
||||
request_ip = request_ip,
|
||||
login_location=login_location,
|
||||
request_os = user_agent.os.family,
|
||||
request_browser = user_agent.browser.family,
|
||||
response_code = response.status_code,
|
||||
response_json = response_data.decode() if isinstance(response_data, (bytes, bytearray)) else str(response_data),
|
||||
process_time = process_time,
|
||||
description = route.summary,
|
||||
created_id = current_user_id,
|
||||
updated_id = current_user_id,
|
||||
), auth = auth)
|
||||
|
||||
return response
|
||||
|
||||
return custom_route_handler
|
||||
157
后端源码/yifan.action-ai.cn/api-bak/app/core/security.py
Normal file
157
后端源码/yifan.action-ai.cn/api-bak/app/core/security.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import jwt
|
||||
from fastapi import Form, Request
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
|
||||
from app.core.exceptions import CustomException
|
||||
from app.config.setting import settings
|
||||
from app.api.v1.module_system.auth.schema import JWTPayloadSchema
|
||||
|
||||
|
||||
class CustomOAuth2PasswordBearer(OAuth2PasswordBearer):
|
||||
"""自定义OAuth2认证类,继承自OAuth2PasswordBearer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_url: str,
|
||||
scheme_name: str | None = None,
|
||||
scopes: dict[str, str] | None = None,
|
||||
description: str | None = None,
|
||||
auto_error: bool = True
|
||||
) -> None:
|
||||
super().__init__(
|
||||
tokenUrl=token_url,
|
||||
scheme_name=scheme_name,
|
||||
scopes=scopes,
|
||||
description=description,
|
||||
auto_error=auto_error
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> str | None:
|
||||
"""
|
||||
重写认证方法,校验token
|
||||
|
||||
参数:
|
||||
- request (Request): FastAPI请求对象。
|
||||
|
||||
返回:
|
||||
- str | None: 校验通过的token,如果校验失败则返回None。
|
||||
|
||||
异常:
|
||||
- CustomException: 认证失败时抛出,状态码为401。
|
||||
"""
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, token = get_authorization_scheme_param(authorization)
|
||||
|
||||
if not authorization or scheme.lower() != settings.TOKEN_TYPE:
|
||||
if self.auto_error:
|
||||
raise CustomException(msg="认证失败,请登录后再试", code=10401, status_code=401)
|
||||
return None
|
||||
return token
|
||||
|
||||
|
||||
class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
|
||||
"""
|
||||
自定义登录表单,扩展验证码等字段
|
||||
|
||||
参数:
|
||||
- grant_type (str | None): 授权类型,默认值为None,正则表达式为'password'。
|
||||
- scope (str): 作用域,默认值为空字符串。
|
||||
- client_id (str | None): 客户端ID,默认值为None。
|
||||
- client_secret (str | None): 客户端密钥,默认值为None。
|
||||
- username (str): 用户名。
|
||||
- password (str): 密码。
|
||||
- captcha_key (str | None): 验证码键,默认值为空字符串。
|
||||
- captcha (str | None): 验证码值,默认值为空字符串。
|
||||
- login_type (str | None): 登录类型,默认值为"PC端",描述为"PC端 | 移动端"。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grant_type: str | None = Form(default=None, pattern='password'),
|
||||
scope: str = Form(default=''),
|
||||
client_id: str | None = Form(default=None),
|
||||
client_secret: str | None = Form(default=None),
|
||||
username: str = Form(),
|
||||
password: str = Form(),
|
||||
captcha_key: str | None = Form(default=""),
|
||||
captcha: str | None = Form(default=""),
|
||||
login_type: str | None = Form(default="PC端", description="PC端 | 移动端")
|
||||
):
|
||||
super().__init__(
|
||||
grant_type=grant_type,
|
||||
scope=scope,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
self.captcha_key = captcha_key
|
||||
self.captcha = captcha
|
||||
self.login_type = login_type
|
||||
|
||||
|
||||
# OAuth2认证配置
|
||||
OAuth2Schema = CustomOAuth2PasswordBearer(
|
||||
token_url="system/auth/login",
|
||||
description="认证"
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(payload: JWTPayloadSchema) -> str:
|
||||
"""
|
||||
生成JWT访问令牌
|
||||
|
||||
参数:
|
||||
- payload (JWTPayloadSchema): JWT有效载荷,包含用户信息等。
|
||||
|
||||
返回:
|
||||
- str: 生成的JWT访问令牌。
|
||||
"""
|
||||
payload_dict = payload.model_dump()
|
||||
return jwt.encode(
|
||||
payload=payload_dict,
|
||||
key=settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> JWTPayloadSchema:
|
||||
"""
|
||||
解析JWT访问令牌
|
||||
|
||||
参数:
|
||||
- token (str): JWT访问令牌字符串。
|
||||
|
||||
返回:
|
||||
- JWTPayloadSchema: 解析后的JWT有效载荷,包含用户信息等。
|
||||
|
||||
异常:
|
||||
- CustomException: 解析失败时抛出,状态码为401。
|
||||
"""
|
||||
if not token:
|
||||
raise CustomException(msg="认证不存在,请重新登录", code=10401, status_code=401)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
jwt=token,
|
||||
key=settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
online_user_info = payload.get("sub")
|
||||
if not online_user_info:
|
||||
raise CustomException(msg="无效认证,请重新登录", code=10401, status_code=401)
|
||||
|
||||
return JWTPayloadSchema(**payload)
|
||||
|
||||
except (jwt.InvalidSignatureError, jwt.DecodeError):
|
||||
raise CustomException(msg="无效认证,请重新登录", code=10401, status_code=401)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise CustomException(msg="认证已过期,请重新登录", code=10401, status_code=401)
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
raise CustomException(msg="token已失效,请重新登录", code=10401, status_code=401)
|
||||
55
后端源码/yifan.action-ai.cn/api-bak/app/core/serialize.py
Normal file
55
后端源码/yifan.action-ai.cn/api-bak/app/core/serialize.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, TypeVar, Type, Generic
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class Serialize(Generic[ModelType, SchemaType]):
|
||||
"""
|
||||
序列化工具类,提供模型、Schema 和字典之间的转换功能
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def schema_to_model(cls,schema: Type[SchemaType], model: Type[ModelType]) -> ModelType:
|
||||
"""
|
||||
将 Pydantic Schema 转换为 SQLAlchemy 模型
|
||||
|
||||
参数:
|
||||
- schema (Type[SchemaType]): Pydantic Schema 实例。
|
||||
- model (Type[ModelType]): SQLAlchemy 模型类。
|
||||
|
||||
返回:
|
||||
- ModelType: SQLAlchemy 模型实例。
|
||||
|
||||
异常:
|
||||
- ValueError: 转换过程中可能抛出的异常。
|
||||
"""
|
||||
try:
|
||||
return model(**cls.model_to_dict(model, schema))
|
||||
except Exception as e:
|
||||
raise ValueError(f"序列化失败: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def model_to_dict(cls, model: Type[ModelType], schema: Type[SchemaType]) -> dict[str, Any]:
|
||||
"""
|
||||
将 SQLAlchemy 模型转换为 Pydantic Schema
|
||||
|
||||
参数:
|
||||
- model (Type[ModelType]): SQLAlchemy 模型实例。
|
||||
- schema (Type[SchemaType]): Pydantic Schema 类。
|
||||
|
||||
返回:
|
||||
- dict[str, Any]: 包含模型数据的字典。
|
||||
|
||||
异常:
|
||||
- ValueError: 转换过程中可能抛出的异常。
|
||||
"""
|
||||
try:
|
||||
return schema.model_validate(model).model_dump()
|
||||
except Exception as e:
|
||||
raise ValueError(f"反序列化失败: {str(e)}")
|
||||
|
||||
167
后端源码/yifan.action-ai.cn/api-bak/app/core/validator.py
Normal file
167
后端源码/yifan.action-ai.cn/api-bak/app/core/validator.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from pydantic import AfterValidator, PlainSerializer, WithJsonSchema
|
||||
|
||||
from app.common.constant import RET
|
||||
from app.core.exceptions import CustomException
|
||||
|
||||
|
||||
# 自定义日期时间字符串类型
|
||||
DateTimeStr = Annotated[
|
||||
datetime,
|
||||
AfterValidator(lambda x: datetime_validator(x)),
|
||||
PlainSerializer(lambda x: x.strftime('%Y-%m-%d %H:%M:%S') if isinstance(x, datetime) else str(x), return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
# 自定义手机号类型
|
||||
Telephone = Annotated[
|
||||
str,
|
||||
AfterValidator(lambda x: mobile_validator(x)),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
# 自定义邮箱类型
|
||||
Email = Annotated[
|
||||
str,
|
||||
AfterValidator(lambda x: email_validator(x)),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
def datetime_validator(value: str | datetime) -> datetime:
|
||||
"""
|
||||
日期格式验证器。
|
||||
|
||||
参数:
|
||||
- value (str | datetime): 日期值。
|
||||
|
||||
返回:
|
||||
- datetime: 格式化后的日期。
|
||||
|
||||
异常:
|
||||
- CustomException: 日期格式无效时抛出。
|
||||
"""
|
||||
pattern = "%Y-%m-%d %H:%M:%S"
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
return datetime.strptime(value, pattern)
|
||||
elif isinstance(value, datetime):
|
||||
return value
|
||||
except Exception:
|
||||
raise CustomException(code=RET.ERROR.code, msg="无效的日期格式")
|
||||
|
||||
|
||||
def email_validator(value: str) -> str:
|
||||
"""
|
||||
邮箱地址验证器。
|
||||
|
||||
参数:
|
||||
- value (str): 邮箱地址。
|
||||
|
||||
返回:
|
||||
- str: 验证后的邮箱地址。
|
||||
|
||||
异常:
|
||||
- CustomException: 邮箱格式无效时抛出。
|
||||
"""
|
||||
if not value:
|
||||
raise CustomException(code=RET.ERROR.code, msg="邮箱地址不能为空")
|
||||
|
||||
regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
|
||||
if not re.match(regex, value):
|
||||
raise CustomException(code=RET.ERROR.code, msg="邮箱地址格式不正确")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def mobile_validator(value: str | None) -> str | None:
|
||||
"""
|
||||
手机号验证器。
|
||||
|
||||
参数:
|
||||
- value (str | None): 手机号。
|
||||
|
||||
返回:
|
||||
- str | None: 验证后的手机号。
|
||||
|
||||
异常:
|
||||
- CustomException: 手机号格式无效时抛出。
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
|
||||
if len(value) != 11 or not value.isdigit():
|
||||
raise CustomException(code=RET.ERROR.code, msg="手机号格式不正确")
|
||||
|
||||
regex = r'^1(3\d|4[4-9]|5[0-35-9]|6[67]|7[013-8]|8[0-9]|9[0-9])\d{8}$'
|
||||
|
||||
if not re.match(regex, value):
|
||||
raise CustomException(code=RET.ERROR.code, msg="手机号格式不正确")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def menu_request_validator(data):
|
||||
"""
|
||||
菜单请求数据验证器。
|
||||
|
||||
参数:
|
||||
- data (Any): 请求数据。
|
||||
|
||||
返回:
|
||||
- Any: 验证后的请求数据。
|
||||
|
||||
异常:
|
||||
- CustomException: 请求数据无效时抛出。
|
||||
"""
|
||||
menu_types = {1: "目录", 2: "功能", 3: "权限", 4: "外链"}
|
||||
|
||||
if data.type not in menu_types:
|
||||
raise CustomException(code=RET.ERROR.code, msg=f"菜单类型必须为: {','.join(map(str, menu_types.keys()))}")
|
||||
|
||||
if data.type in [1, 2]:
|
||||
if not data.route_name:
|
||||
raise CustomException(code=RET.ERROR.code, msg="路由名称不能为空")
|
||||
if not data.route_path:
|
||||
raise CustomException(code=RET.ERROR.code, msg="路由路径不能为空")
|
||||
|
||||
if data.type == 2 and not data.component_path:
|
||||
raise CustomException(code=RET.ERROR.code, msg="组件路径不能为空")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def role_permission_request_validator(data):
|
||||
"""
|
||||
角色权限设置数据验证器。
|
||||
|
||||
参数:
|
||||
- data (Any): 请求数据。
|
||||
|
||||
返回:
|
||||
- Any: 验证后的请求数据。
|
||||
|
||||
异常:
|
||||
- CustomException: 请求数据无效时抛出。
|
||||
"""
|
||||
data_scopes = {
|
||||
1: "仅本人数据权限",
|
||||
2: "本部门数据权限",
|
||||
3: "本部门及以下数据权限",
|
||||
4: "全部数据权限",
|
||||
5: "自定义数据权限"
|
||||
}
|
||||
|
||||
if data.data_scope not in data_scopes:
|
||||
raise CustomException(code=RET.ERROR.code, msg=f"数据权限范围必须为: {','.join(map(str, data_scopes.keys()))}")
|
||||
|
||||
if not data.role_ids:
|
||||
raise CustomException(code=RET.ERROR.code, msg="角色不能为空")
|
||||
|
||||
return data
|
||||
Reference in New Issue
Block a user