Files

229 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
import json
from redis.asyncio.client import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from typing import Any, AsyncGenerator
from fastapi import Depends, Request
from fastapi import Depends
from app.common.enums import RedisInitKeyConfig
from app.core.exceptions import CustomException
from app.core.database import async_db_session
from app.core.redis_crud import RedisCURD
from app.core.security import OAuth2Schema, decode_access_token
from app.core.logger import log
from app.api.v1.module_system.user.model import UserModel
from app.api.v1.module_system.user.crud import UserCRUD
from app.api.v1.module_system.auth.schema import AuthSchema
async def db_getter() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话连接
返回:
- AsyncSession: 数据库会话连接
"""
async with async_db_session() as session:
async with session.begin():
yield session
async def redis_getter(request: Request) -> Redis:
"""获取Redis连接
参数:
- request (Request): 请求对象
返回:
- Redis: Redis连接
"""
return request.app.state.redis
async def _session_user_info_from_token(
request: Request,
db: AsyncSession,
redis: Redis,
token: str,
) -> tuple[AuthSchema, dict[str, Any]]:
"""校验 token / Redis 在线,返回 (AuthSchema, JWT sub 解析出的会话信息字典)。"""
if not token:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
if token.startswith("Bearer"):
token = token.split(" ", 1)[1]
payload = decode_access_token(token)
if not payload or not hasattr(payload, "is_refresh") or payload.is_refresh:
raise CustomException(msg="非法凭证", code=10401, status_code=401)
user_info = json.loads(payload.sub)
session_id = user_info.get("session_id")
if not session_id:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
online_ok = await RedisCURD(redis).exists(
key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}"
)
if not online_ok:
raise CustomException(msg="认证已失效", code=10401, status_code=401)
# 须能唯一定位用户:优先 user_id改用户名后仍有效兼容仅有 user_name 的旧 token
if user_info.get("user_id") is None and not user_info.get("user_name"):
raise CustomException(msg="认证已失效", code=10401, status_code=401)
auth = AuthSchema(db=db, check_data_scope=False)
return auth, user_info
async def _load_user_from_session_info(
auth: AuthSchema,
user_info: dict[str, Any],
preload: list[str | Any],
):
"""按会话中的 user_id 优先加载用户,避免修改 username 后 JWT 内旧 user_name 导致 401。"""
uid = user_info.get("user_id")
username = user_info.get("user_name")
if uid is not None:
return await UserCRUD(auth).get_by_id_crud(id=int(uid), preload=preload)
return await UserCRUD(auth).get_by_username_crud(username=username, preload=preload)
async def get_current_user(
request: Request,
db: AsyncSession = Depends(db_getter),
redis: Redis = Depends(redis_getter),
token: str = Depends(OAuth2Schema),
) -> AuthSchema:
"""获取当前用户
参数:
- request (Request): 请求对象
- db (AsyncSession): 数据库会话
- redis (Redis): Redis连接
- token (str): 访问令牌
返回:
- AuthSchema: 认证信息模型
"""
auth, user_info = await _session_user_info_from_token(request, db, redis, token)
# 获取用户信息使用深层预加载确保RoleModel.creator被正确加载
user = await _load_user_from_session_info(
auth,
user_info,
preload=[
"dept",
selectinload(UserModel.roles),
"positions",
"created_by",
],
)
if not user:
raise CustomException(msg="用户不存在", code=10401, status_code=401)
if not user.status:
raise CustomException(msg="用户已被停用", code=10401, status_code=401)
# 设置请求上下文
request.state.user = user
request.state.user_id = user.id
request.state.user_username = user.username
request.scope["user_id"] = user.id
request.scope["user_username"] = user.username
# 过滤可用的角色和职位
if hasattr(user, 'roles'):
user.roles = [role for role in user.roles if role and role.status]
if hasattr(user, 'positions'):
user.positions = [pos for pos in user.positions if pos and pos.status]
auth.user = user
return auth
async def get_current_user_lite(
request: Request,
db: AsyncSession = Depends(db_getter),
redis: Redis = Depends(redis_getter),
token: str = Depends(OAuth2Schema),
) -> AuthSchema:
"""当前用户(轻量):仅主表字段,无部门/角色等预加载,用于低频写个人资料的接口以压低延迟。"""
auth, user_info = await _session_user_info_from_token(request, db, redis, token)
user = await _load_user_from_session_info(auth, user_info, preload=[])
if not user:
raise CustomException(msg="用户不存在", code=10401, status_code=401)
if not user.status:
raise CustomException(msg="用户已被停用", code=10401, status_code=401)
request.state.user = user
request.state.user_id = user.id
request.state.user_username = user.username
request.scope["user_id"] = user.id
request.scope["user_username"] = user.username
# 无 dept/roles 预加载,不访问关联属性,避免异步环境下的隐式懒加载
auth.user = user
return auth
class AuthPermission:
"""权限验证类"""
def __init__(self, permissions: list[str] | None = None, check_data_scope: bool = True) -> None:
"""
初始化权限验证
参数:
- permissions (list[str] | None): 权限标识列表。
- check_data_scope (bool): 是否启用严格模式校验。
"""
self.permissions = permissions or []
self.check_data_scope = check_data_scope
async def __call__(self, auth: AuthSchema = Depends(get_current_user)) -> AuthSchema:
"""
调用权限验证
参数:
- auth (AuthSchema): 认证信息对象。
返回:
- AuthSchema: 认证信息对象。
"""
auth.check_data_scope = self.check_data_scope
# 超级管理员直接通过
if auth.user and auth.user.is_superuser:
return auth
# 无需验证权限
if not self.permissions:
return auth
# 超级管理员权限标识
if "*" in self.permissions or "*:*:*" in self.permissions:
return auth
# 检查用户是否有角色
if not auth.user or not auth.user.roles:
raise CustomException(msg="无权限操作", code=10403, status_code=403)
# 获取用户权限集合
user_permissions = {
menu.permission
for role in auth.user.roles
for menu in role.menus
if role.status and menu.permission and menu.status
}
# 权限验证 - 满足任一权限即可
if not any(perm in user_permissions for perm in self.permissions):
log.error(f"用户缺少任何所需的权限: {self.permissions}")
raise CustomException(msg="无权限操作", code=10403, status_code=403)
return auth