Files
----/后端源码/yifan.action-ai.cn/api-bak/app/core/middlewares.py

213 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
import json
import time
from starlette.middleware.cors import CORSMiddleware
from starlette.types import ASGIApp
from starlette.requests import Request
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.common.response import ErrorResponse
from app.config.setting import settings
from app.core.logger import log
from app.core.exceptions import CustomException
from app.core.security import decode_access_token
from app.api.v1.module_system.params.service import ParamsService
class CustomCORSMiddleware(CORSMiddleware):
"""CORS跨域中间件"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(
app,
allow_origins=settings.ALLOW_ORIGINS,
allow_methods=settings.ALLOW_METHODS,
allow_headers=settings.ALLOW_HEADERS,
allow_credentials=settings.ALLOW_CREDENTIALS,
expose_headers=settings.CORS_EXPOSE_HEADERS,
)
class RequestLogMiddleware:
"""
记录请求日志中间件
注意:使用纯 ASGI 中间件实现,避免 BaseHTTPMiddleware 缓冲流式响应的问题。
BaseHTTPMiddleware 会等待整个响应体完成后才返回,这会破坏流式响应的功能。
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
@staticmethod
def _extract_session_id_from_headers(headers: list) -> str | None:
"""
从请求头中提取session_id
参数:
- headers (list): ASGI 格式的请求头列表
返回:
- str | None: 会话ID如果无法提取则返回None
"""
try:
authorization = None
for key, value in headers:
if key == b'authorization':
authorization = value.decode('utf-8')
break
if not authorization:
return None
# 处理Bearer token
token = authorization.replace('Bearer ', '').strip()
# 解码token
payload = decode_access_token(token)
if not payload or not hasattr(payload, 'sub'):
return None
# 从payload中提取session_id
user_info = json.loads(payload.sub)
return user_info.get("session_id")
except Exception:
return None
@staticmethod
def _get_header_value(headers: list, key: bytes) -> str | None:
"""从头列表中获取指定的头值"""
for k, v in headers:
if k.lower() == key.lower():
return v.decode('utf-8')
return None
async def __call__(self, scope, receive, send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request = Request(scope, receive)
# 尝试提取session_id
session_id = self._extract_session_id_from_headers(scope.get('headers', []))
# 组装请求日志字段
client_host = scope.get('client', ['未知'])[0] if scope.get('client') else '未知'
log_fields = [
f"请求来源: {client_host}",
f"请求方法: {scope.get('method', 'UNKNOWN')}",
f"请求路径: {scope.get('path', '/')}",
]
log.info(log_fields)
# 获取请求路径
path = scope.get("path")
# 尝试获取客户端真实IP
headers = scope.get('headers', [])
x_forwarded_for = self._get_header_value(headers, b'x-forwarded-for')
if x_forwarded_for:
request_ip = x_forwarded_for.split(',')[0].strip()
else:
request_ip = client_host
# 检查是否启用演示模式
demo_enable = False
ip_white_list = []
white_api_list_path = []
ip_black_list = []
try:
# 从应用实例获取Redis连接
redis = request.app.state.redis
if redis:
# 使用ParamsService获取系统配置
system_config = await ParamsService.get_system_config_for_middleware(redis)
demo_enable = system_config["demo_enable"]
ip_white_list = system_config["ip_white_list"]
white_api_list_path = system_config["white_api_list_path"]
ip_black_list = system_config["ip_black_list"]
except Exception as e:
log.error(f"获取系统配置失败: {e}")
# 检查是否需要拦截请求
should_block = False
block_reason = ""
method = scope.get('method', '')
# 1. 首先检查IP是否在黑名单中
if request_ip and request_ip in ip_black_list:
should_block = True
block_reason = f"IP地址 {request_ip} 在黑名单中"
# 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
elif demo_enable in ["true", "True"] and method != "GET":
is_ip_whitelisted = request_ip in ip_white_list
is_path_whitelisted = path in white_api_list_path
if not is_ip_whitelisted and not is_path_whitelisted:
should_block = True
block_reason = f"演示模式下拦截非GET请求IP: {request_ip}, 路径: {path}"
if should_block:
log.warning([
f"会话ID: {session_id or '未认证'}",
f"请求被拦截: {block_reason}",
f"请求来源: {request_ip}",
f"请求方法: {method}",
f"请求路径: {path}",
f"演示模式: {demo_enable}"
])
# 返回错误响应
response = ErrorResponse(msg="演示环境,禁止操作")
await response(scope, receive, send)
return
# 用于追踪响应状态
response_started = False
response_status = 0
async def send_wrapper(message):
nonlocal response_started, response_status
if message["type"] == "http.response.start":
response_started = True
response_status = message.get("status", 0)
# 计算处理时间并添加到响应头
process_time = round(time.time() - start_time, 5)
headers = list(message.get("headers", []))
headers.append((b"x-process-time", str(process_time).encode()))
message = {**message, "headers": headers}
elif message["type"] == "http.response.body":
# 如果是最后一个body chunk记录日志
if not message.get("more_body", False):
process_time = round(time.time() - start_time, 5)
log.info(
f"响应状态: {response_status}, "
f"处理时间: {round(process_time * 1000, 3)}ms"
)
await send(message)
try:
await self.app(scope, receive, send_wrapper)
except Exception as e:
log.error(f"中间件处理异常: {str(e)}")
if not response_started:
response = ErrorResponse(msg=f"系统异常,请联系管理员", data=str(e))
await response(scope, receive, send)
class CustomGZipMiddleware(GZipMiddleware):
"""GZip压缩中间件"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(
app,
minimum_size=settings.GZIP_MIN_SIZE,
compresslevel=settings.GZIP_COMPRESS_LEVEL
)