upload project source code
This commit is contained in:
213
后端源码/yifan.action-ai.cn/app/core/middlewares.py
Normal file
213
后端源码/yifan.action-ai.cn/app/core/middlewares.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import time
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.requests import Request
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from starlette.responses import Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.common.response import ErrorResponse
|
||||
from app.config.setting import settings
|
||||
from app.core.logger import log
|
||||
from app.core.exceptions import CustomException
|
||||
from app.core.security import decode_access_token
|
||||
from app.api.v1.module_system.params.service import ParamsService
|
||||
|
||||
|
||||
class CustomCORSMiddleware(CORSMiddleware):
|
||||
"""CORS跨域中间件"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
allow_origins=settings.ALLOW_ORIGINS,
|
||||
allow_methods=settings.ALLOW_METHODS,
|
||||
allow_headers=settings.ALLOW_HEADERS,
|
||||
allow_credentials=settings.ALLOW_CREDENTIALS,
|
||||
expose_headers=settings.CORS_EXPOSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
class RequestLogMiddleware:
|
||||
"""
|
||||
记录请求日志中间件
|
||||
|
||||
注意:使用纯 ASGI 中间件实现,避免 BaseHTTPMiddleware 缓冲流式响应的问题。
|
||||
BaseHTTPMiddleware 会等待整个响应体完成后才返回,这会破坏流式响应的功能。
|
||||
"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_id_from_headers(headers: list) -> str | None:
|
||||
"""
|
||||
从请求头中提取session_id
|
||||
|
||||
参数:
|
||||
- headers (list): ASGI 格式的请求头列表
|
||||
|
||||
返回:
|
||||
- str | None: 会话ID,如果无法提取则返回None
|
||||
"""
|
||||
try:
|
||||
authorization = None
|
||||
for key, value in headers:
|
||||
if key == b'authorization':
|
||||
authorization = value.decode('utf-8')
|
||||
break
|
||||
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
# 处理Bearer token
|
||||
token = authorization.replace('Bearer ', '').strip()
|
||||
|
||||
# 解码token
|
||||
payload = decode_access_token(token)
|
||||
if not payload or not hasattr(payload, 'sub'):
|
||||
return None
|
||||
|
||||
# 从payload中提取session_id
|
||||
user_info = json.loads(payload.sub)
|
||||
return user_info.get("session_id")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_header_value(headers: list, key: bytes) -> str | None:
|
||||
"""从头列表中获取指定的头值"""
|
||||
for k, v in headers:
|
||||
if k.lower() == key.lower():
|
||||
return v.decode('utf-8')
|
||||
return None
|
||||
|
||||
async def __call__(self, scope, receive, send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
request = Request(scope, receive)
|
||||
|
||||
# 尝试提取session_id
|
||||
session_id = self._extract_session_id_from_headers(scope.get('headers', []))
|
||||
|
||||
# 组装请求日志字段
|
||||
client_host = scope.get('client', ['未知'])[0] if scope.get('client') else '未知'
|
||||
log_fields = [
|
||||
f"请求来源: {client_host}",
|
||||
f"请求方法: {scope.get('method', 'UNKNOWN')}",
|
||||
f"请求路径: {scope.get('path', '/')}",
|
||||
]
|
||||
log.info(log_fields)
|
||||
|
||||
# 获取请求路径
|
||||
path = scope.get("path")
|
||||
|
||||
# 尝试获取客户端真实IP
|
||||
headers = scope.get('headers', [])
|
||||
x_forwarded_for = self._get_header_value(headers, b'x-forwarded-for')
|
||||
if x_forwarded_for:
|
||||
request_ip = x_forwarded_for.split(',')[0].strip()
|
||||
else:
|
||||
request_ip = client_host
|
||||
|
||||
# 检查是否启用演示模式
|
||||
demo_enable = False
|
||||
ip_white_list = []
|
||||
white_api_list_path = []
|
||||
ip_black_list = []
|
||||
|
||||
try:
|
||||
# 从应用实例获取Redis连接
|
||||
redis = request.app.state.redis
|
||||
if redis:
|
||||
# 使用ParamsService获取系统配置
|
||||
system_config = await ParamsService.get_system_config_for_middleware(redis)
|
||||
demo_enable = system_config["demo_enable"]
|
||||
ip_white_list = system_config["ip_white_list"]
|
||||
white_api_list_path = system_config["white_api_list_path"]
|
||||
ip_black_list = system_config["ip_black_list"]
|
||||
except Exception as e:
|
||||
log.error(f"获取系统配置失败: {e}")
|
||||
|
||||
# 检查是否需要拦截请求
|
||||
should_block = False
|
||||
block_reason = ""
|
||||
method = scope.get('method', '')
|
||||
|
||||
# 1. 首先检查IP是否在黑名单中
|
||||
if request_ip and request_ip in ip_black_list:
|
||||
should_block = True
|
||||
block_reason = f"IP地址 {request_ip} 在黑名单中"
|
||||
|
||||
# 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
|
||||
elif demo_enable in ["true", "True"] and method != "GET":
|
||||
is_ip_whitelisted = request_ip in ip_white_list
|
||||
is_path_whitelisted = path in white_api_list_path
|
||||
|
||||
if not is_ip_whitelisted and not is_path_whitelisted:
|
||||
should_block = True
|
||||
block_reason = f"演示模式下拦截非GET请求,IP: {request_ip}, 路径: {path}"
|
||||
|
||||
if should_block:
|
||||
log.warning([
|
||||
f"会话ID: {session_id or '未认证'}",
|
||||
f"请求被拦截: {block_reason}",
|
||||
f"请求来源: {request_ip}",
|
||||
f"请求方法: {method}",
|
||||
f"请求路径: {path}",
|
||||
f"演示模式: {demo_enable}"
|
||||
])
|
||||
# 返回错误响应
|
||||
response = ErrorResponse(msg="演示环境,禁止操作")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# 用于追踪响应状态
|
||||
response_started = False
|
||||
response_status = 0
|
||||
|
||||
async def send_wrapper(message):
|
||||
nonlocal response_started, response_status
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
response_status = message.get("status", 0)
|
||||
|
||||
# 计算处理时间并添加到响应头
|
||||
process_time = round(time.time() - start_time, 5)
|
||||
headers = list(message.get("headers", []))
|
||||
headers.append((b"x-process-time", str(process_time).encode()))
|
||||
message = {**message, "headers": headers}
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
# 如果是最后一个body chunk,记录日志
|
||||
if not message.get("more_body", False):
|
||||
process_time = round(time.time() - start_time, 5)
|
||||
log.info(
|
||||
f"响应状态: {response_status}, "
|
||||
f"处理时间: {round(process_time * 1000, 3)}ms"
|
||||
)
|
||||
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
except Exception as e:
|
||||
log.error(f"中间件处理异常: {str(e)}")
|
||||
if not response_started:
|
||||
response = ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
class CustomGZipMiddleware(GZipMiddleware):
|
||||
"""GZip压缩中间件"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
minimum_size=settings.GZIP_MIN_SIZE,
|
||||
compresslevel=settings.GZIP_COMPRESS_LEVEL
|
||||
)
|
||||
Reference in New Issue
Block a user