upload project source code
This commit is contained in:
228
后端源码/yifan.action-ai.cn/api/app/core/dependencies.py
Normal file
228
后端源码/yifan.action-ai.cn/api/app/core/dependencies.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from redis.asyncio.client import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from typing import Any, AsyncGenerator
|
||||
from fastapi import Depends, Request
|
||||
from fastapi import Depends
|
||||
|
||||
from app.common.enums import RedisInitKeyConfig
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.database import async_db_session
|
||||
from app.core.redis_crud import RedisCURD
|
||||
from app.core.security import OAuth2Schema, decode_access_token
|
||||
from app.core.logger import log
|
||||
|
||||
from app.api.v1.module_system.user.model import UserModel
|
||||
from app.api.v1.module_system.user.crud import UserCRUD
|
||||
from app.api.v1.module_system.auth.schema import AuthSchema
|
||||
|
||||
|
||||
async def db_getter() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话连接
|
||||
|
||||
返回:
|
||||
- AsyncSession: 数据库会话连接
|
||||
"""
|
||||
async with async_db_session() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
|
||||
async def redis_getter(request: Request) -> Redis:
|
||||
"""获取Redis连接
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象
|
||||
|
||||
返回:
|
||||
- Redis: Redis连接
|
||||
"""
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
async def _session_user_info_from_token(
|
||||
request: Request,
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
token: str,
|
||||
) -> tuple[AuthSchema, dict[str, Any]]:
|
||||
"""校验 token / Redis 在线,返回 (AuthSchema, JWT sub 解析出的会话信息字典)。"""
|
||||
if not token:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
if token.startswith("Bearer"):
|
||||
token = token.split(" ", 1)[1]
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if not payload or not hasattr(payload, "is_refresh") or payload.is_refresh:
|
||||
raise CustomException(msg="非法凭证", code=10401, status_code=401)
|
||||
|
||||
user_info = json.loads(payload.sub)
|
||||
session_id = user_info.get("session_id")
|
||||
if not session_id:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
online_ok = await RedisCURD(redis).exists(
|
||||
key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}"
|
||||
)
|
||||
if not online_ok:
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
# 须能唯一定位用户:优先 user_id(改用户名后仍有效);兼容仅有 user_name 的旧 token
|
||||
if user_info.get("user_id") is None and not user_info.get("user_name"):
|
||||
raise CustomException(msg="认证已失效", code=10401, status_code=401)
|
||||
|
||||
auth = AuthSchema(db=db, check_data_scope=False)
|
||||
return auth, user_info
|
||||
|
||||
|
||||
async def _load_user_from_session_info(
|
||||
auth: AuthSchema,
|
||||
user_info: dict[str, Any],
|
||||
preload: list[str | Any],
|
||||
):
|
||||
"""按会话中的 user_id 优先加载用户,避免修改 username 后 JWT 内旧 user_name 导致 401。"""
|
||||
uid = user_info.get("user_id")
|
||||
username = user_info.get("user_name")
|
||||
if uid is not None:
|
||||
return await UserCRUD(auth).get_by_id_crud(id=int(uid), preload=preload)
|
||||
return await UserCRUD(auth).get_by_username_crud(username=username, preload=preload)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(db_getter),
|
||||
redis: Redis = Depends(redis_getter),
|
||||
token: str = Depends(OAuth2Schema),
|
||||
) -> AuthSchema:
|
||||
"""获取当前用户
|
||||
|
||||
参数:
|
||||
- request (Request): 请求对象
|
||||
- db (AsyncSession): 数据库会话
|
||||
- redis (Redis): Redis连接
|
||||
- token (str): 访问令牌
|
||||
|
||||
返回:
|
||||
- AuthSchema: 认证信息模型
|
||||
"""
|
||||
auth, user_info = await _session_user_info_from_token(request, db, redis, token)
|
||||
|
||||
# 获取用户信息,使用深层预加载确保RoleModel.creator被正确加载
|
||||
user = await _load_user_from_session_info(
|
||||
auth,
|
||||
user_info,
|
||||
preload=[
|
||||
"dept",
|
||||
selectinload(UserModel.roles),
|
||||
"positions",
|
||||
"created_by",
|
||||
],
|
||||
)
|
||||
if not user:
|
||||
raise CustomException(msg="用户不存在", code=10401, status_code=401)
|
||||
if not user.status:
|
||||
raise CustomException(msg="用户已被停用", code=10401, status_code=401)
|
||||
|
||||
# 设置请求上下文
|
||||
request.state.user = user
|
||||
request.state.user_id = user.id
|
||||
request.state.user_username = user.username
|
||||
request.scope["user_id"] = user.id
|
||||
request.scope["user_username"] = user.username
|
||||
|
||||
# 过滤可用的角色和职位
|
||||
if hasattr(user, 'roles'):
|
||||
user.roles = [role for role in user.roles if role and role.status]
|
||||
if hasattr(user, 'positions'):
|
||||
user.positions = [pos for pos in user.positions if pos and pos.status]
|
||||
|
||||
auth.user = user
|
||||
return auth
|
||||
|
||||
|
||||
async def get_current_user_lite(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(db_getter),
|
||||
redis: Redis = Depends(redis_getter),
|
||||
token: str = Depends(OAuth2Schema),
|
||||
) -> AuthSchema:
|
||||
"""当前用户(轻量):仅主表字段,无部门/角色等预加载,用于低频写个人资料的接口以压低延迟。"""
|
||||
auth, user_info = await _session_user_info_from_token(request, db, redis, token)
|
||||
|
||||
user = await _load_user_from_session_info(auth, user_info, preload=[])
|
||||
if not user:
|
||||
raise CustomException(msg="用户不存在", code=10401, status_code=401)
|
||||
if not user.status:
|
||||
raise CustomException(msg="用户已被停用", code=10401, status_code=401)
|
||||
|
||||
request.state.user = user
|
||||
request.state.user_id = user.id
|
||||
request.state.user_username = user.username
|
||||
request.scope["user_id"] = user.id
|
||||
request.scope["user_username"] = user.username
|
||||
|
||||
# 无 dept/roles 预加载,不访问关联属性,避免异步环境下的隐式懒加载
|
||||
|
||||
auth.user = user
|
||||
return auth
|
||||
|
||||
|
||||
class AuthPermission:
|
||||
"""权限验证类"""
|
||||
|
||||
def __init__(self, permissions: list[str] | None = None, check_data_scope: bool = True) -> None:
|
||||
"""
|
||||
初始化权限验证
|
||||
|
||||
参数:
|
||||
- permissions (list[str] | None): 权限标识列表。
|
||||
- check_data_scope (bool): 是否启用严格模式校验。
|
||||
"""
|
||||
self.permissions = permissions or []
|
||||
self.check_data_scope = check_data_scope
|
||||
|
||||
async def __call__(self, auth: AuthSchema = Depends(get_current_user)) -> AuthSchema:
|
||||
"""
|
||||
调用权限验证
|
||||
|
||||
参数:
|
||||
- auth (AuthSchema): 认证信息对象。
|
||||
|
||||
返回:
|
||||
- AuthSchema: 认证信息对象。
|
||||
"""
|
||||
auth.check_data_scope = self.check_data_scope
|
||||
|
||||
# 超级管理员直接通过
|
||||
if auth.user and auth.user.is_superuser:
|
||||
return auth
|
||||
|
||||
# 无需验证权限
|
||||
if not self.permissions:
|
||||
return auth
|
||||
|
||||
# 超级管理员权限标识
|
||||
if "*" in self.permissions or "*:*:*" in self.permissions:
|
||||
return auth
|
||||
|
||||
# 检查用户是否有角色
|
||||
if not auth.user or not auth.user.roles:
|
||||
raise CustomException(msg="无权限操作", code=10403, status_code=403)
|
||||
|
||||
# 获取用户权限集合
|
||||
user_permissions = {
|
||||
menu.permission
|
||||
for role in auth.user.roles
|
||||
for menu in role.menus
|
||||
if role.status and menu.permission and menu.status
|
||||
}
|
||||
|
||||
# 权限验证 - 满足任一权限即可
|
||||
if not any(perm in user_permissions for perm in self.permissions):
|
||||
log.error(f"用户缺少任何所需的权限: {self.permissions}")
|
||||
raise CustomException(msg="无权限操作", code=10403, status_code=403)
|
||||
|
||||
return auth
|
||||
Reference in New Issue
Block a user