213 lines
7.7 KiB
Python
213 lines
7.7 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
import json
|
||
import time
|
||
from starlette.middleware.cors import CORSMiddleware
|
||
from starlette.types import ASGIApp
|
||
from starlette.requests import Request
|
||
from starlette.middleware.gzip import GZipMiddleware
|
||
from starlette.responses import Response
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
||
from app.common.response import ErrorResponse
|
||
from app.config.setting import settings
|
||
from app.core.logger import log
|
||
from app.core.exceptions import CustomException
|
||
from app.core.security import decode_access_token
|
||
from app.api.v1.module_system.params.service import ParamsService
|
||
|
||
|
||
class CustomCORSMiddleware(CORSMiddleware):
|
||
"""CORS跨域中间件"""
|
||
def __init__(self, app: ASGIApp) -> None:
|
||
super().__init__(
|
||
app,
|
||
allow_origins=settings.ALLOW_ORIGINS,
|
||
allow_methods=settings.ALLOW_METHODS,
|
||
allow_headers=settings.ALLOW_HEADERS,
|
||
allow_credentials=settings.ALLOW_CREDENTIALS,
|
||
expose_headers=settings.CORS_EXPOSE_HEADERS,
|
||
)
|
||
|
||
|
||
class RequestLogMiddleware:
|
||
"""
|
||
记录请求日志中间件
|
||
|
||
注意:使用纯 ASGI 中间件实现,避免 BaseHTTPMiddleware 缓冲流式响应的问题。
|
||
BaseHTTPMiddleware 会等待整个响应体完成后才返回,这会破坏流式响应的功能。
|
||
"""
|
||
def __init__(self, app: ASGIApp) -> None:
|
||
self.app = app
|
||
|
||
@staticmethod
|
||
def _extract_session_id_from_headers(headers: list) -> str | None:
|
||
"""
|
||
从请求头中提取session_id
|
||
|
||
参数:
|
||
- headers (list): ASGI 格式的请求头列表
|
||
|
||
返回:
|
||
- str | None: 会话ID,如果无法提取则返回None
|
||
"""
|
||
try:
|
||
authorization = None
|
||
for key, value in headers:
|
||
if key == b'authorization':
|
||
authorization = value.decode('utf-8')
|
||
break
|
||
|
||
if not authorization:
|
||
return None
|
||
|
||
# 处理Bearer token
|
||
token = authorization.replace('Bearer ', '').strip()
|
||
|
||
# 解码token
|
||
payload = decode_access_token(token)
|
||
if not payload or not hasattr(payload, 'sub'):
|
||
return None
|
||
|
||
# 从payload中提取session_id
|
||
user_info = json.loads(payload.sub)
|
||
return user_info.get("session_id")
|
||
except Exception:
|
||
return None
|
||
|
||
@staticmethod
|
||
def _get_header_value(headers: list, key: bytes) -> str | None:
|
||
"""从头列表中获取指定的头值"""
|
||
for k, v in headers:
|
||
if k.lower() == key.lower():
|
||
return v.decode('utf-8')
|
||
return None
|
||
|
||
async def __call__(self, scope, receive, send) -> None:
|
||
if scope["type"] != "http":
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
start_time = time.time()
|
||
request = Request(scope, receive)
|
||
|
||
# 尝试提取session_id
|
||
session_id = self._extract_session_id_from_headers(scope.get('headers', []))
|
||
|
||
# 组装请求日志字段
|
||
client_host = scope.get('client', ['未知'])[0] if scope.get('client') else '未知'
|
||
log_fields = [
|
||
f"请求来源: {client_host}",
|
||
f"请求方法: {scope.get('method', 'UNKNOWN')}",
|
||
f"请求路径: {scope.get('path', '/')}",
|
||
]
|
||
log.info(log_fields)
|
||
|
||
# 获取请求路径
|
||
path = scope.get("path")
|
||
|
||
# 尝试获取客户端真实IP
|
||
headers = scope.get('headers', [])
|
||
x_forwarded_for = self._get_header_value(headers, b'x-forwarded-for')
|
||
if x_forwarded_for:
|
||
request_ip = x_forwarded_for.split(',')[0].strip()
|
||
else:
|
||
request_ip = client_host
|
||
|
||
# 检查是否启用演示模式
|
||
demo_enable = False
|
||
ip_white_list = []
|
||
white_api_list_path = []
|
||
ip_black_list = []
|
||
|
||
try:
|
||
# 从应用实例获取Redis连接
|
||
redis = request.app.state.redis
|
||
if redis:
|
||
# 使用ParamsService获取系统配置
|
||
system_config = await ParamsService.get_system_config_for_middleware(redis)
|
||
demo_enable = system_config["demo_enable"]
|
||
ip_white_list = system_config["ip_white_list"]
|
||
white_api_list_path = system_config["white_api_list_path"]
|
||
ip_black_list = system_config["ip_black_list"]
|
||
except Exception as e:
|
||
log.error(f"获取系统配置失败: {e}")
|
||
|
||
# 检查是否需要拦截请求
|
||
should_block = False
|
||
block_reason = ""
|
||
method = scope.get('method', '')
|
||
|
||
# 1. 首先检查IP是否在黑名单中
|
||
if request_ip and request_ip in ip_black_list:
|
||
should_block = True
|
||
block_reason = f"IP地址 {request_ip} 在黑名单中"
|
||
|
||
# 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
|
||
elif demo_enable in ["true", "True"] and method != "GET":
|
||
is_ip_whitelisted = request_ip in ip_white_list
|
||
is_path_whitelisted = path in white_api_list_path
|
||
|
||
if not is_ip_whitelisted and not is_path_whitelisted:
|
||
should_block = True
|
||
block_reason = f"演示模式下拦截非GET请求,IP: {request_ip}, 路径: {path}"
|
||
|
||
if should_block:
|
||
log.warning([
|
||
f"会话ID: {session_id or '未认证'}",
|
||
f"请求被拦截: {block_reason}",
|
||
f"请求来源: {request_ip}",
|
||
f"请求方法: {method}",
|
||
f"请求路径: {path}",
|
||
f"演示模式: {demo_enable}"
|
||
])
|
||
# 返回错误响应
|
||
response = ErrorResponse(msg="演示环境,禁止操作")
|
||
await response(scope, receive, send)
|
||
return
|
||
|
||
# 用于追踪响应状态
|
||
response_started = False
|
||
response_status = 0
|
||
|
||
async def send_wrapper(message):
|
||
nonlocal response_started, response_status
|
||
|
||
if message["type"] == "http.response.start":
|
||
response_started = True
|
||
response_status = message.get("status", 0)
|
||
|
||
# 计算处理时间并添加到响应头
|
||
process_time = round(time.time() - start_time, 5)
|
||
headers = list(message.get("headers", []))
|
||
headers.append((b"x-process-time", str(process_time).encode()))
|
||
message = {**message, "headers": headers}
|
||
|
||
elif message["type"] == "http.response.body":
|
||
# 如果是最后一个body chunk,记录日志
|
||
if not message.get("more_body", False):
|
||
process_time = round(time.time() - start_time, 5)
|
||
log.info(
|
||
f"响应状态: {response_status}, "
|
||
f"处理时间: {round(process_time * 1000, 3)}ms"
|
||
)
|
||
|
||
await send(message)
|
||
|
||
try:
|
||
await self.app(scope, receive, send_wrapper)
|
||
except Exception as e:
|
||
log.error(f"中间件处理异常: {str(e)}")
|
||
if not response_started:
|
||
response = ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
|
||
await response(scope, receive, send)
|
||
|
||
|
||
class CustomGZipMiddleware(GZipMiddleware):
|
||
"""GZip压缩中间件"""
|
||
def __init__(self, app: ASGIApp) -> None:
|
||
super().__init__(
|
||
app,
|
||
minimum_size=settings.GZIP_MIN_SIZE,
|
||
compresslevel=settings.GZIP_COMPRESS_LEVEL
|
||
) |