upload project source code

This commit is contained in:
2026-04-30 18:49:43 +08:00
commit 9b394ba682
2277 changed files with 660945 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,473 @@
# -*- coding: utf-8 -*-
from pydantic import BaseModel
from typing import TypeVar, Sequence, Generic, Dict, Any, List, Optional, Type, Union
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.orm import selectinload
from sqlalchemy.engine import Result
from sqlalchemy import asc, func, select, delete, Select, desc, update
from sqlalchemy import inspect as sa_inspect
from app.core.base_model import MappedBase
from app.core.exceptions import CustomException
from app.core.permission import Permission
from app.api.v1.module_system.auth.schema import AuthSchema
ModelType = TypeVar("ModelType", bound=MappedBase)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
OutSchemaType = TypeVar("OutSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""基础数据层"""
def __init__(self, model: Type[ModelType], auth: AuthSchema) -> None:
"""
初始化CRUDBase类
参数:
- model (Type[ModelType]): 数据模型类。
- auth (AuthSchema): 认证信息。
返回:
- None
"""
self.model = model
self.auth = auth
async def get(self, preload: Optional[List[Union[str, Any]]] = None, **kwargs) -> Optional[ModelType]:
"""
根据条件获取单个对象
参数:
- preload (Optional[List[Union[str, Any]]]): 预加载关系支持关系名字符串或SQLAlchemy loader option
- **kwargs: 查询条件
返回:
- Optional[ModelType]: 对象实例
异常:
- CustomException: 查询失败时抛出异常
"""
try:
conditions = await self.__build_conditions(**kwargs)
sql = select(self.model).where(*conditions)
# 应用可配置的预加载选项
for opt in self.__loader_options(preload):
sql = sql.options(opt)
sql = await self.__filter_permissions(sql)
result: Result = await self.auth.db.execute(sql)
obj = result.scalars().first()
return obj
except Exception as e:
raise CustomException(msg=f"获取查询失败: {str(e)}")
async def list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
"""
根据条件获取对象列表
参数:
- search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
- order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
- preload (Optional[List[Union[str, Any]]]): 预加载关系支持关系名字符串或SQLAlchemy loader option
返回:
- Sequence[ModelType]: 对象列表
异常:
- CustomException: 查询失败时抛出异常
"""
try:
conditions = await self.__build_conditions(**search) if search else []
order = order_by or [{'id': 'asc'}]
sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
# 应用可配置的预加载选项
for opt in self.__loader_options(preload):
sql = sql.options(opt)
sql = await self.__filter_permissions(sql)
result: Result = await self.auth.db.execute(sql)
return result.scalars().all()
except Exception as e:
raise CustomException(msg=f"列表查询失败: {str(e)}")
async def tree_list(self, search: Optional[Dict] = None, order_by: Optional[List[Dict[str, str]]] = None, children_attr: str = 'children', preload: Optional[List[Union[str, Any]]] = None) -> Sequence[ModelType]:
"""
获取树形结构数据列表
参数:
- search (Optional[Dict]): 查询条件
- order_by (Optional[List[Dict[str, str]]]): 排序字段
- children_attr (str): 子节点属性名
- preload (Optional[List[Union[str, Any]]]): 额外预加载关系若为None则默认包含children_attr
返回:
- Sequence[ModelType]: 树形结构数据列表
异常:
- CustomException: 查询失败时抛出异常
"""
try:
conditions = await self.__build_conditions(**search) if search else []
order = order_by or [{'id': 'asc'}]
sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
# 处理预加载选项
final_preload = preload
# 如果没有提供preload且children_attr存在则添加到预加载选项中
if preload is None and children_attr and hasattr(self.model, children_attr):
# 获取模型默认预加载选项
model_defaults = getattr(self.model, "__loader_options__", [])
# 将children_attr添加到默认预加载选项中
final_preload = list(model_defaults) + [children_attr]
# 应用预加载选项
for opt in self.__loader_options(final_preload):
sql = sql.options(opt)
sql = await self.__filter_permissions(sql)
result: Result = await self.auth.db.execute(sql)
return result.scalars().all()
except Exception as e:
raise CustomException(msg=f"树形列表查询失败: {str(e)}")
async def page(self, offset: int, limit: int, order_by: List[Dict[str, str]], search: Dict, out_schema: Type[OutSchemaType], preload: Optional[List[Union[str, Any]]] = None) -> Dict:
"""
获取分页数据
参数:
- offset (int): 偏移量
- limit (int): 每页数量
- order_by (List[Dict[str, str]]): 排序字段
- search (Dict): 查询条件
- out_schema (Type[OutSchemaType]): 输出数据模型
- preload (Optional[List[Union[str, Any]]]): 预加载关系
返回:
- Dict: 分页数据
异常:
- CustomException: 查询失败时抛出异常
"""
try:
conditions = await self.__build_conditions(**search) if search else []
order = order_by or [{'id': 'asc'}]
sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
# 应用预加载选项
for opt in self.__loader_options(preload):
sql = sql.options(opt)
sql = await self.__filter_permissions(sql)
# 优化count查询使用主键计数而非全表扫描
mapper = sa_inspect(self.model)
pk_cols = list(getattr(mapper, "primary_key", []))
if pk_cols:
# 使用主键的第一列进行计数主键必定非NULL性能更好
count_sql = select(func.count(pk_cols[0])).select_from(self.model)
else:
# 降级方案使用count(*)
count_sql = select(func.count()).select_from(self.model)
if conditions:
count_sql = count_sql.where(*conditions)
count_sql = await self.__filter_permissions(count_sql)
total_result = await self.auth.db.execute(count_sql)
total = total_result.scalar() or 0
result: Result = await self.auth.db.execute(sql.offset(offset).limit(limit))
objs = result.scalars().all()
return {
"page_no": offset // limit + 1 if limit else 1,
"page_size": limit if limit else 10,
"total": total,
"has_next": offset + limit < total,
"items": [out_schema.model_validate(obj).model_dump() for obj in objs]
}
except Exception as e:
raise CustomException(msg=f"分页查询失败: {str(e)}")
async def create(self, data: Union[CreateSchemaType, Dict]) -> ModelType:
"""
创建新对象
参数:
- data (Union[CreateSchemaType, Dict]): 对象属性
返回:
- ModelType: 新创建的对象实例
异常:
- CustomException: 创建失败时抛出异常
"""
try:
obj_dict = data if isinstance(data, dict) else data.model_dump()
obj = self.model(**obj_dict)
# 设置字段值只检查一次current_user
if self.auth.user:
if hasattr(obj, "created_id"):
setattr(obj, "created_id", self.auth.user.id)
if hasattr(obj, "updated_id"):
setattr(obj, "updated_id", self.auth.user.id)
self.auth.db.add(obj)
await self.auth.db.flush()
await self.auth.db.refresh(obj)
return obj
except Exception as e:
raise CustomException(msg=f"创建失败: {str(e)}")
async def update(self, id: int, data: Union[UpdateSchemaType, Dict]) -> ModelType:
"""
更新对象
参数:
- id (int): 对象ID
- data (Union[UpdateSchemaType, Dict]): 更新的属性及值
返回:
- ModelType: 更新后的对象实例
异常:
- CustomException: 更新失败时抛出异常
"""
try:
obj_dict = data if isinstance(data, dict) else data.model_dump(exclude_unset=True, exclude={"id"})
obj = await self.get(id=id)
if not obj:
raise CustomException(msg="更新对象不存在")
# 设置字段值只检查一次current_user
if self.auth.user:
if hasattr(obj, "updated_id"):
setattr(obj, "updated_id", self.auth.user.id)
for key, value in obj_dict.items():
if hasattr(obj, key):
setattr(obj, key, value)
await self.auth.db.flush()
await self.auth.db.refresh(obj)
# 权限二次确认flush后再次验证对象仍在权限范围内
# 防止并发修改导致的权限逃逸如其他事务修改了created_id
verify_obj = await self.get(id=id)
if not verify_obj:
# 对象已被删除或权限已失效
raise CustomException(msg="更新失败,对象不存在或无权限访问")
return obj
except Exception as e:
raise CustomException(msg=f"更新失败: {str(e)}")
async def delete(self, ids: List[int]) -> None:
"""
删除对象
参数:
- ids (List[int]): 对象ID列表
异常:
- CustomException: 删除失败时抛出异常
"""
try:
# 先查询确认权限,避免删除无权限的数据
objs = await self.list(search={"id": ("in", ids)})
accessible_ids = [obj.id for obj in objs]
# 检查是否所有ID都有权限访问
inaccessible_count = len(ids) - len(accessible_ids)
if inaccessible_count > 0:
raise CustomException(msg=f"无权限删除{inaccessible_count}条数据")
if not accessible_ids:
return # 没有可删除的数据
mapper = sa_inspect(self.model)
pk_cols = list(getattr(mapper, "primary_key", []))
if not pk_cols:
raise CustomException(msg="模型缺少主键,无法删除")
if len(pk_cols) > 1:
raise CustomException(msg="暂不支持复合主键的批量删除")
# 只删除有权限的数据
sql = delete(self.model).where(pk_cols[0].in_(accessible_ids))
await self.auth.db.execute(sql)
await self.auth.db.flush()
except Exception as e:
raise CustomException(msg=f"删除失败: {str(e)}")
async def clear(self) -> None:
"""
清空对象表
异常:
- CustomException: 清空失败时抛出异常
"""
try:
sql = delete(self.model)
await self.auth.db.execute(sql)
await self.auth.db.flush()
except Exception as e:
raise CustomException(msg=f"清空失败: {str(e)}")
async def set(self, ids: List[int], **kwargs) -> None:
"""
批量更新对象
参数:
- ids (List[int]): 对象ID列表
- **kwargs: 更新的属性及值
异常:
- CustomException: 更新失败时抛出异常
"""
try:
# 先查询确认权限,避免更新无权限的数据
objs = await self.list(search={"id": ("in", ids)})
accessible_ids = [obj.id for obj in objs]
# 检查是否所有ID都有权限访问
inaccessible_count = len(ids) - len(accessible_ids)
if inaccessible_count > 0:
raise CustomException(msg=f"无权限更新{inaccessible_count}条数据")
if not accessible_ids:
return # 没有可更新的数据
mapper = sa_inspect(self.model)
pk_cols = list(getattr(mapper, "primary_key", []))
if not pk_cols:
raise CustomException(msg="模型缺少主键,无法更新")
if len(pk_cols) > 1:
raise CustomException(msg="暂不支持复合主键的批量更新")
# 只更新有权限的数据
sql = update(self.model).where(pk_cols[0].in_(accessible_ids)).values(**kwargs)
await self.auth.db.execute(sql)
await self.auth.db.flush()
except CustomException:
raise
except Exception as e:
raise CustomException(msg=f"批量更新失败: {str(e)}")
async def __filter_permissions(self, sql: Select) -> Select:
"""
过滤数据权限仅用于Select
"""
filter = Permission(
model=self.model,
auth=self.auth
)
return await filter.filter_query(sql)
async def __build_conditions(self, **kwargs) -> List[ColumnElement]:
"""
构建查询条件
参数:
- **kwargs: 查询参数
返回:
- List[ColumnElement]: SQL条件表达式列表
异常:
- CustomException: 查询参数不存在时抛出异常
"""
conditions = []
for key, value in kwargs.items():
if value is None or value == "":
continue
attr = getattr(self.model, key)
if isinstance(value, tuple):
seq, val = value
if seq == "None":
conditions.append(attr.is_(None))
elif seq == "not None":
conditions.append(attr.isnot(None))
elif seq == "date" and val:
conditions.append(func.date_format(attr, "%Y-%m-%d") == val)
elif seq == "month" and val:
conditions.append(func.date_format(attr, "%Y-%m") == val)
elif seq == "like" and val:
conditions.append(attr.like(f"%{val}%"))
elif seq == "in" and val:
conditions.append(attr.in_(val))
elif seq == "between" and isinstance(val, (list, tuple)) and len(val) == 2:
conditions.append(attr.between(val[0], val[1]))
elif seq == "!=" and val:
conditions.append(attr != val)
elif seq == ">" and val:
conditions.append(attr > val)
elif seq == ">=" and val:
conditions.append(attr >= val)
elif seq == "<" and val:
conditions.append(attr < val)
elif seq == "<=" and val:
conditions.append(attr <= val)
elif seq == "==" and val:
conditions.append(attr == val)
else:
conditions.append(attr == value)
return conditions
def __order_by(self, order_by: List[Dict[str, str]]) -> List[ColumnElement]:
"""
获取排序字段
参数:
- order_by (List[Dict[str, str]]): 排序字段列表,格式为 [{'id': 'asc'}, {'name': 'desc'}]
返回:
- List[ColumnElement]: 排序字段列表
异常:
- CustomException: 排序字段不存在时抛出异常
"""
columns = []
for order in order_by:
for field, direction in order.items():
column = getattr(self.model, field)
columns.append(desc(column) if direction.lower() == 'desc' else asc(column))
return columns
def __loader_options(self, preload: Optional[List[Union[str, Any]]] = None) -> List[Any]:
"""
构建预加载选项
参数:
- preload (Optional[List[Union[str, Any]]]): 预加载关系支持关系名字符串或SQLAlchemy loader option
返回:
- List[Any]: 预加载选项列表
"""
options = []
# 获取模型定义的默认加载选项
model_loader_options = getattr(self.model, '__loader_options__', [])
# 合并所有需要预加载的选项
all_preloads = set(model_loader_options)
if preload:
for opt in preload:
if isinstance(opt, str):
all_preloads.add(opt)
elif preload == []:
# 如果明确指定空列表,则不使用任何预加载
all_preloads = set()
# 处理所有预加载选项
for opt in all_preloads:
if isinstance(opt, str):
# 使用selectinload来避免在异步环境中的MissingGreenlet错误
if hasattr(self.model, opt):
options.append(selectinload(getattr(self.model, opt)))
else:
# 直接使用非字符串的加载选项
options.append(opt)
return options

View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
from typing import Optional
from datetime import datetime
from sqlalchemy import DateTime, String, Integer, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, declared_attr, relationship
from sqlalchemy.ext.asyncio import AsyncAttrs
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.api.v1.module_system.user.model import UserModel
from app.utils.common_util import uuid4_str
class MappedBase(AsyncAttrs, DeclarativeBase):
"""
声明式基类
`AsyncAttrs <https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncAttrs>`__
`DeclarativeBase <https://docs.sqlalchemy.org/en/20/orm/declarative_config.html>`__
`mapped_column() <https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.mapped_column>`__
兼容 SQLite、MySQL 和 PostgreSQL
"""
__abstract__: bool = True
class ModelMixin(MappedBase):
"""
模型混入类 - 提供通用字段和功能
基础模型混合类 Mixin: 一种面向对象编程概念, 使结构变得更加清晰
数据隔离设计原则:
==================
数据权限 (created_id/updated_id):
- 配合角色的data_scope字段实现精细化权限控制
- 1:仅本人
- 2:本部门
- 3:本部门及以下
- 4:全部数据
- 5:自定义
SQLAlchemy加载策略说明:
- select(默认): 延迟加载,访问时单独查询
- joined: 使用LEFT JOIN预加载
- selectin: 使用IN查询批量预加载(推荐用于一对多)
- subquery: 使用子查询预加载
- raise/raise_on_sql: 禁止加载
- noload: 不加载,返回None
- immediate: 立即加载
- write_only: 只写不读
- dynamic: 返回查询对象,支持进一步过滤
"""
__abstract__: bool = True
# 基础字段
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True, comment='主键ID')
uuid: Mapped[str] = mapped_column(String(64), default=uuid4_str, nullable=False, unique=True, comment='UUID全局唯一标识')
status: Mapped[str] = mapped_column(String(10), default='0', nullable=False, comment="是否启用(0:启用 1:禁用)")
description: Mapped[str | None] = mapped_column(Text, default=None, nullable=True, comment="备注/描述")
created_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
updated_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False, comment='更新时间')
class UserMixin(MappedBase):
"""
用户审计字段 Mixin
用于记录数据的创建者和更新者
用于实现数据权限中的"仅本人数据权限"
"""
__abstract__: bool = True
created_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey('sys_user.id', ondelete="SET NULL", onupdate="CASCADE"),
default=None,
nullable=True,
index=True,
comment="创建人ID"
)
updated_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey('sys_user.id', ondelete="SET NULL", onupdate="CASCADE"),
default=None,
nullable=True,
index=True,
comment="更新人ID"
)
@declared_attr
def created_by(cls) -> Mapped[Optional["UserModel"]]:
"""
创建人关联关系(延迟加载,避免循环依赖)
"""
return relationship(
"UserModel",
lazy="selectin",
foreign_keys=lambda: cls.created_id,
uselist=False
)
@declared_attr
def updated_by(cls) -> Mapped[Optional["UserModel"]]:
"""
更新人关联关系(延迟加载,避免循环依赖)
"""
return relationship(
"UserModel",
lazy="selectin",
foreign_keys=lambda: cls.updated_id,
uselist=False
)

View File

@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
from fastapi import Query
class PaginationQueryParam:
"""分页查询参数基类"""
def __init__(
self,
page_no: int = Query(default=1, description="当前页码", ge=1),
page_size: int = Query(default=10, description="每页数量", ge=1, le=100),
order_by: str | None = Query(default=None, description="排序字段,格式:field1,asc;field2,desc"),
) -> None:
"""
初始化分页查询参数。
参数:
- page_no (int | None): 当前页码,默认 None。
- page_size (int | None): 每页数量,默认 None最大 100。
- order_by (str | None): 排序字段,格式 'field,asc;field2,desc'
返回:
- None
"""
self.page_no = page_no
self.page_size = page_size
# 将字符串格式的order_by转换为服务层需要的List[Dict[str, str]]格式
if order_by:
try:
self.order_by = []
for item in order_by.split(';'):
if item.strip():
field, direction = item.split(',', 1)
self.order_by.append({field.strip(): direction.strip().lower()})
except ValueError:
# 如果解析失败,使用默认排序
self.order_by = [{'updated_time': 'desc'}]
else:
self.order_by = [{'updated_time': 'desc'}]

View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
from pydantic import BaseModel, ConfigDict, Field
from app.core.validator import DateTimeStr
class UserInfoSchema(BaseModel):
"""用户信息模型"""
model_config = ConfigDict(from_attributes=True)
id: int | None = Field(default=None, description="用户ID")
name: str | None = Field(default=None, description="用户姓名")
username: str | None = Field(default=None, description="用户名")
class CommonSchema(BaseModel):
"""通用信息模型"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(description="编号ID")
name: str = Field(description="名称")
class BaseSchema(BaseModel):
"""通用输出模型,包含基础字段和审计字段"""
model_config = ConfigDict(from_attributes=True, coerce_numbers_to_str=True)
id: int | None = Field(default=None, description="主键ID")
uuid: str | None = Field(default=None, description="UUID")
status: str = Field(default="0", description="状态")
description: str | None = Field(default=None, description="描述")
created_time: DateTimeStr | None = Field(default=None, description="创建时间")
updated_time: DateTimeStr | None = Field(default=None, description="更新时间")
class UserBySchema(BaseModel):
"""通用创建模型,包含基础字段和审计字段"""
model_config = ConfigDict(from_attributes=True)
created_id: int | None = Field(default=None, description="创建人ID")
created_by: UserInfoSchema | None = Field(default=None, description="创建人信息")
updated_id: int | None = Field(default=None, description="更新人ID")
updated_by: UserInfoSchema | None = Field(default=None, description="更新人信息")
class BatchSetAvailable(BaseModel):
"""批量设置可用状态的请求模型"""
ids: list[int] = Field(default_factory=list, description="ID列表")
status: str = Field(default="0", description="是否可用")
class UploadResponseSchema(BaseModel):
"""上传响应模型"""
model_config = ConfigDict(from_attributes=True)
file_path: str | None = Field(default=None, description='新文件映射路径')
file_name: str | None = Field(default=None, description='新文件名称')
origin_name: str | None = Field(default=None, description='原文件名称')
file_url: str | None = Field(default=None, description='新文件访问地址')
class DownloadFileSchema(BaseModel):
"""下载文件模型"""
file_path: str = Field(..., description='新文件映射路径')
file_name: str = Field(..., description='新文件名称')

View File

@@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
from redis.asyncio import Redis
from redis import exceptions
from fastapi import FastAPI
from sqlalchemy import create_engine, Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession, AsyncEngine
from app.core.logger import log
from app.config.setting import settings
from app.core.exceptions import CustomException
def create_engine_and_session(
db_url: str = settings.DB_URI
) -> tuple[Engine, sessionmaker]:
"""
创建同步数据库引擎和会话工厂。
参数:
- db_url (str): 数据库连接URL,默认从配置中获取。
返回:
- tuple[Engine, sessionmaker]: 同步数据库引擎和会话工厂。
"""
try:
if not settings.SQL_DB_ENABLE:
raise CustomException(msg="请先开启数据库连接", data="请启用 app/config/setting.py: SQL_DB_ENABLE")
# 同步数据库引擎
engine: Engine = create_engine(
url=db_url,
echo=settings.DATABASE_ECHO,
pool_pre_ping=settings.POOL_PRE_PING,
pool_recycle=settings.POOL_RECYCLE,
)
except Exception as e:
log.error(f'❌ 数据库连接失败 {e}')
raise
else:
# 同步数据库会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return engine, SessionLocal
def create_async_engine_and_session(
db_url: str = settings.ASYNC_DB_URI
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
"""
获取异步数据库会话连接。
返回:
- tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 异步数据库引擎和会话工厂。
"""
try:
if not settings.SQL_DB_ENABLE:
raise CustomException(msg="请先开启数据库连接", data="请启用 app/config/setting.py: SQL_DB_ENABLE")
# 异步数据库引擎
async_engine: AsyncEngine = create_async_engine(
url=db_url,
echo=settings.DATABASE_ECHO,
echo_pool=settings.ECHO_POOL,
pool_pre_ping=settings.POOL_PRE_PING,
future=settings.FUTURE,
pool_recycle=settings.POOL_RECYCLE,
pool_size=settings.POOL_SIZE,
max_overflow=settings.MAX_OVERFLOW,
pool_timeout=settings.POOL_TIMEOUT,
pool_use_lifo=settings.POOL_USE_LIFO,
)
except Exception as e:
log.error(f'❌ 数据库连接失败 {e}')
raise
else:
# 异步数据库会话工厂
AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
autocommit=settings.AUTOCOMMIT,
autoflush=settings.AUTOFETCH,
expire_on_commit=settings.EXPIRE_ON_COMMIT,
class_=AsyncSession
)
return async_engine, AsyncSessionLocal
engine, db_session = create_engine_and_session(settings.DB_URI)
async_engine, async_db_session = create_async_engine_and_session(settings.ASYNC_DB_URI)
async def redis_connect(app: FastAPI, status: str) -> Redis | None:
"""
创建或关闭Redis连接。
参数:
- app (FastAPI): FastAPI应用实例。
- status (bool): 连接状态,True为创建连接,False为关闭连接。
返回:
- Redis | None: Redis连接实例,如果连接失败则返回None。
"""
if not settings.REDIS_ENABLE:
raise CustomException(msg="请先开启Redis连接", data="请启用 app/core/config.py: REDIS_ENABLE")
if status:
try:
rd = await Redis.from_url(
url=settings.REDIS_URI,
encoding='utf-8',
decode_responses=True,
health_check_interval=20,
max_connections=settings.POOL_SIZE,
socket_timeout=settings.POOL_TIMEOUT
)
app.state.redis = rd
if await rd.ping():
log.info("✅️ Redis连接成功...")
return rd
except exceptions.AuthenticationError as e:
log.error(f"❌ 数据库 Redis 认证失败: {e}")
raise
except exceptions.TimeoutError as e:
log.error(f"❌ 数据库 Redis 连接超时: {e}")
raise
except exceptions.RedisError as e:
log.error(f"❌ 数据库 Redis 连接错误: {e}")
raise
else:
await app.state.redis.aclose()
log.info('✅️ Redis连接已关闭')

View File

@@ -0,0 +1,178 @@
# -*- coding: utf-8 -*-
import json
from redis.asyncio.client import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from typing import AsyncGenerator
from fastapi import Depends, Request
from fastapi import Depends
from app.common.enums import RedisInitKeyConfig
from app.core.exceptions import CustomException
from app.core.database import async_db_session
from app.core.redis_crud import RedisCURD
from app.core.security import OAuth2Schema, decode_access_token
from app.core.logger import log
from app.api.v1.module_system.user.model import UserModel
from app.api.v1.module_system.role.model import RoleModel
from app.api.v1.module_system.user.crud import UserCRUD
from app.api.v1.module_system.auth.schema import AuthSchema
async def db_getter() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话连接
返回:
- AsyncSession: 数据库会话连接
"""
async with async_db_session() as session:
async with session.begin():
yield session
async def redis_getter(request: Request) -> Redis:
"""获取Redis连接
参数:
- request (Request): 请求对象
返回:
- Redis: Redis连接
"""
return request.app.state.redis
async def get_current_user(
request: Request,
db: AsyncSession = Depends(db_getter),
redis: Redis = Depends(redis_getter),
token: str = Depends(OAuth2Schema),
) -> AuthSchema:
"""获取当前用户
参数:
- request (Request): 请求对象
- db (AsyncSession): 数据库会话
- redis (Redis): Redis连接
- token (str): 访问令牌
返回:
- AuthSchema: 认证信息模型
"""
if not token:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
# 处理Bearer token
if token.startswith('Bearer'):
token = token.split(' ')[1]
payload = decode_access_token(token)
if not payload or not hasattr(payload, 'is_refresh') or payload.is_refresh:
raise CustomException(msg="非法凭证", code=10401, status_code=401)
online_user_info = payload.sub
# 从Redis中获取用户信息
user_info = json.loads(online_user_info) # 确保是字典类型
session_id = user_info.get("session_id")
if not session_id:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
# 检查用户是否在线
online_ok = await RedisCURD(redis).exists(key=f'{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}')
if not online_ok:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
# 关闭数据权限过滤,避免当前用户查询被拦截
auth = AuthSchema(db=db, check_data_scope=False)
username = user_info.get("user_name")
if not username:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
# 获取用户信息使用深层预加载确保RoleModel.creator被正确加载
user = await UserCRUD(auth).get_by_username_crud(
username=username,
preload=[
"dept",
selectinload(UserModel.roles),
"positions",
"created_by"
]
)
if not user:
raise CustomException(msg="用户不存在", code=10401, status_code=401)
if not user.status:
raise CustomException(msg="用户已被停用", code=10401, status_code=401)
# 设置请求上下文
request.state.user = user
request.state.user_id = user.id
request.state.user_username = user.username
request.scope["user_id"] = user.id
request.scope["user_username"] = user.username
# 过滤可用的角色和职位
if hasattr(user, 'roles'):
user.roles = [role for role in user.roles if role and role.status]
if hasattr(user, 'positions'):
user.positions = [pos for pos in user.positions if pos and pos.status]
auth.user = user
return auth
class AuthPermission:
"""权限验证类"""
def __init__(self, permissions: list[str] | None = None, check_data_scope: bool = True) -> None:
"""
初始化权限验证
参数:
- permissions (list[str] | None): 权限标识列表。
- check_data_scope (bool): 是否启用严格模式校验。
"""
self.permissions = permissions or []
self.check_data_scope = check_data_scope
async def __call__(self, auth: AuthSchema = Depends(get_current_user)) -> AuthSchema:
"""
调用权限验证
参数:
- auth (AuthSchema): 认证信息对象。
返回:
- AuthSchema: 认证信息对象。
"""
auth.check_data_scope = self.check_data_scope
# 超级管理员直接通过
if auth.user and auth.user.is_superuser:
return auth
# 无需验证权限
if not self.permissions:
return auth
# 超级管理员权限标识
if "*" in self.permissions or "*:*:*" in self.permissions:
return auth
# 检查用户是否有角色
if not auth.user or not auth.user.roles:
raise CustomException(msg="无权限操作", code=10403, status_code=403)
# 获取用户权限集合
user_permissions = {
menu.permission
for role in auth.user.roles
for menu in role.menus
if role.status and menu.permission and menu.status
}
# 权限验证 - 满足任一权限即可
if not any(perm in user_permissions for perm in self.permissions):
log.error(f"用户缺少任何所需的权限: {self.permissions}")
raise CustomException(msg="无权限操作", code=10403, status_code=403)
return auth

View File

@@ -0,0 +1,356 @@
# -*- coding: utf-8 -*-
"""
集中式路由发现与注册
约定:
- 仅扫描 `app.api.v1` 包内,顶级目录以 `module_` 开头的模块。
- 在各模块任意子目录下的 `controller.py` 中定义的 `APIRouter` 实例会自动被注册。
- 顶级目录 `module_xxx` 会映射为容器路由前缀 `/<xxx>`。
设计目标:
- 稳定、可预测:有序扫描与注册,确定性日志输出。
- 简洁、易维护:职责拆分成小函数,类型提示与清晰注释。
- 安全、可控:去重处理、异常分层记录、可配置的前缀映射与忽略规则。
- 灵活、可扩展:基于类的设计,支持配置自定义和实例化多套路由系统。
"""
from __future__ import annotations
import importlib
from enum import Enum
from pathlib import Path
from typing import Callable, Iterable, Any
from functools import wraps
from fastapi import APIRouter
from app.core.logger import log
def _log_error_handling(func: Callable) -> Callable:
"""错误处理装饰器,用于统一捕获和记录方法执行过程中的异常"""
@wraps(func)
def wrapper(self: 'DiscoverRouter', *args: Any, **kwargs: Any) -> Any:
method_name = func.__name__
try:
return func(self, *args, **kwargs)
except ModuleNotFoundError as e:
log.error(f"❌️ 模块未找到 [{method_name}]: {str(e)}")
raise
except ImportError as e:
log.error(f"❌️ 导入错误 [{method_name}]: {str(e)}")
raise
except AttributeError as e:
log.error(f"❌️ 属性错误 [{method_name}]: {str(e)}")
raise
except Exception as e:
log.error(f"❌️ 未知错误 [{method_name}]: {str(e)}")
# 在调试模式下打印完整堆栈信息
if getattr(self, 'debug', False):
import traceback
log.error(traceback.format_exc())
raise
return wrapper
class DiscoverRouter:
"""
路由自动发现与注册器
提供基于约定的路由自动发现与注册功能,支持自定义配置和灵活扩展。
"""
def __init__(
self,
module_prefix: str = "module_",
base_package: str = "app.api.v1",
prefix_map: dict[str, str] | None = None,
exclude_dirs: set[str] | None = None,
exclude_files: set[str] | None = None,
auto_discover: bool = True,
debug: bool = False
) -> None:
"""
初始化路由发现注册器
参数:
- module_prefix: 模块目录前缀,默认为 "module_"
- base_package: 基础包名,默认为 "app.api.v1"
- prefix_map: 前缀映射字典,用于自定义路由前缀
- exclude_dirs: 排除的目录集合
- exclude_files: 排除的文件集合
- auto_discover: 是否在初始化时自动执行发现和注册,默认为 True
- debug: 是否启用调试模式,在调试模式下会输出更详细的错误信息,默认为 False
"""
self.module_prefix = module_prefix
self.base_package = base_package
self.prefix_map = prefix_map or {}
self.exclude_dirs = exclude_dirs or set()
self.exclude_files = exclude_files or set()
self.debug = debug
self._router = APIRouter()
self._seen_router_ids: set[int] = set()
self._discovery_stats: dict[str, int] = {
"scanned_files": 0,
"imported_modules": 0,
"included_routers": 0,
"container_count": 0
}
# 自动执行发现和注册
if auto_discover:
self.discover_and_register()
@property
def router(self) -> APIRouter:
"""获取根路由实例"""
return self._router
@property
def discovery_stats(self) -> dict[str, int]:
"""获取路由发现统计信息"""
return self._discovery_stats.copy()
@_log_error_handling
def _get_base_dir_and_pkg(self) -> tuple[Path, str]:
"""定位基础包的文件系统路径与包名。
返回:
- (Path, str): (包的路径, 包名)
"""
base_pkg = importlib.import_module(self.base_package)
base_dir = Path(next(iter(base_pkg.__path__)))
log.info(f"📁 基础包路径: {base_dir}, 包名: {base_pkg.__name__}")
return base_dir, base_pkg.__name__
def _iter_controller_files(self, base_dir: Path) -> Iterable[Path]:
"""递归查找并返回所有 `controller.py` 文件,按路径排序保证确定性。"""
try:
files = sorted(base_dir.rglob("controller.py"), key=lambda p: p.as_posix())
log.info(f"🔍 发现 {len(files)} 个控制器文件")
return files
except PermissionError as e:
log.error(f"❌️ 权限错误: 无法访问目录 {base_dir}: {str(e)}")
return []
except Exception as e:
log.error(f"❌️ 查找控制器文件失败: {str(e)}")
return []
def _resolve_prefix(self, top_module: str) -> str | None:
"""将顶级模块目录名解析为容器前缀。"""
if top_module in self.exclude_dirs:
if self.debug:
log.warning(f"⚠️ 目录 {top_module} 被排除")
return None
if not top_module.startswith(self.module_prefix):
if self.debug:
log.warning(f"⚠️ 目录 {top_module} 不符合前缀约定 {self.module_prefix}")
return None
mapped = self.prefix_map.get(top_module)
if mapped:
log.info(f"🔄 模块 {top_module} 映射到前缀 {mapped}")
return mapped
prefix = f"/{top_module[len(self.module_prefix):]}"
if self.debug:
log.debug(f"📋 模块 {top_module} 使用默认前缀 {prefix}")
return prefix
@_log_error_handling
def _include_module_routers(self, mod: object, container: APIRouter) -> int:
"""将模块中的所有 `APIRouter` 实例包含到指定容器路由中。
返回:
- int: 新增注册的路由数量
"""
from fastapi import APIRouter as _APIRouter
added = 0
mod_name = getattr(mod, "__name__", "<unknown>")
router_count = 0
for attr_name in dir(mod):
attr = getattr(mod, attr_name, None)
if isinstance(attr, _APIRouter):
router_count += 1
rid = id(attr)
if rid in self._seen_router_ids:
log.warning(f"⚠️ 路由 {attr_name} 在模块 {mod_name} 中已注册,跳过重复注册")
continue
self._seen_router_ids.add(rid)
container.include_router(attr)
added += 1
log.info(f" 注册路由 {attr_name} 到容器")
if router_count == 0:
log.warning(f"⚠️ 模块 {mod_name} 中未发现 APIRouter 实例")
return added
@_log_error_handling
def discover_and_register(self) -> dict[str, int]:
"""
执行路由发现与注册
返回:
- dict[str, int]: 包含发现统计信息的字典
- scanned_files: 扫描的文件数量
- imported_modules: 导入的模块数量
- included_routers: 注册的路由数量
- container_count: 容器数量
"""
log.info("🚀 开始路由发现与注册...")
base_dir, base_pkg = self._get_base_dir_and_pkg()
containers: dict[str, APIRouter] = {}
container_counts: dict[str, int] = {}
scanned_files = 0
imported_modules = 0
included_routers = 0
try:
for file in self._iter_controller_files(base_dir):
rel_path = file.relative_to(base_dir).as_posix()
scanned_files += 1
if rel_path in self.exclude_files:
log.warning(f"⚠️ 文件 {rel_path} 被排除")
continue
parts = file.relative_to(base_dir).parts
if len(parts) < 2:
log.warning(f"⚠️ 文件路径不完整: {rel_path},跳过")
continue
top_module = parts[0]
prefix = self._resolve_prefix(top_module)
if not prefix:
continue
# 拼接模块导入路径
mod_path = ".".join((base_pkg,) + tuple(parts[:-1]) + ("controller",))
try:
mod = importlib.import_module(mod_path)
imported_modules += 1
log.info(f"📥 导入模块: {mod_path}")
except ModuleNotFoundError as e:
log.error(f"❌️ 未找到控制器模块: {mod_path} -> {str(e)}")
continue
except ImportError as e:
log.error(f"❌️ 导入控制器失败: {mod_path} -> {str(e)}")
continue
container = containers.setdefault(prefix, APIRouter(prefix=prefix))
try:
added = self._include_module_routers(mod, container)
included_routers += added
container_counts[prefix] = container_counts.get(prefix, 0) + added
except Exception as e:
log.error(f"❌️ 注册控制器路由失败: {mod_path} -> {str(e)}")
# 将容器路由按前缀名称排序后注册到根路由,保证顺序稳定
for prefix in sorted(containers.keys()):
container = containers[prefix]
rid = id(container)
if rid in self._seen_router_ids:
continue
self._seen_router_ids.add(rid)
self._router.include_router(container)
# 更丰富的注册日志(含路由数量)
count = container_counts.get(prefix, 0)
log.info(f"✅️ 已注册模块容器: {prefix} (路由数: {count})")
# 更新统计信息
stats = {
"scanned_files": scanned_files,
"imported_modules": imported_modules,
"included_routers": included_routers,
"container_count": len(containers)
}
self._discovery_stats = stats
# 生成总结日志
log.info(
(
f"✅️ 路由发现完成: 扫描文件 {scanned_files}, "
f"导入模块 {imported_modules}, 注册路由 {included_routers}, "
f"容器 {len(containers)}"
)
)
return stats
except Exception as e:
log.error(f"❌️ 路由发现与注册过程失败: {str(e)}")
# 确保返回统计信息,即使过程中出错
return self._discovery_stats
def set_debug(self, debug: bool) -> 'DiscoverRouter':
"""设置调试模式
参数:
- debug: 是否开启调试模式
返回:
- self: 支持链式调用
"""
self.debug = debug
log_level = "DEBUG" if debug else "INFO"
log.info(f"⚙️ 调试模式已{'开启' if debug else '关闭'},日志级别: {log_level}")
return self
def add_exclude_dir(self, dir_name: str) -> 'DiscoverRouter':
"""添加排除的目录
参数:
- dir_name: 要排除的目录名称
返回:
- self: 支持链式调用
"""
self.exclude_dirs.add(dir_name)
log.info(f"📝 添加排除目录: {dir_name}")
return self
def add_prefix_map(self, module_name: str, prefix: str) -> 'DiscoverRouter':
"""添加前缀映射
参数:
- module_name: 模块名称
- prefix: 对应的路由前缀
返回:
- self: 支持链式调用
"""
self.prefix_map[module_name] = prefix
log.info(f"📝 添加前缀映射: {module_name} -> {prefix}")
return self
@_log_error_handling
def register_router(self, router: APIRouter, tags: list[str | Enum] | None = None) -> None:
"""手动注册一个路由实例
参数:
- router: 要注册的 APIRouter 实例
- tags: 路由标签,用于 API 文档分组
"""
rid = id(router)
if rid not in self._seen_router_ids:
self._seen_router_ids.add(rid)
self._router.include_router(router, tags=tags)
log.info(f"📌 手动注册路由,标签: {tags}")
else:
log.warning(f"⚠️ 路由已存在,跳过重复注册")
# 创建默认实例并执行自动发现注册
_discoverer = DiscoverRouter()
# 保持向后兼容,导出原始的 router 变量
router = _discoverer.router
# 导出 DiscoverRouter 类供外部使用
__all__ = ["DiscoverRouter", "router"]
# 执行自动发现注册(已由 DiscoverRouter 实例内部处理)

View File

@@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
from typing import Any
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError, ResponseValidationError
from pydantic_validation_decorator import FieldValidationError
from starlette.responses import JSONResponse
from starlette.exceptions import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from app.common.constant import RET
from app.common.response import ErrorResponse
from app.core.logger import log
class CustomException(Exception):
"""
自定义异常基类
"""
def __init__(
self,
msg: str = RET.EXCEPTION.msg,
code: int = RET.EXCEPTION.code,
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
data: Any | None = None,
success: bool = False
) -> None:
"""
初始化异常对象。
参数:
- msg (str): 错误消息。
- code (int): 业务状态码。
- status_code (int): HTTP 状态码。
- data (Any | None): 附加数据。
- success (bool): 是否成功标记,默认 False。
返回:
- None
"""
super().__init__(msg) # 调用父类初始化方法
self.status_code = status_code
self.code = code
self.msg = msg
self.data = data
self.success = success
def __str__(self) -> str:
"""返回异常消息
返回:
- str: 异常消息
"""
return self.msg
def handle_exception(app: FastAPI):
"""
注册全局异常处理器。
参数:
- app (FastAPI): 应用实例。
返回:
- None
"""
@app.exception_handler(CustomException)
async def CustomExceptionHandler(request: Request, exc: CustomException) -> JSONResponse:
"""
自定义异常处理器
参数:
- request (Request): 请求对象。
- exc (CustomException): 自定义异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
log.error(f"[自定义异常] {request.method} {request.url.path} | 错误码: {exc.code} | 错误信息: {exc.msg} | 详情: {exc.data}")
return ErrorResponse(msg=exc.msg, code=exc.code, status_code=exc.status_code, data=exc.data)
@app.exception_handler(HTTPException)
async def HttpExceptionHandler(request: Request, exc: HTTPException) -> JSONResponse:
"""
HTTP异常处理器
参数:
- request (Request): 请求对象。
- exc (HTTPException): HTTP异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
log.error(f"[HTTP异常] {request.method} {request.url.path} | 状态码: {exc.status_code} | 错误信息: {exc.detail}")
return ErrorResponse(msg=exc.detail, status_code=exc.status_code)
@app.exception_handler(RequestValidationError)
async def ValidationExceptionHandler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""
请求参数验证异常处理器
参数:
- request (Request): 请求对象。
- exc (RequestValidationError): 请求参数验证异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
error_mapping = {
"Field required": "请求失败,缺少必填项!",
"value is not a valid list": "类型错误,提交参数应该为列表!",
"value is not a valid int": "类型错误,提交参数应该为整数!",
"value could not be parsed to a boolean": "类型错误,提交参数应该为布尔值!",
"Input should be a valid list": "类型错误,输入应该是一个有效的列表!"
}
raw_msg = exc.errors()[0].get('msg')
msg = error_mapping.get(raw_msg, raw_msg)
# 去掉Pydantic默认的前缀“Value error”, 仅保留具体提示内容
if isinstance(msg, str) and msg.startswith("Value error"):
if "," in msg:
msg = msg.split(",", 1)[1].strip()
else:
msg = msg.replace("Value error", "").strip()
log.error(f"[参数验证异常] {request.method} {request.url.path} | 错误信息: {msg} | 原始错误: {exc.errors()}")
# 如果是bytes类型(如文件上传)不返回原始数据避免JSON序列化失败
response_data = None if isinstance(exc.body, bytes) else exc.body
return ErrorResponse(msg=str(msg), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, data=response_data)
@app.exception_handler(ResponseValidationError)
async def ResponseValidationHandle(request: Request, exc: ResponseValidationError) -> JSONResponse:
"""
响应参数验证异常处理器
参数:
- request (Request): 请求对象。
- exc (ResponseValidationError): 响应参数验证异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
log.error(f"[响应验证异常] {request.method} {request.url.path} | 错误信息: 响应数据格式错误 | 详情: {exc.errors()}")
# 如果是bytes类型(如文件上传)不返回原始数据避免JSON序列化失败
response_data = None if isinstance(exc.body, bytes) else exc.body
return ErrorResponse(msg="服务器响应格式错误", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, data=response_data)
@app.exception_handler(SQLAlchemyError)
async def SQLAlchemyExceptionHandler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
"""
数据库异常处理器
参数:
- request (Request): 请求对象。
- exc (SQLAlchemyError): 数据库异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
error_msg = '数据库操作失败'
exc_type = type(exc).__name__
# 对于生产环境,返回通用错误消息
log.error(f"[数据库异常] {request.method} {request.url.path} | 错误类型: {exc_type} | 错误详情: {str(exc)}")
return ErrorResponse(msg=f'{error_msg}: {exc_type}', status_code=status.HTTP_400_BAD_REQUEST, data=str(exc))
@app.exception_handler(ValueError)
async def ValueExceptionHandler(request: Request, exc: ValueError) -> JSONResponse:
"""
值异常处理器
参数:
- request (Request): 请求对象。
- exc (ValueError): 值异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
log.exception(f"[值异常] {request.method} {request.url.path} | 错误信息: {str(exc)}")
return ErrorResponse(msg=str(exc), status_code=status.HTTP_400_BAD_REQUEST)
@app.exception_handler(FieldValidationError)
async def FieldValidationExceptionHandler(request: Request, exc: FieldValidationError) -> JSONResponse:
"""
字段验证异常处理器
参数:
- request (Request): 请求对象。
- exc (FieldValidationError): 字段验证异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
log.error(f"[字段验证异常] {request.method} {request.url.path} | 错误信息: {exc.message}")
return ErrorResponse(msg=exc.message, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@app.exception_handler(Exception)
async def AllExceptionHandler(request: Request, exc: Exception) -> JSONResponse:
"""
全局异常处理器
参数:
- request (Request): 请求对象。
- exc (Exception): 异常实例。
返回:
- JSONResponse: 包含错误信息的 JSON 响应。
"""
exc_type = type(exc).__name__
log.error(f"[未捕获异常] {request.method} {request.url.path} | 错误类型: {exc_type} | 错误详情: {str(exc)}")
# 对于未捕获的异常,返回通用错误信息
return ErrorResponse(msg='服务器内部错误', status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, data=None)

View File

@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
import logging
import sys
import atexit
from typing_extensions import override
from loguru import logger
from app.config.path_conf import LOG_DIR
from app.config.setting import settings
# 全局变量记录日志处理器ID
_logger_handlers = []
class InterceptHandler(logging.Handler):
"""
日志拦截处理器:将所有 Python 标准日志重定向到 Loguru
工作原理:
1. 继承自 logging.Handler
2. 重写 emit 方法处理日志记录
3. 将标准库日志转换为 Loguru 格式
"""
@override
def emit(self, record: logging.LogRecord) -> None:
# 尝试获取日志级别名称
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# 获取调用帧信息增加None检查
frame, depth = logging.currentframe(), 2
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
# 使用 Loguru 记录日志
logger.opt(depth=depth, exception=record.exc_info).log(
level,
record.getMessage()
)
def cleanup_logging():
"""
清理日志资源
在程序退出时调用,确保所有日志处理器被正确关闭
"""
global _logger_handlers
for handler_id in _logger_handlers:
try:
logger.remove(handler_id)
except Exception:
pass
_logger_handlers.clear()
def setup_logging():
"""
配置日志系统
功能:
1. 控制台彩色输出
2. 文件日志轮转
3. 错误日志单独存储
4. 智能异步策略:开发环境同步(避免reload资源泄漏),生产环境异步(高性能)
"""
global _logger_handlers
if _logger_handlers:
return
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
# 添加上下文信息
_ = logger.configure(extra={"app_name": "FastapiAdmin"})
# 步骤1移除默认处理器
logger.remove()
# 步骤2定义日志格式
log_format = (
# 时间信息
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
# 日志级别,居中对齐
"<level>{level: <8}</level> | "
# 文件、函数和行号
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
# 日志消息
"<level>{message}</level>"
)
# 智能选择异步策略:开发环境禁用异步(避免reload时资源泄漏),生产环境启用异步(提升性能)
use_async = not settings.DEBUG
# 步骤3配置控制台输出
handler_id = logger.add(
sys.stdout,
format=log_format,
level="DEBUG" if settings.DEBUG else "INFO",
enqueue=use_async, # 开发同步,生产异步
backtrace=True, # 显示完整的异常回溯
diagnose=True, # 显示变量值等诊断信息
colorize=True # 启用彩色输出
)
_logger_handlers.append(handler_id)
# 步骤4创建日志目录
log_dir = LOG_DIR
# 确保日志目录存在,如果不存在则创建
log_dir.mkdir(parents=True, exist_ok=True)
# 步骤5配置常规日志文件
handler_id = logger.add(
str(log_dir / "info.log"),
format=log_format,
level="INFO",
rotation="100 MB", # 按文件大小轮转,避免 Windows 文件锁问题
retention=30, # 日志保留天数,超过此天数的日志文件将被自动清理
compression="gz",
encoding="utf-8",
enqueue=use_async, # 开发同步,生产异步
catch=True # 捕获日志处理过程中的异常,避免程序崩溃
)
_logger_handlers.append(handler_id)
# 步骤6配置错误日志文件
handler_id = logger.add(
str(log_dir / "error.log"),
format=log_format,
level="ERROR",
rotation="100 MB", # 按文件大小轮转,避免 Windows 文件锁问题
retention=30, # 日志保留天数,超过此天数的日志文件将被自动清理
compression="gz",
encoding="utf-8",
enqueue=use_async, # 开发同步,生产异步
backtrace=True,
diagnose=True,
catch=True # 捕获日志处理过程中的异常,避免程序崩溃
)
_logger_handlers.append(handler_id)
# 步骤7配置标准库日志
logging.basicConfig(handlers=[InterceptHandler()], level="DEBUG" if settings.DEBUG else "INFO", force=True)
logger_name_list = [name for name in logging.root.manager.loggerDict]
# 步骤8配置第三方库日志
for logger_name in logger_name_list:
_logger = logging.getLogger(logger_name)
_logger.handlers = [InterceptHandler()]
_logger.propagate = False
# 注册退出清理函数
atexit.register(cleanup_logging)
setup_logging()
log = logger

View File

@@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
import json
import time
from starlette.middleware.cors import CORSMiddleware
from starlette.types import ASGIApp
from starlette.requests import Request
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.common.response import ErrorResponse
from app.config.setting import settings
from app.core.logger import log
from app.core.exceptions import CustomException
from app.core.security import decode_access_token
from app.api.v1.module_system.params.service import ParamsService
class CustomCORSMiddleware(CORSMiddleware):
"""CORS跨域中间件"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(
app,
allow_origins=settings.ALLOW_ORIGINS,
allow_methods=settings.ALLOW_METHODS,
allow_headers=settings.ALLOW_HEADERS,
allow_credentials=settings.ALLOW_CREDENTIALS,
expose_headers=settings.CORS_EXPOSE_HEADERS,
)
class RequestLogMiddleware:
"""
记录请求日志中间件
注意:使用纯 ASGI 中间件实现,避免 BaseHTTPMiddleware 缓冲流式响应的问题。
BaseHTTPMiddleware 会等待整个响应体完成后才返回,这会破坏流式响应的功能。
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
@staticmethod
def _extract_session_id_from_headers(headers: list) -> str | None:
"""
从请求头中提取session_id
参数:
- headers (list): ASGI 格式的请求头列表
返回:
- str | None: 会话ID如果无法提取则返回None
"""
try:
authorization = None
for key, value in headers:
if key == b'authorization':
authorization = value.decode('utf-8')
break
if not authorization:
return None
# 处理Bearer token
token = authorization.replace('Bearer ', '').strip()
# 解码token
payload = decode_access_token(token)
if not payload or not hasattr(payload, 'sub'):
return None
# 从payload中提取session_id
user_info = json.loads(payload.sub)
return user_info.get("session_id")
except Exception:
return None
@staticmethod
def _get_header_value(headers: list, key: bytes) -> str | None:
"""从头列表中获取指定的头值"""
for k, v in headers:
if k.lower() == key.lower():
return v.decode('utf-8')
return None
async def __call__(self, scope, receive, send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request = Request(scope, receive)
# 尝试提取session_id
session_id = self._extract_session_id_from_headers(scope.get('headers', []))
# 组装请求日志字段
client_host = scope.get('client', ['未知'])[0] if scope.get('client') else '未知'
log_fields = [
f"请求来源: {client_host}",
f"请求方法: {scope.get('method', 'UNKNOWN')}",
f"请求路径: {scope.get('path', '/')}",
]
log.info(log_fields)
# 获取请求路径
path = scope.get("path")
# 尝试获取客户端真实IP
headers = scope.get('headers', [])
x_forwarded_for = self._get_header_value(headers, b'x-forwarded-for')
if x_forwarded_for:
request_ip = x_forwarded_for.split(',')[0].strip()
else:
request_ip = client_host
# 检查是否启用演示模式
demo_enable = False
ip_white_list = []
white_api_list_path = []
ip_black_list = []
try:
# 从应用实例获取Redis连接
redis = request.app.state.redis
if redis:
# 使用ParamsService获取系统配置
system_config = await ParamsService.get_system_config_for_middleware(redis)
demo_enable = system_config["demo_enable"]
ip_white_list = system_config["ip_white_list"]
white_api_list_path = system_config["white_api_list_path"]
ip_black_list = system_config["ip_black_list"]
except Exception as e:
log.error(f"获取系统配置失败: {e}")
# 检查是否需要拦截请求
should_block = False
block_reason = ""
method = scope.get('method', '')
# 1. 首先检查IP是否在黑名单中
if request_ip and request_ip in ip_black_list:
should_block = True
block_reason = f"IP地址 {request_ip} 在黑名单中"
# 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
elif demo_enable in ["true", "True"] and method != "GET":
is_ip_whitelisted = request_ip in ip_white_list
is_path_whitelisted = path in white_api_list_path
if not is_ip_whitelisted and not is_path_whitelisted:
should_block = True
block_reason = f"演示模式下拦截非GET请求IP: {request_ip}, 路径: {path}"
if should_block:
log.warning([
f"会话ID: {session_id or '未认证'}",
f"请求被拦截: {block_reason}",
f"请求来源: {request_ip}",
f"请求方法: {method}",
f"请求路径: {path}",
f"演示模式: {demo_enable}"
])
# 返回错误响应
response = ErrorResponse(msg="演示环境,禁止操作")
await response(scope, receive, send)
return
# 用于追踪响应状态
response_started = False
response_status = 0
async def send_wrapper(message):
nonlocal response_started, response_status
if message["type"] == "http.response.start":
response_started = True
response_status = message.get("status", 0)
# 计算处理时间并添加到响应头
process_time = round(time.time() - start_time, 5)
headers = list(message.get("headers", []))
headers.append((b"x-process-time", str(process_time).encode()))
message = {**message, "headers": headers}
elif message["type"] == "http.response.body":
# 如果是最后一个body chunk记录日志
if not message.get("more_body", False):
process_time = round(time.time() - start_time, 5)
log.info(
f"响应状态: {response_status}, "
f"处理时间: {round(process_time * 1000, 3)}ms"
)
await send(message)
try:
await self.app(scope, receive, send_wrapper)
except Exception as e:
log.error(f"中间件处理异常: {str(e)}")
if not response_started:
response = ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
await response(scope, receive, send)
class CustomGZipMiddleware(GZipMiddleware):
"""GZip压缩中间件"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(
app,
minimum_size=settings.GZIP_MIN_SIZE,
compresslevel=settings.GZIP_COMPRESS_LEVEL
)

View File

@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
from typing import Any
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy import select
from app.api.v1.module_system.user.model import UserModel
from app.api.v1.module_system.dept.model import DeptModel
from app.api.v1.module_system.auth.schema import AuthSchema
from app.utils.common_util import get_child_id_map, get_child_recursion
class Permission:
"""
为业务模型提供数据权限过滤功能
"""
# 数据权限常量定义,提高代码可读性
DATA_SCOPE_SELF = 1 # 仅本人数据
DATA_SCOPE_DEPT = 2 # 本部门数据
DATA_SCOPE_DEPT_AND_CHILD = 3 # 本部门及以下数据
DATA_SCOPE_ALL = 4 # 全部数据
DATA_SCOPE_CUSTOM = 5 # 自定义数据
def __init__(self, model: Any, auth: AuthSchema):
"""
初始化权限过滤器实例
Args:
db: 数据库会话
model: 数据模型类
current_user: 当前用户对象
auth: 认证信息对象
"""
self.model = model
self.auth = auth
self.conditions: list[ColumnElement] = [] # 权限条件列表
async def filter_query(self, query: Any) -> Any:
"""
异步过滤查询对象
Args:
query: SQLAlchemy查询对象
Returns:
过滤后的查询对象
"""
condition = await self.__permission_condition()
return query.where(condition) if condition is not None else query
async def __permission_condition(self) -> ColumnElement | None:
"""
应用数据范围权限隔离
基于角色的五种数据权限范围过滤
支持五种权限类型:
1. 仅本人数据权限 - 只能查看自己创建的数据
2. 本部门数据权限 - 只能查看同部门的数据
3. 本部门及以下数据权限 - 可以查看本部门及所有子部门的数据
4. 全部数据权限 - 可以查看所有数据
5. 自定义数据权限 - 通过role_dept_relation表定义可访问的部门列表
权限处理原则:
- 多个角色的权限取并集(最宽松原则)
- 优先级:全部数据 > 部门权限2、3、5的并集> 仅本人
- 构造权限过滤表达式返回None表示不限制
"""
# 如果不需要检查数据权限,则不限制
if not self.auth.user:
return None
# 如果检查数据权限为False,则不限制
if not self.auth.check_data_scope:
return None
# 如果模型没有创建人created_id字段,则不限制
if not hasattr(self.model, "created_id"):
return None
# 超级管理员可以查看所有数据
if self.auth.user.is_superuser:
return None
# 如果用户没有角色,则只能查看自己的数据
roles = getattr(self.auth.user, "roles", []) or []
if not roles:
created_id_attr = getattr(self.model, "created_id", None)
if created_id_attr is not None:
return created_id_attr == self.auth.user.id
return None
# 获取用户所有角色的权限范围
data_scopes = set()
custom_dept_ids = set() # 自定义权限data_scope=5关联的部门ID集合
for role in roles:
data_scopes.add(role.data_scope)
# 收集自定义权限data_scope=5关联的部门ID
if role.data_scope == self.DATA_SCOPE_CUSTOM and hasattr(role, 'depts') and role.depts:
for dept in role.depts:
custom_dept_ids.add(dept.id)
# 权限优先级处理:全部数据权限最高优先级
if self.DATA_SCOPE_ALL in data_scopes:
return None
# 收集所有可访问的部门ID2、3、5权限的并集
accessible_dept_ids = set()
user_dept_id = getattr(self.auth.user, "dept_id", None)
# 处理自定义数据权限5
if self.DATA_SCOPE_CUSTOM in data_scopes:
accessible_dept_ids.update(custom_dept_ids)
# 处理本部门数据权限2
if self.DATA_SCOPE_DEPT in data_scopes:
if user_dept_id is not None:
accessible_dept_ids.add(user_dept_id)
# 处理本部门及以下数据权限3
if self.DATA_SCOPE_DEPT_AND_CHILD in data_scopes:
if user_dept_id is not None:
try:
# 查询所有部门并递归获取子部门
dept_sql = select(DeptModel)
dept_result = await self.auth.db.execute(dept_sql)
dept_objs = dept_result.scalars().all()
id_map = get_child_id_map(dept_objs)
# get_child_recursion返回的结果已包含自身ID和所有子部门ID
dept_with_children_ids = get_child_recursion(id=user_dept_id, id_map=id_map)
accessible_dept_ids.update(dept_with_children_ids)
except Exception:
# 查询失败时降级到本部门
accessible_dept_ids.add(user_dept_id)
# 如果有部门权限2、3、5任一使用部门过滤
if accessible_dept_ids:
creator_rel = getattr(self.model, "created_by", None)
# 优先使用关系过滤(性能更好)
if creator_rel is not None and hasattr(UserModel, 'dept_id'):
return creator_rel.has(getattr(UserModel, 'dept_id').in_(list(accessible_dept_ids)))
# 降级方案如果模型没有created_by关系但有created_id则只能查看自己的数据
else:
created_id_attr = getattr(self.model, "created_id", None)
if created_id_attr is not None:
return created_id_attr == self.auth.user.id
return None
# 处理仅本人数据权限1
if self.DATA_SCOPE_SELF in data_scopes:
created_id_attr = getattr(self.model, "created_id", None)
if created_id_attr is not None:
return created_id_attr == self.auth.user.id
return None
# 默认情况:如果用户有角色但没有任何有效权限范围,只能查看自己的数据
created_id_attr = getattr(self.model, "created_id", None)
if created_id_attr is not None:
return created_id_attr == self.auth.user.id
return None

View File

@@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
import pickle
from typing import Any, Awaitable
from redis.asyncio.client import Redis
from app.core.logger import log
class RedisCURD:
"""缓存工具类"""
def __init__(self, redis: Redis) -> None:
"""初始化"""
self.redis = redis
async def mget(self, keys: list) -> list:
"""批量获取缓存
参数:
- keys (list): 键名列表
返回:
- list: 返回缓存值列表,如果获取失败则返回空列表
"""
try:
data = await self.redis.mget(*[str(key) for key in keys])
return data
except Exception as e:
log.error(f"批量获取缓存失败: {str(e)}")
return []
async def get_keys(self, pattern: str = "*") -> list:
"""获取缓存键名
参数:
- pattern (str, optional): 匹配模式,默认值为"*"
返回:
- list: 返回匹配的缓存键名列表,如果获取失败则返回空列表
"""
try:
keys = await self.redis.keys(f"{pattern}")
return keys
except Exception as e:
log.error(f"获取缓存键名失败: {str(e)}")
return []
async def get(self, key: str) -> Any:
"""获取缓存
参数:
- key (str): 缓存键名
返回:
- Any: 返回缓存值,如果缓存不存在则返回None
"""
try:
data = await self.redis.get(f"{key}")
if data is None:
return None
return data
except Exception as e:
log.error(f"获取缓存失败: {str(e)}")
return None
async def set(self, key: str, value: Any, expire: int | None = None) -> bool:
"""设置缓存
参数:
- key (str): 缓存键名
- value (Any): 缓存值
- expire (int | None, optional): 过期时间,单位为秒,默认值为None。
返回:
- bool: 如果设置缓存成功则返回True,否则返回False
"""
try:
# 根据数据类型选择序列化方式
if isinstance(value, (int, float, str)):
data = str(value).encode('utf-8')
else:
try:
data = pickle.dumps(value)
except Exception as e:
log.error(f"序列化数据失败: {str(e)}")
return False
await self.redis.set(
name = key,
value = data,
ex=expire
)
return True
except Exception as e:
log.error(f"设置缓存失败: {str(e)}")
return False
async def delete(self, *keys: str) -> bool:
"""删除缓存
参数:
- keys (str): 缓存键名
返回:
- bool: 如果删除缓存成功则返回True,否则返回False
"""
try:
await self.redis.delete(*keys)
return True
except Exception as e:
log.error(f"删除缓存失败: {str(e)}")
return False
async def clear(self, pattern: str = "*") -> bool:
"""清空缓存
参数:
- pattern (str, optional): 匹配模式,默认值为"*"
返回:
- bool: 如果清空缓存成功则返回True,否则返回False
"""
try:
keys = await self.redis.keys(f"{pattern}")
if keys:
await self.redis.delete(*keys)
return True
except Exception as e:
log.error(f"清空缓存失败: {str(e)}")
return False
async def exists(self, key: str) -> bool:
"""判断缓存是否存在
参数:
- key (str): 缓存键名
返回:
- bool: 如果缓存存在则返回True,否则返回False
"""
try:
return await self.redis.exists(f"{key}")
except Exception as e:
log.error(f"判断缓存是否存在失败: {str(e)}")
return False
async def ttl(self, key: str) -> int:
"""获取缓存过期时间
参数:
- key (str): 缓存键名
返回:
- int: 返回缓存过期时间,单位为秒,如果缓存没有设置过期时间则返回-1
"""
try:
return await self.redis.ttl(f"{key}")
except Exception as e:
log.error(f"获取缓存过期时间失败: {str(e)}")
return -1
async def expire(self, key: str, expire: int) -> bool:
"""设置缓存过期时间
参数:
- key (str): 缓存键名
- expire (int): 过期时间,单位为秒
返回:
- bool: 如果设置缓存过期时间成功则返回True,否则返回False
"""
try:
return await self.redis.expire(f"{key}", expire)
except Exception as e:
log.error(f"设置缓存过期时间失败: {str(e)}")
return False
async def info(self) -> dict:
"""获取缓存信息
返回:
- dict: 返回缓存信息字典,如果获取失败则返回空字典
"""
try:
return await self.redis.info()
except Exception as e:
log.error(f"获取缓存信息失败: {str(e)}")
return {}
async def db_size(self) -> int:
"""获取数据库大小
返回:
- int: 返回数据库大小,如果获取失败则返回0
"""
try:
return await self.redis.dbsize()
except Exception as e:
log.error(f"获取数据库大小失败: {str(e)}")
return 0
async def commandstats(self) -> dict:
"""获取命令统计信息
返回:
- dict: 返回命令统计信息字典,如果获取失败则返回空字典
"""
try:
return await self.redis.info("commandstats")
except Exception as e:
log.error(f"获取命令统计信息失败: {str(e)}")
return {}
async def hash_set(self, name: str, key: str, value: Any) -> bool:
"""设置哈希缓存
参数:
- name (str): 哈希缓存名称
- key (str): 哈希缓存键名
- value (Any): 哈希缓存值
返回:
- bool: 如果设置哈希缓存成功则返回True,否则返回False
"""
try:
self.redis.hset(name=name, key=key, value=value)
return True
except Exception as e:
log.error(f"设置哈希缓存失败: {str(e)}")
return False
async def hash_get(self, name: str, keys: list[str]) -> Awaitable[list[Any]] | list[Any]:
"""获取哈希缓存
参数:
- name (str): 哈希缓存名称
- keys (list[str]): 哈希缓存键名列表
返回:
- Awaitable[list[Any]] | list[Any]: 返回哈希缓存值列表,如果获取失败则返回空列表
"""
try:
data = self.redis.hmget(name=name, keys=keys)
return data
except Exception as e:
log.error(f"获取哈希缓存失败: {str(e)}")
return []

View File

@@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
import time
import json
from typing import Any, Callable, Coroutine
from fastapi import Request, Response
from fastapi.routing import APIRoute
from user_agents import parse
from app.core.database import async_db_session
from app.config.setting import settings
from app.utils.ip_local_util import IpLocalUtil
from app.api.v1.module_system.auth.schema import AuthSchema
from app.api.v1.module_system.log.schema import OperationLogCreateSchema
from app.api.v1.module_system.log.service import OperationLogService
"""
在 FastAPI 中route_class 参数用于自定义路由的行为。
通过设置 route_class你可以定义一个自定义的路由类从而在每个路由处理之前或之后执行特定的操作。
这对于日志记录、权限验证、性能监控等场景非常有用。
"""
class OperationLogRoute(APIRoute):
"""操作日志路由装饰器"""
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
"""
自定义路由处理程序,在每个路由处理之前或之后执行特定的操作。
参数:
- request (Request): FastAPI请求对象。
返回:
- Response: FastAPI响应对象。
"""
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
"""
自定义路由处理程序,在每个路由处理之前或之后执行特定的操作。
参数:
- request (Request): FastAPI请求对象。
描述:
- 该方法在每个路由处理之前被调用,用于记录操作日志。
返回:
- Response: FastAPI响应对象。
"""
start_time = time.time()
# 请求前的处理
response: Response = await original_route_handler(request)
# 请求后的处理
if not settings.OPERATION_LOG_RECORD:
return response
if request.method not in settings.OPERATION_RECORD_METHOD:
return response
route: APIRoute = request.scope.get("route", None)
if route.name in settings.IGNORE_OPERATION_FUNCTION:
return response
user_agent = parse(request.headers.get("user-agent"))
payload = b"{}"
req_content_type = request.headers.get("Content-Type", "")
if req_content_type and (
req_content_type.startswith('multipart/form-data') or req_content_type.startswith('application/x-www-form-urlencoded')
):
form_data = await request.form()
oper_param = '\n'.join([f'{k}: {v}' for k, v in form_data.items()])
payload = oper_param # 直接使用字符串格式的参数
else:
payload = await request.body()
path_params = request.path_params
oper_param = {}
# 处理请求体数据
if payload:
try:
oper_param['body'] = json.loads(payload.decode())
except (json.JSONDecodeError, UnicodeDecodeError):
oper_param['body'] = payload.decode('utf-8', errors='ignore')
# 处理路径参数
if path_params:
oper_param['path_params'] = dict(path_params)
payload = json.dumps(oper_param, ensure_ascii=False)
# 日志表请求参数字段长度最大为2000因此在此处判断长度
if len(payload) > 2000:
payload = '请求参数过长'
response_data = response.body if "application/json" in response.headers.get("Content-Type", "") else b"{}"
process_time = f"{(time.time() - start_time):.2f}s"
# 获取当前用户ID,如果是登录接口则为空
log_type = 1 # 1:登录日志 2:操作日志
current_user_id = None
# 优化只在操作日志场景下获取current_user_id
if "user_id" in request.scope:
current_user_id = request.scope.get("user_id")
log_type = 2
request_ip = None
x_forwarded_for = request.headers.get('X-Forwarded-For')
if x_forwarded_for:
# 取第一个 IP 地址,通常为客户端真实 IP
request_ip = x_forwarded_for.split(',')[0].strip()
else:
# 若没有 X-Forwarded-For 头,则使用 request.client.host
if request.client:
request_ip = request.client.host
login_location = await IpLocalUtil.get_ip_location(request_ip) if request_ip else None
# 判断请求是否来自api文档
referer = request.headers.get('referer')
request_from_swagger = referer and referer.endswith('docs')
request_from_redoc = referer and referer.endswith('redoc')
if request_from_swagger or request_from_redoc:
# 如果请求来自api文档则不记录日志
pass
else:
async with async_db_session() as session:
async with session.begin():
auth = AuthSchema(db=session)
await OperationLogService.create_log_service(data=OperationLogCreateSchema(
type = log_type,
request_path = request.url.path,
request_method = request.method,
request_payload = payload,
request_ip = request_ip,
login_location=login_location,
request_os = user_agent.os.family,
request_browser = user_agent.browser.family,
response_code = response.status_code,
response_json = response_data.decode() if isinstance(response_data, (bytes, bytearray)) else str(response_data),
process_time = process_time,
description = route.summary,
created_id = current_user_id,
updated_id = current_user_id,
), auth = auth)
return response
return custom_route_handler

View File

@@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
import jwt
from fastapi import Form, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.security.utils import get_authorization_scheme_param
from app.core.exceptions import CustomException
from app.config.setting import settings
from app.api.v1.module_system.auth.schema import JWTPayloadSchema
class CustomOAuth2PasswordBearer(OAuth2PasswordBearer):
"""自定义OAuth2认证类,继承自OAuth2PasswordBearer"""
def __init__(
self,
token_url: str,
scheme_name: str | None = None,
scopes: dict[str, str] | None = None,
description: str | None = None,
auto_error: bool = True
) -> None:
super().__init__(
tokenUrl=token_url,
scheme_name=scheme_name,
scopes=scopes,
description=description,
auto_error=auto_error
)
async def __call__(self, request: Request) -> str | None:
"""
重写认证方法,校验token
参数:
- request (Request): FastAPI请求对象。
返回:
- str | None: 校验通过的token,如果校验失败则返回None。
异常:
- CustomException: 认证失败时抛出,状态码为401。
"""
authorization = request.headers.get("Authorization")
scheme, token = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != settings.TOKEN_TYPE:
if self.auto_error:
raise CustomException(msg="认证失败,请登录后再试", code=10401, status_code=401)
return None
return token
class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
"""
自定义登录表单,扩展验证码等字段
参数:
- grant_type (str | None): 授权类型,默认值为None,正则表达式为'password'
- scope (str): 作用域,默认值为空字符串。
- client_id (str | None): 客户端ID,默认值为None。
- client_secret (str | None): 客户端密钥,默认值为None。
- username (str): 用户名。
- password (str): 密码。
- captcha_key (str | None): 验证码键,默认值为空字符串。
- captcha (str | None): 验证码值,默认值为空字符串。
- login_type (str | None): 登录类型,默认值为"PC端",描述为"PC端 | 移动端"
"""
def __init__(
self,
grant_type: str | None = Form(default=None, pattern='password'),
scope: str = Form(default=''),
client_id: str | None = Form(default=None),
client_secret: str | None = Form(default=None),
username: str = Form(),
password: str = Form(),
captcha_key: str | None = Form(default=""),
captcha: str | None = Form(default=""),
login_type: str | None = Form(default="PC端", description="PC端 | 移动端")
):
super().__init__(
grant_type=grant_type,
scope=scope,
client_id=client_id,
client_secret=client_secret,
username=username,
password=password,
)
self.captcha_key = captcha_key
self.captcha = captcha
self.login_type = login_type
# OAuth2认证配置
OAuth2Schema = CustomOAuth2PasswordBearer(
token_url="system/auth/login",
description="认证"
)
def create_access_token(payload: JWTPayloadSchema) -> str:
"""
生成JWT访问令牌
参数:
- payload (JWTPayloadSchema): JWT有效载荷,包含用户信息等。
返回:
- str: 生成的JWT访问令牌。
"""
payload_dict = payload.model_dump()
return jwt.encode(
payload=payload_dict,
key=settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
def decode_access_token(token: str) -> JWTPayloadSchema:
"""
解析JWT访问令牌
参数:
- token (str): JWT访问令牌字符串。
返回:
- JWTPayloadSchema: 解析后的JWT有效载荷,包含用户信息等。
异常:
- CustomException: 解析失败时抛出,状态码为401。
"""
if not token:
raise CustomException(msg="认证不存在,请重新登录", code=10401, status_code=401)
try:
payload = jwt.decode(
jwt=token,
key=settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
online_user_info = payload.get("sub")
if not online_user_info:
raise CustomException(msg="无效认证,请重新登录", code=10401, status_code=401)
return JWTPayloadSchema(**payload)
except (jwt.InvalidSignatureError, jwt.DecodeError):
raise CustomException(msg="无效认证,请重新登录", code=10401, status_code=401)
except jwt.ExpiredSignatureError:
raise CustomException(msg="认证已过期,请重新登录", code=10401, status_code=401)
except jwt.InvalidTokenError:
raise CustomException(msg="token已失效,请重新登录", code=10401, status_code=401)

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
from pydantic import BaseModel
from typing import Any, TypeVar, Type, Generic
from sqlalchemy.orm import DeclarativeBase
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
class Serialize(Generic[ModelType, SchemaType]):
"""
序列化工具类提供模型、Schema 和字典之间的转换功能
"""
@classmethod
def schema_to_model(cls,schema: Type[SchemaType], model: Type[ModelType]) -> ModelType:
"""
将 Pydantic Schema 转换为 SQLAlchemy 模型
参数:
- schema (Type[SchemaType]): Pydantic Schema 实例。
- model (Type[ModelType]): SQLAlchemy 模型类。
返回:
- ModelType: SQLAlchemy 模型实例。
异常:
- ValueError: 转换过程中可能抛出的异常。
"""
try:
return model(**cls.model_to_dict(model, schema))
except Exception as e:
raise ValueError(f"序列化失败: {str(e)}")
@classmethod
def model_to_dict(cls, model: Type[ModelType], schema: Type[SchemaType]) -> dict[str, Any]:
"""
将 SQLAlchemy 模型转换为 Pydantic Schema
参数:
- model (Type[ModelType]): SQLAlchemy 模型实例。
- schema (Type[SchemaType]): Pydantic Schema 类。
返回:
- dict[str, Any]: 包含模型数据的字典。
异常:
- ValueError: 转换过程中可能抛出的异常。
"""
try:
return schema.model_validate(model).model_dump()
except Exception as e:
raise ValueError(f"反序列化失败: {str(e)}")

View File

@@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
import re
from datetime import datetime
from typing import Annotated
from pydantic import AfterValidator, PlainSerializer, WithJsonSchema
from app.common.constant import RET
from app.core.exceptions import CustomException
# 自定义日期时间字符串类型
DateTimeStr = Annotated[
datetime,
AfterValidator(lambda x: datetime_validator(x)),
PlainSerializer(lambda x: x.strftime('%Y-%m-%d %H:%M:%S') if isinstance(x, datetime) else str(x), return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
# 自定义手机号类型
Telephone = Annotated[
str,
AfterValidator(lambda x: mobile_validator(x)),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
# 自定义邮箱类型
Email = Annotated[
str,
AfterValidator(lambda x: email_validator(x)),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
def datetime_validator(value: str | datetime) -> datetime:
"""
日期格式验证器。
参数:
- value (str | datetime): 日期值。
返回:
- datetime: 格式化后的日期。
异常:
- CustomException: 日期格式无效时抛出。
"""
pattern = "%Y-%m-%d %H:%M:%S"
try:
if isinstance(value, str):
return datetime.strptime(value, pattern)
elif isinstance(value, datetime):
return value
except Exception:
raise CustomException(code=RET.ERROR.code, msg="无效的日期格式")
def email_validator(value: str) -> str:
"""
邮箱地址验证器。
参数:
- value (str): 邮箱地址。
返回:
- str: 验证后的邮箱地址。
异常:
- CustomException: 邮箱格式无效时抛出。
"""
if not value:
raise CustomException(code=RET.ERROR.code, msg="邮箱地址不能为空")
regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(regex, value):
raise CustomException(code=RET.ERROR.code, msg="邮箱地址格式不正确")
return value
def mobile_validator(value: str | None) -> str | None:
"""
手机号验证器。
参数:
- value (str | None): 手机号。
返回:
- str | None: 验证后的手机号。
异常:
- CustomException: 手机号格式无效时抛出。
"""
if not value:
return value
if len(value) != 11 or not value.isdigit():
raise CustomException(code=RET.ERROR.code, msg="手机号格式不正确")
regex = r'^1(3\d|4[4-9]|5[0-35-9]|6[67]|7[013-8]|8[0-9]|9[0-9])\d{8}$'
if not re.match(regex, value):
raise CustomException(code=RET.ERROR.code, msg="手机号格式不正确")
return value
def menu_request_validator(data):
"""
菜单请求数据验证器。
参数:
- data (Any): 请求数据。
返回:
- Any: 验证后的请求数据。
异常:
- CustomException: 请求数据无效时抛出。
"""
menu_types = {1: "目录", 2: "功能", 3: "权限", 4: "外链"}
if data.type not in menu_types:
raise CustomException(code=RET.ERROR.code, msg=f"菜单类型必须为: {','.join(map(str, menu_types.keys()))}")
if data.type in [1, 2]:
if not data.route_name:
raise CustomException(code=RET.ERROR.code, msg="路由名称不能为空")
if not data.route_path:
raise CustomException(code=RET.ERROR.code, msg="路由路径不能为空")
if data.type == 2 and not data.component_path:
raise CustomException(code=RET.ERROR.code, msg="组件路径不能为空")
return data
def role_permission_request_validator(data):
"""
角色权限设置数据验证器。
参数:
- data (Any): 请求数据。
返回:
- Any: 验证后的请求数据。
异常:
- CustomException: 请求数据无效时抛出。
"""
data_scopes = {
1: "仅本人数据权限",
2: "本部门数据权限",
3: "本部门及以下数据权限",
4: "全部数据权限",
5: "自定义数据权限"
}
if data.data_scope not in data_scopes:
raise CustomException(code=RET.ERROR.code, msg=f"数据权限范围必须为: {','.join(map(str, data_scopes.keys()))}")
if not data.role_ids:
raise CustomException(code=RET.ERROR.code, msg="角色不能为空")
return data