upload project source code

This commit is contained in:
2026-04-30 18:49:43 +08:00
commit 9b394ba682
2277 changed files with 660945 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from app.config.path_conf import BANNER_FILE
from app.core.logger import log
def worship(env: str) -> None:
"""
获取项目启动Banner优先读取 banner.txt
"""
if BANNER_FILE.exists():
banner = BANNER_FILE.read_text(encoding='utf-8')
banner = f"🚀 当前运行环境: {env}\n{banner}"
log.info(banner)

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
import base64
import random
import string
from io import BytesIO
from typing import Tuple
from PIL import Image, ImageDraw, ImageFont
from app.config.setting import settings
class CaptchaUtil:
"""
验证码工具类
"""
@classmethod
def generate_captcha(cls) -> Tuple[str, str]:
"""
生成带有噪声和干扰的验证码图片4位随机字符
返回:
- Tuple[str, str]: [base64图片字符串, 验证码值]。
"""
# 生成4位随机验证码
chars = string.digits + string.ascii_letters
captcha_value = ''.join(random.sample(chars, 4))
# 创建一张随机颜色背景的图片
width, height = 160, 60
background_color = tuple(random.randint(230, 255) for _ in range(3))
image = Image.new('RGB', (width, height), color=background_color)
draw = ImageDraw.Draw(image)
# 使用指定字体
font = ImageFont.truetype(font=settings.CAPTCHA_FONT_PATH, size=settings.CAPTCHA_FONT_SIZE)
# 计算文本总宽度和高度
total_width = sum(draw.textbbox((0, 0), char, font=font)[2] for char in captcha_value)
text_height = draw.textbbox((0, 0), captcha_value[0], font=font)[3]
# 计算起始位置,使文字居中
x_start = (width - total_width) / 2
y_start = (height - text_height) / 2 - draw.textbbox((0, 0), captcha_value[0], font=font)[1]
# 绘制字符
x = x_start
for char in captcha_value:
# 使用深色文字,增加对比度
text_color = tuple(random.randint(0, 80) for _ in range(3))
# 随机偏移,增加干扰
x_offset = x + random.uniform(-2, 2)
y_offset = y_start + random.uniform(-2, 2)
# 绘制字符
draw.text((x_offset, y_offset), char, font=font, fill=text_color)
# 更新x坐标,增加字符间距的随机性
x += draw.textbbox((0, 0), char, font=font)[2] + random.uniform(1, 5)
# 添加干扰线
for _ in range(4):
line_color = tuple(random.randint(150, 200) for _ in range(3))
points = [(i, int(random.uniform(0, height))) for i in range(0, width, 20)]
draw.line(points, fill=line_color, width=1)
# 添加随机噪点
for _ in range(width * height // 60):
point_color = tuple(random.randint(0, 255) for _ in range(3))
draw.point(
(random.randint(0, width), random.randint(0, height)),
fill=point_color
)
# 将图像数据保存到内存中并转换为base64
buffer = BytesIO()
image.save(buffer, format='PNG', optimize=True)
base64_string = base64.b64encode(buffer.getvalue()).decode()
return base64_string, captcha_value
@classmethod
def captcha_arithmetic(cls) -> Tuple[str, int]:
"""
创建验证码图片(加减乘运算)。
返回:
- Tuple[str, int]: [base64图片字符串, 计算结果]。
"""
# 创建空白图像,使用随机浅色背景
background_color = tuple(random.randint(230, 255) for _ in range(3))
image = Image.new('RGB', (160, 60), color=background_color)
draw = ImageDraw.Draw(image)
# 设置字体
font = ImageFont.truetype(font=settings.CAPTCHA_FONT_PATH, size=settings.CAPTCHA_FONT_SIZE)
# 生成运算数字和运算符
operators = ['+', '-', '*']
operator = random.choice(operators)
# 对于减法,确保num1大于num2
if operator == '-':
num1 = random.randint(6, 10)
num2 = random.randint(1, 5)
else:
num1 = random.randint(1, 9)
num2 = random.randint(1, 9)
# 计算结果
result_map = {
'+': lambda x, y: x + y,
'-': lambda x, y: x - y,
'*': lambda x, y: x * y
}
captcha_value = result_map[operator](num1, num2)
# 绘制文本,使用深色增加对比度
text = f'{num1} {operator} {num2} = ?'
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
x = (160 - text_width) // 2
draw.text((x, 15), text, fill=(0, 0, 139), font=font)
# 添加干扰线
for _ in range(3):
line_color = tuple(random.randint(150, 200) for _ in range(3))
draw.line([
(random.randint(0, 160), random.randint(0, 60)),
(random.randint(0, 160), random.randint(0, 60))
], fill=line_color, width=1)
# 将图像数据保存到内存中并转换为base64
buffer = BytesIO()
image.save(buffer, format='PNG', optimize=True)
base64_string = base64.b64encode(buffer.getvalue()).decode()
return base64_string, captcha_value

View File

@@ -0,0 +1,399 @@
# -*- coding: utf-8 -*-
import importlib
import re
import uuid
from pathlib import Path
from typing import Any, Literal, Sequence, Generator
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.engine.row import Row
from sqlalchemy.orm.collections import InstrumentedList
from sqlalchemy.sql.expression import TextClause, null
from app.config.setting import settings
from app.core.logger import log
from app.core.exceptions import CustomException
def import_module(module: str, desc: str) -> Any:
"""
动态导入模块
参数:
- module (str): 模块名称。
- desc (str): 模块描述。
返回:
- Any: 模块对象。
"""
try:
module_path, module_class = module.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, module_class)
except ModuleNotFoundError:
log.error(f"❗️ 导入{desc}失败,未找到模块:{module}")
raise
except AttributeError:
log.error(f"❗ ️导入{desc}失败,未找到模块方法:{module}")
raise
async def import_modules_async(modules: list, desc: str, **kwargs) -> None:
"""
异步导入模块列表
参数:
- modules (list[str]): 模块列表。
- desc (str): 模块描述。
- kwargs: 额外参数。
返回:
- None
"""
for module in modules:
if not module:
continue
try:
module_path = module[0:module.rindex(".")]
module_name = module[module.rindex(".") + 1:]
module_obj = importlib.import_module(module_path)
await getattr(module_obj, module_name)(**kwargs)
except ModuleNotFoundError:
log.error(f"❌️ 导入{desc}失败,未找到模块:{module}")
raise
except AttributeError:
log.error(f"❌️ 导入{desc}失败,未找到模块方法:{module}")
raise
def get_random_character() -> str:
"""
生成随机字符串
返回:
- str: 随机字符串。
"""
return uuid.uuid4().hex
def uuid4_str() -> str:
"""数据库引擎 UUID 类型兼容性解决方案"""
return str(uuid.uuid4())
def get_parent_id_map(model_list: Sequence[DeclarativeBase]) -> dict[int, int]:
"""
获取父级 ID 映射字典
参数:
- model_list (Sequence[DeclarativeBase]): 模型列表。
返回:
- Dict[int, int]: {id: parent_id} 映射字典。
"""
return {item.id: item.parent_id for item in model_list}
def get_parent_recursion(id: int, id_map: dict[int, int], ids: list[int] | None = None) -> list[int]:
"""
递归获取所有父级 ID
参数:
- id (int): 当前 ID。
- id_map (dict[int, int]): ID 映射字典。
- ids (list[int] | None): 已收集的 ID 列表。
返回:
- list[int]: 所有父级 ID 列表。
"""
ids = ids or []
if id in ids:
raise CustomException(msg="递归获取父级ID失败,不可以自引用")
ids.append(id)
parent_id = id_map.get(id)
if parent_id:
get_parent_recursion(parent_id, id_map, ids)
return ids
def get_child_id_map(model_list: Sequence[DeclarativeBase]) -> dict[int, list[int]]:
"""
获取子级 ID 映射字典
参数:
- model_list (Sequence[DeclarativeBase]): 模型列表。
返回:
- Dict[int, List[int]]: {id: [child_ids]} 映射字典。
"""
data_map = {}
for model in model_list:
data_map.setdefault(model.id, [])
if model.parent_id:
data_map.setdefault(model.parent_id, []).append(model.id)
return data_map
def get_child_recursion(id: int, id_map: dict[int, list[int]], ids: list[int] | None= None) -> list[int]:
"""
递归获取所有子级 ID
参数:
- id (int): 当前 ID。
- id_map (dict[int, list[int]]): ID 映射字典。
- ids (list[int] | None): 已收集的 ID 列表。
返回:
- list[int]: 所有子级 ID 列表。
"""
ids = ids or []
ids.append(id)
for child in id_map.get(id, []):
get_child_recursion(child, id_map, ids)
return ids
def traversal_to_tree(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
通过遍历算法构造树形结构
参数:
- nodes (list[dict[str, Any]]): 树节点列表。
返回:
- list[dict[str, Any]]: 构造后的树形结构列表。
"""
tree: list[dict[str, Any]] = []
node_dict = {node['id']: node for node in nodes}
for node in nodes:
# 确保每个节点都有children字段即使没有子节点也设置为null
if 'children' not in node:
node['children'] = None
parent_id = node['parent_id']
if parent_id is None:
tree.append(node)
else:
parent_node = node_dict.get(parent_id)
if parent_node is not None:
if 'children' not in parent_node or parent_node['children'] is None:
parent_node['children'] = []
if node not in parent_node['children']:
parent_node['children'].append(node)
else:
if node not in tree:
tree.append(node)
# 确保所有节点都有children字段
for node in tree:
if 'children' not in node:
node['children'] = None
return tree
def recursive_to_tree(nodes: list[dict[str, Any]], *, parent_id: int | None = None) -> list[dict[str, Any]]:
"""
通过递归算法构造树形结构(性能影响较大)
参数:
- nodes (list[dict[str, Any]]): 树节点列表。
- parent_id (int | None): 父节点 ID,默认为 None 表示根节点。
返回:
- list[dict[str, Any]]: 构造后的树形结构列表。
"""
tree: list[dict[str, Any]] = []
for node in nodes:
if node['parent_id'] == parent_id:
child_nodes = recursive_to_tree(nodes, parent_id=node['id'])
if child_nodes:
node['children'] = child_nodes
tree.append(node)
return tree
def bytes2human(n: int, format_str: str = '%(value).1f%(symbol)s') -> str:
"""
字节数转人类可读格式
Used by various scripts. See:
http://goo.gl/zeJZl
>>> bytes2human(10000)
'9.8K'
>>> bytes2human(100001221)
'95.4M'
参数:
- n (int): 字节数。
- format_str (str): 格式化字符串,默认 '%(value).1f%(symbol)s'
返回:
- str: 可读的字节字符串,如 '1.5MB'
"""
symbols = ('B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB')
prefix = {s: 1 << (i + 1) * 10 for i, s in enumerate(symbols[1:])}
for symbol in reversed(symbols[1:]):
if n >= prefix[symbol]:
value = float(n) / prefix[symbol]
return format_str % locals()
return format_str % dict(symbol=symbols[0], value=n)
def bytes2file_response(bytes_info: bytes) -> Generator[bytes, Any, None]:
"""生成文件响应"""
yield bytes_info
def get_filepath_from_url(url: str) -> Path:
"""
工具方法:根据请求参数获取文件路径
参数:
- url (str): 请求参数中的 url 参数。
返回:
- Path: 文件路径。
"""
file_info = url.split('?')[1].split('&')
task_id = file_info[0].split('=')[1]
file_name = file_info[1].split('=')[1]
task_path = file_info[2].split('=')[1]
filepath = settings.STATIC_ROOT.joinpath(task_path, task_id, file_name)
return filepath
class SqlalchemyUtil:
"""
sqlalchemy工具类
"""
@classmethod
def base_to_dict(
cls, obj: DeclarativeBase | dict[str, Any], transform_case: Literal['no_case', 'snake_to_camel', 'camel_to_snake'] = 'no_case'
):
"""
将sqlalchemy模型对象转换为字典
:param obj: sqlalchemy模型对象或普通字典
:param transform_case: 转换得到的结果形式,可选的有'no_case'(不转换)、'snake_to_camel'(下划线转小驼峰)、'camel_to_snake'(小驼峰转下划线),默认为'no_case'
:return: 字典结果
"""
if isinstance(obj, DeclarativeBase):
base_dict = obj.__dict__.copy()
base_dict.pop('_sa_instance_state', None)
for name, value in base_dict.items():
if isinstance(value, InstrumentedList):
base_dict[name] = cls.serialize_result(value, 'snake_to_camel')
elif isinstance(obj, dict):
base_dict = obj.copy()
if transform_case == 'snake_to_camel':
return {CamelCaseUtil.snake_to_camel(k): v for k, v in base_dict.items()}
elif transform_case == 'camel_to_snake':
return {SnakeCaseUtil.camel_to_snake(k): v for k, v in base_dict.items()}
return base_dict
@classmethod
def serialize_result(
cls, result: Any, transform_case: Literal['no_case', 'snake_to_camel', 'camel_to_snake'] = 'no_case'
):
"""
将sqlalchemy查询结果序列化
:param result: sqlalchemy查询结果
:param transform_case: 转换得到的结果形式,可选的有'no_case'(不转换)、'snake_to_camel'(下划线转小驼峰)、'camel_to_snake'(小驼峰转下划线),默认为'no_case'
:return: 序列化结果
"""
if isinstance(result, (DeclarativeBase, dict)):
return cls.base_to_dict(result, transform_case)
elif isinstance(result, list):
return [cls.serialize_result(row, transform_case) for row in result]
elif isinstance(result, Row):
if all([isinstance(row, DeclarativeBase) for row in result]):
return [cls.base_to_dict(row, transform_case) for row in result]
elif any([isinstance(row, DeclarativeBase) for row in result]):
return [cls.serialize_result(row, transform_case) for row in result]
else:
result_dict = result._asdict()
if transform_case == 'snake_to_camel':
return {CamelCaseUtil.snake_to_camel(k): v for k, v in result_dict.items()}
elif transform_case == 'camel_to_snake':
return {SnakeCaseUtil.camel_to_snake(k): v for k, v in result_dict.items()}
return result_dict
return result
@classmethod
def get_server_default_null(cls, dialect_name: str, need_explicit_null: bool = True) -> TextClause | None:
"""
根据数据库方言动态返回值为null的server_default
:param dialect_name: 数据库方言名称
:param need_explicit_null: 是否需要显式DEFAULT NULL
:return: 不同数据库方言对应的null_server_default
"""
if need_explicit_null and dialect_name == 'postgres':
return null()
return None
class CamelCaseUtil:
"""
下划线形式(snake_case)转小驼峰形式(camelCase)工具方法
"""
@classmethod
def snake_to_camel(cls, snake_str: str):
"""
下划线形式字符串(snake_case)转换为小驼峰形式字符串(camelCase)
:param snake_str: 下划线形式字符串
:return: 小驼峰形式字符串
"""
# 分割字符串
words = snake_str.split('_')
# 小驼峰命名,第一个词首字母小写,其余词首字母大写
# return words[0] + ''.join(word.capitalize() for word in words[1:])
# 大驼峰命名,所有词首字母大写
return ''.join(word.capitalize() for word in words)
@classmethod
def transform_result(cls, result: Any):
"""
针对不同类型将下划线形式(snake_case)批量转换为小驼峰形式(camelCase)方法
:param result: 输入数据
:return: 小驼峰形式结果
"""
return SqlalchemyUtil.serialize_result(result=result, transform_case='snake_to_camel')
class SnakeCaseUtil:
"""
小驼峰形式(camelCase)转下划线形式(snake_case)工具方法
"""
@classmethod
def camel_to_snake(cls, camel_str: str):
"""
小驼峰形式字符串(camelCase)转换为下划线形式字符串(snake_case)
:param camel_str: 小驼峰形式字符串
:return: 下划线形式字符串
"""
# 在大写字母前添加一个下划线,然后将整个字符串转为小写
words = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel_str)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', words).lower()
@classmethod
def transform_result(cls, result: Any):
"""
针对不同类型将下划线形式(snake_case)批量转换为小驼峰形式(camelCase)方法
:param result: 输入数据
:return: 小驼峰形式结果
"""
return SqlalchemyUtil.serialize_result(result=result, transform_case='camel_to_snake')

View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from rich import get_console
from rich.panel import Panel
from rich.text import Text
from rich.console import Group
from app.config.setting import settings
console = get_console()
def create_service_panel(
host: str, port: int, reload: bool, *,
redis_ready: Optional[bool] = None,
scheduler_jobs: Optional[int] = None,
scheduler_status: Optional[str] = None,
) -> Panel:
"""创建简洁的服务启动信息面板"""
url = f'http://{host}:{port}'
base_url = f'{url}{settings.ROOT_PATH}'
docs_url = base_url + settings.DOCS_URL
redoc_url = base_url + settings.REDOC_URL
# 核心服务信息
service_info = Text()
service_info.append(f"服务名称 {settings.TITLE} • 优雅 • 简洁 • 高效", style="bold magenta")
service_info.append(f"\n当前版本 v{settings.VERSION}" , style="bold green")
service_info.append(f"\n服务地址 {url}", style="bold blue")
service_info.append(f"\n运行环境 {settings.ENVIRONMENT.value if hasattr(settings.ENVIRONMENT, 'value') else settings.ENVIRONMENT}", style="bold red")
service_info.append(f"\n重载配置: {'✅ 开启' if reload else '❌ 关闭'}", style="bold italic")
service_info.append(f"\n调试模式: {'✅ 开启' if settings.DEBUG else '❌ 关闭'}", style="bold italic")
service_info.append(f"\n数据库类型: {settings.DATABASE_TYPE} 数据库", style="bold italic")
service_info.append(f"\nRedis: {'✅ 已连接' if redis_ready else '❌ 未连接'}", style="bold italic")
service_info.append(f"\n定时任务 {'✅ 运行中' if scheduler_status == 'running' else '⏸️ 暂停'} {scheduler_jobs}", style="bold italic")
docs_info = Text()
docs_info.append("📖 文档", style="bold magenta")
docs_info.append(f"\n🔗 Swagger: {docs_url}", style="blue link")
docs_info.append(f"\n🔗 ReDoc: {redoc_url}", style="blue link")
final_content = Group(
service_info,
"\n" + "" * 40,
docs_info,
)
return Panel(
renderable=final_content,
title="[bold purple]🚀 服务启动完成[/]",
border_style="green",
padding=(1, 2)
)
def run(host: str, port: int, reload: bool, *,
redis_ready: Optional[bool] = None,
scheduler_jobs: Optional[int] = None,
scheduler_status: Optional[str] = None
) -> None:
"""显示启动信息面板"""
# 创建并显示启动面板
service_panel = create_service_panel(
host=host,
port=port,
reload=reload,
redis_ready=redis_ready,
scheduler_jobs=scheduler_jobs,
scheduler_status=scheduler_status,
)
console.print(service_panel)
def display_shutdown_info():
"""显示关闭信息"""
shutdown_content = Text()
shutdown_content.append("🛑 ", style="bold red")
shutdown_content.append("FastapiAdmin 服务关闭")
shutdown_content.append(f"\n{datetime.now().strftime('%H:%M:%S')}")
shutdown_content.append("\n👋 感谢使用!", style="dim")
shutdown_panel = Panel(
shutdown_content,
title="[bold red]服务关闭[/]",
border_style="red",
padding=(1, 2)
)
console.print(shutdown_panel)

View File

@@ -0,0 +1,218 @@
# -*- coding: utf-8 -*-
import re
from datetime import datetime
class CronUtil:
"""
Cron表达式工具类
"""
@classmethod
def __valid_range(cls, search_str: str, start_range: int, end_range: int) -> bool:
"""
校验范围表达式的合法性。
参数:
- search_str (str): 范围表达式。
- start_range (int): 开始范围。
- end_range (int): 结束范围。
返回:
- bool: 校验是否通过。
"""
match = re.match(r'^(\d+)-(\d+)$', search_str)
if match:
start, end = int(match.group(1)), int(match.group(2))
return start_range <= start < end <= end_range
return False
@classmethod
def __valid_sum(cls, search_str: str, start_range_a: int, start_range_b: int, end_range_a: int, end_range_b: int, sum_range: int) -> bool:
"""
校验和表达式的合法性。
参数:
- search_str (str): 和表达式。
- start_range_a (int): 开始范围A。
- start_range_b (int): 开始范围B。
- end_range_a (int): 结束范围A。
- end_range_b (int): 结束范围B。
- sum_range (int): 总和范围。
返回:
- bool: 校验是否通过。
"""
match = re.match(r'^(\d+)/(\d+)$', search_str)
if match:
start, end = int(match.group(1)), int(match.group(2))
return (
start_range_a <= start <= start_range_b
and end_range_a <= end <= end_range_b
and start + end <= sum_range
)
return False
@classmethod
def validate_second_or_minute(cls, second_or_minute: str) -> bool:
"""
校验秒或分钟字段的合法性。
参数:
- second_or_minute (str): 秒或分钟值。
返回:
- bool: 校验是否通过。
"""
if (
second_or_minute == '*'
or ('-' in second_or_minute and cls.__valid_range(second_or_minute, 0, 59))
or ('/' in second_or_minute and cls.__valid_sum(second_or_minute, 0, 58, 1, 59, 59))
or re.match(r'^(?:[0-5]?\d|59)(?:,[0-5]?\d|59)*$', second_or_minute)
):
return True
return False
@classmethod
def validate_hour(cls, hour: str) -> bool:
"""
校验小时字段的合法性。
参数:
- hour (str): 小时值。
返回:
- bool: 校验是否通过。
"""
if (
hour == '*'
or ('-' in hour and cls.__valid_range(hour, 0, 23))
or ('/' in hour and cls.__valid_sum(hour, 0, 22, 1, 23, 23))
or re.match(r'^(?:0|[1-9]|1\d|2[0-3])(?:,(?:0|[1-9]|1\d|2[0-3]))*$', hour)
):
return True
return False
@classmethod
def validate_day(cls, day: str) -> bool:
"""
校验日期字段的合法性。
参数:
- day (str): 日值。
返回:
- bool: 校验是否通过。
"""
if (
day in ['*', '?', 'L']
or ('-' in day and cls.__valid_range(day, 1, 31))
or ('/' in day and cls.__valid_sum(day, 1, 30, 1, 30, 31))
or ('W' in day and re.match(r'^(?:[1-9]|1\d|2\d|3[01])W$', day))
or re.match(r'^(?:0|[1-9]|1\d|2[0-9]|3[0-1])(?:,(?:0|[1-9]|1\d|2[0-9]|3[0-1]))*$', day)
):
return True
return False
@classmethod
def validate_month(cls, month: str) -> bool:
"""
校验月份字段的合法性。
参数:
- month (str): 月值。
返回:
- bool: 校验是否通过。
"""
if (
month == '*'
or ('-' in month and cls.__valid_range(month, 1, 12))
or ('/' in month and cls.__valid_sum(month, 1, 11, 1, 11, 12))
or re.match(r'^(?:0|[1-9]|1[0-2])(?:,(?:0|[1-9]|1[0-2]))*$', month)
):
return True
return False
@classmethod
def validate_week(cls, week: str) -> bool:
"""
校验星期字段的合法性。
参数:
- week (str): 周值。
返回:
- bool: 校验是否通过。
"""
if (
week in ['*', '?']
or ('-' in week and cls.__valid_range(week, 1, 7))
or ('#' in week and re.match(r'^[1-7]#[1-4]$', week))
or ('L' in week and re.match(r'^[1-7]L$', week))
or re.match(r'^[1-7](?:(,[1-7]))*$', week)
):
return True
return False
@classmethod
def validate_year(cls, year: str) -> bool:
"""
校验年份字段的合法性。
参数:
- year (str): 年值。
返回:
- bool: 校验是否通过。
"""
current_year = int(datetime.now().year)
future_years = [current_year + i for i in range(9)]
if (
year == '*'
or ('-' in year and cls.__valid_range(year, current_year, 2099))
or ('/' in year and cls.__valid_sum(year, current_year, 2098, 1, 2099 - current_year, 2099))
or ('#' in year and re.match(r'^[1-7]#[1-4]$', year))
or ('L' in year and re.match(r'^[1-7]L$', year))
or (
(len(year) == 4 or ',' in year)
and all(int(item) in future_years and current_year <= int(item) <= 2099 for item in year.split(','))
)
):
return True
return False
@classmethod
def validate_cron_expression(cls, cron_expression: str) -> bool | None:
"""
校验 Cron 表达式是否正确。
参数:
- cron_expression (str): Cron 表达式。
返回:
- bool | None: 校验是否通过。
"""
values = cron_expression.split()
if len(values) != 6 and len(values) != 7:
return False
second_validation = cls.validate_second_or_minute(values[0])
minute_validation = cls.validate_second_or_minute(values[1])
hour_validation = cls.validate_hour(values[2])
day_validation = cls.validate_day(values[3])
month_validation = cls.validate_month(values[4])
week_validation = cls.validate_week(values[5])
validation = (
second_validation
and minute_validation
and hour_validation
and day_validation
and month_validation
and week_validation
)
if len(values) == 6:
return validation
if len(values) == 7:
year_validation = cls.validate_year(values[6])
return validation and year_validation

View File

@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
import io
import pandas as pd
from typing import Any
from openpyxl import Workbook
from openpyxl.utils import get_column_letter
from openpyxl.styles import Alignment, PatternFill
from openpyxl.worksheet.datavalidation import DataValidation
class ExcelUtil:
"""Excel文件处理工具类"""
@classmethod
def __mapping_list(cls, list_data: list[dict[str, Any]], mapping_dict: dict) -> list:
"""
工具方法:将列表数据中的字段名映射为对应的中文字段名。
参数:
- list_data (list[dict[str, Any]]): 数据列表。
- mapping_dict (dict): 字段名映射字典。
返回:
- list: 映射后的数据列表。
"""
mapping_data = [{mapping_dict.get(key): item.get(key) for key in mapping_dict} for item in list_data]
return mapping_data
@classmethod
def get_excel_template(cls, header_list: list[str], selector_header_list: list[str], option_list: list[dict[str, list[str]]]) -> bytes:
"""
生成 Excel 模板文件。
参数:
- header_list (list[str]): 表头列表。
- selector_header_list (list[str]): 需要设置下拉选择的表头列表。
- option_list (list[dict[str, list[str]]]): 下拉选项配置列表。
返回:
- bytes: Excel 文件的二进制数据。
"""
wb = Workbook()
ws = wb.active
if not ws:
raise ValueError("不存在活动工作表")
# 设置表头样式
header_fill = PatternFill(start_color='ababab', end_color='ababab', fill_type='solid')
# 写入表头
for col_num, header in enumerate(header_list, 1):
cell = ws.cell(row=1, column=col_num)
cell.value = header
cell.fill = header_fill
# 设置水平居中对齐
cell.alignment = Alignment(horizontal='center')
# 设置列宽度为16
ws.column_dimensions[get_column_letter(col_num)].width = 12
# 设置下拉选择
for selector_header in selector_header_list:
col_idx = header_list.index(selector_header) + 1
# 获取当前表头的选项列表
header_options = next((opt.get(selector_header) for opt in option_list if selector_header in opt), [])
if header_options:
dv = DataValidation(type='list', formula1=f'"{",".join(header_options)}"')
dv.add(f'{get_column_letter(col_idx)}2:{get_column_letter(col_idx)}1048576')
ws.add_data_validation(dv)
# 导出为二进制数据
buffer = io.BytesIO()
wb.save(buffer)
buffer.seek(0)
# 读取字节数据
excel_data = buffer.getvalue()
return excel_data
@classmethod
def export_list2excel(cls, list_data: list[dict[str, Any]], mapping_dict: dict) -> bytes:
"""
将列表数据导出为 Excel 文件。
参数:
- list_data (list[dict[str, Any]]): 要导出的数据列表。
- mapping_dict (dict): 字段名映射字典。
返回:
- bytes: Excel 文件的二进制数据。
"""
mapping_data = cls.__mapping_list(list_data, mapping_dict)
df = pd.DataFrame(mapping_data)
buffer = io.BytesIO()
df.to_excel(buffer, index=False, engine='openpyxl')
binary_data = buffer.getvalue()
return binary_data

View File

@@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
import hashlib
import os
from typing import Any
from passlib.context import CryptContext
from cryptography.hazmat.backends.openssl import backend
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from itsdangerous import URLSafeSerializer
from app.core.logger import log
# 密码加密配置
PwdContext = CryptContext(
schemes=["bcrypt"],
deprecated="auto",
bcrypt__rounds=12 # 设置加密轮数,增加安全性
)
class PwdUtil:
"""
密码工具类,提供密码加密和验证功能
"""
@classmethod
def verify_password(cls, plain_password: str, password_hash: str) -> bool:
"""
校验密码是否匹配
参数:
- plain_password (str): 明文密码。
- password_hash (str): 加密后的密码哈希值。
返回:
- bool: 密码是否匹配。
"""
return PwdContext.verify(plain_password, password_hash)
@classmethod
def set_password_hash(cls, password: str) -> str:
"""
对密码进行加密
参数:
- password (str): 明文密码。
返回:
- str: 加密后的密码哈希值。
"""
return PwdContext.hash(password)
@classmethod
def check_password_strength(cls, password: str) -> str | None:
"""
检查密码强度
参数:
- password (str): 明文密码。
返回:
- str | None: 如果密码强度不够返回提示信息,否则返回None。
"""
if len(password) < 6:
return "密码长度至少6位"
if not any(c.isupper() for c in password):
return "密码需要包含大写字母"
if not any(c.islower() for c in password):
return "密码需要包含小写字母"
if not any(c.isdigit() for c in password):
return "密码需要包含数字"
return None
class AESCipher:
"""AES 加密器"""
def __init__(self, key: bytes | str) -> None:
"""
初始化 AES 加密器。
参数:
- key (bytes | str): 密钥16/24/32 bytes 或 16 进制字符串。
返回:
- None
"""
self.key = key if isinstance(key, bytes) else bytes.fromhex(key)
def encrypt(self, plaintext: bytes | str) -> bytes:
"""
AES 加密。
参数:
- plaintext (bytes | str): 加密前的明文。
返回:
- bytes: 加密后的密文前16字节为随机IV
"""
if not isinstance(plaintext, bytes):
plaintext = str(plaintext).encode('utf-8')
iv = os.urandom(16)
cipher = Cipher(algorithms.AES(self.key), modes.CBC(iv), backend=backend)
encryptor = cipher.encryptor()
padder = padding.PKCS7(cipher.algorithm.block_size).padder() # type: ignore
padded_plaintext = padder.update(plaintext) + padder.finalize()
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
return iv + ciphertext
def decrypt(self, ciphertext: bytes | str) -> str:
"""
AES 解密。
参数:
- ciphertext (bytes | str): 解密前的密文bytes 或 16 进制字符串。
返回:
- str: 解密后的明文。
"""
ciphertext = ciphertext if isinstance(ciphertext, bytes) else bytes.fromhex(ciphertext)
iv = ciphertext[:16]
ciphertext = ciphertext[16:]
cipher = Cipher(algorithms.AES(self.key), modes.CBC(iv), backend=backend)
decryptor = cipher.decryptor()
unpadder = padding.PKCS7(cipher.algorithm.block_size).unpadder() # type: ignore
padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()
plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()
return plaintext.decode('utf-8')
class Md5Cipher:
"""MD5 加密器"""
@staticmethod
def encrypt(plaintext: bytes | str) -> str:
"""
MD5 加密。
参数:
- plaintext (bytes | str): 加密前的明文。
返回:
- str: MD5 十六进制摘要。
"""
md5 = hashlib.md5()
if not isinstance(plaintext, bytes):
plaintext = str(plaintext).encode('utf-8')
md5.update(plaintext)
return md5.hexdigest()
class ItsDCipher:
"""ItsDangerous 加密器"""
def __init__(self, key: bytes | str) -> None:
"""
初始化 ItsDangerous 加密器。
参数:
- key (bytes | str): 密钥16/24/32 bytes 或 16 进制字符串。
返回:
- None
"""
self.key = key if isinstance(key, bytes) else bytes.fromhex(key)
def encrypt(self, plaintext: Any) -> str:
"""
ItsDangerous 加密。
参数:
- plaintext (Any): 加密前的明文。
返回:
- str: 加密后的密文URL安全
异常:
- Exception: 加密失败时使用 MD5 作为降级,错误已记录。
"""
serializer = URLSafeSerializer(self.key)
try:
ciphertext = serializer.dumps(plaintext)
except Exception as e:
log.error(f'ItsDangerous encrypt failed: {e}')
ciphertext = Md5Cipher.encrypt(plaintext)
return ciphertext
def decrypt(self, ciphertext: str) -> Any:
"""
ItsDangerous 解密。
参数:
- ciphertext (str): 解密前的密文。
返回:
- Any: 解密后的明文;失败时返回原密文。
异常:
- Exception: 解密失败时记录错误并返回原密文。
"""
serializer = URLSafeSerializer(self.key)
try:
plaintext = serializer.loads(ciphertext)
except Exception as e:
log.error(f'ItsDangerous decrypt failed: {e}')
plaintext = ciphertext
return plaintext

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
import importlib
import inspect
import os
from pathlib import Path
from functools import lru_cache
from sqlalchemy import inspect as sa_inspect
from typing import Any, Type
from app.config.path_conf import BASE_DIR
class ImportUtil:
@classmethod
def find_project_root(cls) -> Path:
"""
查找项目根目录
:return: 项目根目录路径
"""
return BASE_DIR
@classmethod
def is_valid_model(cls, obj: Any, base_class: Type) -> bool:
"""
验证是否为有效的SQLAlchemy模型类
:param obj: 待验证的对象
:param base_class: SQLAlchemy的基类
:return: 验证结果
"""
# 必须继承自base_class且不是base_class本身
if not (inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class):
return False
# 必须有表名定义(排除抽象基类)
if not hasattr(obj, '__tablename__') or obj.__tablename__ is None:
return False
# 必须有至少一个列定义
try:
return len(sa_inspect(obj).columns) > 0
except Exception:
return False
@classmethod
@lru_cache(maxsize=256)
def find_models(cls, base_class: Type) -> list[Any]:
"""
查找并过滤有效的模型类,避免重复和无效定义
:param base_class: SQLAlchemy的Base类用于验证模型类
:return: 有效模型类列表
"""
models = []
# 按类对象去重
seen_models = set()
# 按表名去重(防止同表名冲突)
seen_tables = set()
# 记录已经处理过的model.py文件路径
processed_model_files = set()
project_root = cls.find_project_root()
print(f"⏰️ 开始在项目根目录 {project_root} 中查找模型...")
# 排除目录扩展
exclude_dirs = {
'venv',
'.env',
'.git',
'__pycache__',
'migrations',
'alembic',
'tests',
'test',
'docs',
'examples',
'scripts',
'.venv',
'__pycache__',
'static',
'templates',
'sql',
'env'
}
# 定义要搜索的模型目录模式
model_dir_patterns = [
'model.py',
'models.py'
]
# 使用一个更高效的方法来查找所有model.py文件
model_files = []
for root, dirs, files in os.walk(project_root):
# 过滤排除目录
dirs[:] = [d for d in dirs if d not in exclude_dirs]
for file in files:
if file in model_dir_patterns:
file_path = Path(root) / file
# 构建相对于项目根的模块路径
relative_path = file_path.relative_to(project_root)
model_files.append((file_path, relative_path))
print(f"🔍 找到 {len(model_files)} 个模型文件")
# 按模块路径排序,确保先导入基础模块
model_files.sort(key=lambda x: str(x[1]))
for file_path, relative_path in model_files:
# 确保文件路径没有被处理过
if str(file_path) in processed_model_files:
continue
processed_model_files.add(str(file_path))
# 构建模块名(将路径分隔符转换为点)
module_parts = relative_path.parts[:-1] + (relative_path.stem,)
module_name = '.'.join(module_parts)
try:
# 导入模块
module = importlib.import_module(module_name)
# 获取模块中的所有类
for name, obj in inspect.getmembers(module, inspect.isclass):
# 验证模型有效性
if not cls.is_valid_model(obj, base_class):
continue
# 检查类对象重复
if obj in seen_models:
continue
# 检查表名重复
table_name = obj.__tablename__
if table_name in seen_tables:
continue
# 添加到已处理集合
seen_models.add(obj)
seen_tables.add(table_name)
models.append(obj)
print(f'✅️ 找到有效模型: {obj.__module__}.{obj.__name__} (表: {table_name})')
except ImportError as e:
if 'cannot import name' not in str(e):
print(f'❗️ 警告: 无法导入模块 {module_name}: {e}')
except Exception as e:
print(f'❌️ 处理模块 {module_name} 时出错: {e}')
# 查找apscheduler_jobs表的模型如果存在
cls._find_apscheduler_model(base_class, models, seen_models, seen_tables)
return models
@classmethod
def _find_apscheduler_model(cls, base_class: Type, models: list[Any], seen_models: set[Any], seen_tables: set[str]):
"""
专门查找APScheduler相关的模型
:param base_class: SQLAlchemy的Base类
:param models: 模型列表
:param seen_models: 已处理的模型集合
:param seen_tables: 已处理的表名集合
"""
# 尝试从apscheduler相关模块导入
try:
# 检查是否有自定义的apscheduler模型
for module_name in ['app.core.ap_scheduler', 'app.module_task.scheduler_test']:
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if cls.is_valid_model(obj, base_class) and hasattr(obj, '__tablename__') and obj.__tablename__ == 'apscheduler_jobs':
if obj not in seen_models and 'apscheduler_jobs' not in seen_tables:
seen_models.add(obj)
seen_tables.add('apscheduler_jobs')
models.append(obj)
print(f'✅️ 找到有效模型: {obj.__module__}.{obj.__name__} (表: apscheduler_jobs)')
except ImportError:
pass
except Exception as e:
print(f'❗️ 查找APScheduler模型时出错: {e}')

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
import re
import httpx
from app.core.logger import log
class IpLocalUtil:
"""
获取IP归属地工具类
"""
@classmethod
def is_valid_ip(cls, ip: str) -> bool:
"""
校验IP格式是否合法。
参数:
- ip (str): IP地址。
返回:
- bool: 是否合法。
"""
ip_pattern = r'^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$'
return bool(re.match(ip_pattern, ip))
@classmethod
def is_private_ip(cls, ip: str) -> bool:
"""
判断是否为内网IP。
参数:
- ip (str): IP地址。
返回:
- bool: 是否为内网IP。
"""
priv_pattern = r'^(127\.|10\.|172\.(1[6-9]|2[0-9]|3[01])\.|192\.168\.)'
return bool(re.match(priv_pattern, ip))
@classmethod
async def get_ip_location(cls, ip: str) -> str | None:
"""
获取IP归属地信息。
参数:
- ip (str): IP地址。
返回:
- str | None: IP归属地信息失败时返回"未知"或None。
"""
# 校验IP格式
if not cls.is_valid_ip(ip):
log.error(f"IP格式不合法: {ip}")
return "未知"
# 内网IP直接返回
if cls.is_private_ip(ip):
return '内网IP'
try:
# 使用ip-api.com API获取IP归属地信息
async with httpx.AsyncClient(timeout=10.0) as client:
# 尝试使用 ip9.com.cn API
url = f'https://ip9.com.cn/get?ip={ip}'
response = await cls._make_api_request(client, url)
if response and response.json().get('ret') == 200:
result = response.json().get('data', {})
return f"{result.get('country','')}-{result.get('prov','')}-{result.get('city','')}-{result.get('area','')}-{result.get('isp','')}"
# 尝试使用百度 API
url = f'https://qifu-api.baidubce.com/ip/geo/v1/district?ip={ip}'
response = await cls._make_api_request(client, url)
if response and response.json().get('code') == "Success":
data = response.json().get('data', {})
# 修正原代码中的格式错误
return f"{data.get('country','')}-{data.get('prov','')}-{data.get('city','')}-{data.get('district','')}-{data.get('isp','')}"
except Exception as e:
log.error(f"获取IP归属地失败: {e}")
return "未知"
@classmethod
async def _make_api_request(cls, client, url):
"""
单独的 API 请求方法,包含重试机制。
参数:
- client (AsyncClient): httpx 异步客户端。
- url (str): 请求 URL。
返回:
- Response | None: 响应对象失败时返回None。
"""
max_retries = 3
for attempt in range(max_retries):
try:
response = await client.get(url, timeout=10)
if response.status_code == 200:
return response
except Exception as e:
if attempt < max_retries - 1:
continue
log.error(f"请求 {url} 失败: {e}")
return None

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
import datetime
from cnlunar import Lunar
ZODIAC = ['', '', '', '', '', '', '', '', '', '', '', '']
SHICHEN = ['子时', '丑时', '寅时', '卯时', '辰时', '巳时', '午时', '未时', '申时', '酉时', '戌时', '亥时', '子时']
TIANGAN = ['', '', '', '', '', '', '', '', '', '']
DIZHI = ['', '', '', '', '', '', '', '', '', '', '', '']
WEEKDAYS = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日']
def get_zodiac(year: int) -> str:
"""根据年份计算生肖"""
return ZODIAC[(year - 4) % 12]
def get_ganzhi_year(year: int) -> str:
"""根据年份计算干支年(如:甲辰)"""
idx = (year - 4) % 60
return TIANGAN[idx % 10] + DIZHI[idx % 12]
def get_ganzhi_month(year: int, month: int) -> str:
"""简化算法计算干支月"""
idx = (year * 12 + month + 13) % 60
return TIANGAN[idx % 10] + DIZHI[idx % 12]
def get_ganzhi_day(dt_date: datetime.date) -> str:
"""基于儒略日计算干支日"""
a = (14 - dt_date.month) // 12
y = dt_date.year + 4800 - a
m = dt_date.month + 12 * a - 3
jdn = dt_date.day + (153 * m + 2) // 5 + 365 * y + y // 4 - y // 100 + y // 400 - 32045
idx = (jdn + 49) % 60
return TIANGAN[idx % 10] + DIZHI[idx % 12]
def _has_valid_time(dt: datetime.datetime) -> bool:
"""判断时间部分是否有效(非 00:00:00"""
return dt.hour != 0 or dt.minute != 0 or dt.second != 0
def get_shichen_ke(dt: datetime.datetime) -> tuple[str, int]:
"""根据时间计算时辰和刻数,返回 (时辰, 刻)"""
lunar = Lunar(dt)
shichen = SHICHEN[lunar.twohourNum]
minutes_in_shichen = (dt.hour * 60 + dt.minute) % 120
ke = minutes_in_shichen // 15 + 1
return shichen, ke
def format_lunar_date(dt: datetime.datetime, with_time: bool = True) -> str:
"""
将公历日期格式化为农历字符串。
with_time=True 且时间有效时,追加时辰刻数。
返回示例:(乙巳年 五月初六 申时3刻 或 (乙巳年 五月初六)
"""
lunar = Lunar(dt)
lunar_year = f"{lunar.year8Char}"
lunar_month = lunar.lunarMonthCn.replace('', '').replace('', '')
lunar_day = lunar.lunarDayCn
if with_time and _has_valid_time(dt):
shichen, ke = get_shichen_ke(dt)
return f"{lunar_year} {lunar_month}{lunar_day} {shichen}{ke}刻)"
return f"{lunar_year} {lunar_month}{lunar_day}"

View File

@@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
import oss2
import random
from datetime import datetime
from fastapi import UploadFile
from pathlib import Path
from urllib.parse import urljoin
from app.config.setting import settings
from app.core.exceptions import CustomException
from app.core.logger import log
class OSSUtil:
"""
阿里云OSS上传工具类
"""
def __init__(self):
"""初始化OSS客户端"""
try:
# 创建Bucket对象所有Object相关的接口都可以通过Bucket对象来进行
auth = oss2.Auth(settings.OSS_ACCESS_KEY_ID, settings.OSS_ACCESS_KEY_SECRET)
self.bucket = oss2.Bucket(auth, settings.OSS_ENDPOINT, settings.OSS_BUCKET_NAME)
log.info("OSS客户端初始化成功")
except Exception as e:
log.error(f"OSS客户端初始化失败: {e}")
raise CustomException(msg="OSS服务初始化失败")
@staticmethod
def generate_random_number() -> str:
"""
生成3位随机数字字符串。
返回:
- str: 三位随机数字字符串。
"""
return f'{random.randint(1, 999):03}'
@staticmethod
def check_file_extension(file: UploadFile) -> bool:
"""
检查文件后缀是否合法。
参数:
- file (UploadFile): 上传的文件对象。
返回:
- bool: 文件后缀是否合法。
异常:
- CustomException: 文件类型不支持时抛出。
"""
if file.content_type and file.filename:
# 优先使用文件名的扩展名
file_extension = '.' + file.filename.rsplit('.', 1)[-1].lower() if '.' in file.filename else None
if file_extension and file_extension in settings.ALLOWED_EXTENSIONS:
return True
raise CustomException(msg="文件类型不支持")
else:
raise CustomException(msg="文件类型不支持")
@staticmethod
def check_file_size(file: UploadFile) -> bool:
"""
校验文件大小是否合法。
参数:
- file (UploadFile): 上传的文件对象。
返回:
- bool: 文件大小是否合法(未提供 size 返回 False
"""
if file.size:
return file.size <= settings.MAX_FILE_SIZE
else:
return False
@classmethod
def generate_file_name(cls, filename: str) -> str:
"""
生成文件名称。
参数:
- filename (str): 原始文件名(包含拓展名)。
返回:
- str: 生成的文件名(包含时间戳、机器码、随机码)。
"""
name, ext = filename.rsplit(".", 1)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return f'{name}_{timestamp}{settings.UPLOAD_MACHINE}{cls.generate_random_number()}.{ext}'
@classmethod
def generate_oss_key(cls, filename: str) -> str:
"""
生成OSS对象键路径
参数:
- filename (str): 文件名
返回:
- str: OSS对象键格式如: upload/2026/02/08/filename.jpg
"""
date_path = datetime.now().strftime("%Y/%m/%d")
return f"upload/{date_path}/{filename}"
# 上传文件至OSS方法
async def upload_file(self, file: UploadFile) -> tuple[str, str, str]:
"""
上传文件到阿里云OSS
参数:
- file (UploadFile): 上传的文件对象。
返回:
- tuple[str, str, str]: (文件名, OSS对象键, 文件访问URL)。
异常:
- CustomException: 当文件类型不支持或大小超限时抛出。
"""
# 文件校验(校验文件大小、校验文件后缀)
if not all([self.check_file_extension(file), self.check_file_size(file)]):
raise CustomException(msg='文件类型或大小不合法')
try:
# 生成文件名
if not file.filename:
raise CustomException(msg='文件名不能为空')
# 生成文件名称
filename = self.generate_file_name(file.filename)
# 生成oss文件路径格式如: upload/2026/02/08/filename.jpg
oss_key = self.generate_oss_key(filename)
# 读取文件内容
file_content = await file.read()
# 上传到OSS
result = self.bucket.put_object(oss_key, file_content)
if result.status == 200:
# 生成访问URLOSS访问域名/文件路径信息)
file_url = f"{settings.OSS_DOMAIN}/{oss_key}"
log.info(f"文件上传OSS成功: {oss_key}")
return filename, oss_key, file_url
else:
log.error(f"OSS上传失败状态码: {result.status}")
raise CustomException(msg='文件上传失败')
except oss2.exceptions.OssError as e:
log.error(f"OSS上传异常: {e}")
raise CustomException(msg=f'OSS上传失败: {e}')
except Exception as e:
log.error(f"文件上传失败: {e}")
raise CustomException(msg='文件上传失败')
def delete_file(self, oss_key: str) -> bool:
"""
删除OSS中的文件
参数:
- oss_key (str): OSS对象键
返回:
- bool: 删除是否成功
"""
try:
result = self.bucket.delete_object(oss_key)
if result.status == 204:
log.info(f"OSS文件删除成功: {oss_key}")
return True
else:
log.error(f"OSS文件删除失败状态码: {result.status}")
return False
except oss2.exceptions.OssError as e:
log.error(f"OSS删除异常: {e}")
return False
except Exception as e:
log.error(f"文件删除失败: {e}")
return False
def get_file_url(self, oss_key: str) -> str:
"""
获取文件访问URL
参数:
- oss_key (str): OSS对象键
返回:
- str: 文件访问URL
"""
return f"{settings.OSS_DOMAIN}/{oss_key}"
def file_exists(self, oss_key: str) -> bool:
"""
检查文件是否存在
参数:
- oss_key (str): OSS对象键
返回:
- bool: 文件是否存在
"""
try:
return self.bucket.object_exists(oss_key)
except Exception as e:
log.error(f"检查文件存在性失败: {e}")
return False

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""
PDF工具类 - 用于生成和填充PDF文档
"""
import io
import os
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import mm
from reportlab.pdfgen import canvas
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.ttfonts import TTFont
from reportlab.lib.colors import HexColor
from pypdf import PdfReader, PdfWriter
from app.core.logger import log
class PDFUtil:
"""PDF工具类"""
# 字体路径(需要中文字体支持)
FONT_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'static', 'fonts')
@classmethod
def register_chinese_font(cls):
"""注册中文字体"""
try:
# 尝试注册思源黑体或其他中文字体
font_file = os.path.join(cls.FONT_PATH, 'SimHei.ttf')
if os.path.exists(font_file):
pdfmetrics.registerFont(TTFont('SimHei', font_file))
return 'SimHei'
# 尝试系统字体
system_fonts = [
'C:/Windows/Fonts/simhei.ttf', # Windows
'/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # Linux
'/System/Library/Fonts/PingFang.ttc', # macOS
]
for font_path in system_fonts:
if os.path.exists(font_path):
pdfmetrics.registerFont(TTFont('ChineseFont', font_path))
return 'ChineseFont'
log.warning("未找到中文字体,将使用默认字体")
return 'Helvetica'
except Exception as e:
log.error(f"注册中文字体失败: {e}")
return 'Helvetica'
@classmethod
def create_overlay_pdf(cls, data: dict, page_width: float, page_height: float) -> io.BytesIO:
"""
创建覆盖层PDF用于在模板上叠加文字
参数:
- data: 要填充的数据字典,格式: {字段名: (x, y, 值, 字体大小, 颜色)}
- page_width: 页面宽度
- page_height: 页面高度
返回:
- BytesIO: PDF字节流
"""
buffer = io.BytesIO()
c = canvas.Canvas(buffer, pagesize=(page_width, page_height))
font_name = cls.register_chinese_font()
for field_name, field_config in data.items():
if isinstance(field_config, tuple) and len(field_config) >= 3:
x, y, value = field_config[:3]
font_size = field_config[3] if len(field_config) > 3 else 12
color = field_config[4] if len(field_config) > 4 else '#000000'
c.setFont(font_name, font_size)
c.setFillColor(HexColor(color))
c.drawString(x * mm, y * mm, str(value))
c.save()
buffer.seek(0)
return buffer
@classmethod
def fill_pdf_template(cls, template_path: str, data: dict, field_positions: dict) -> bytes:
"""
填充PDF模板
参数:
- template_path: 模板文件路径
- data: 数据字典
- field_positions: 字段位置配置,格式: {字段名: (x, y, 字体大小, 颜色)}
返回:
- bytes: 填充后的PDF字节
"""
# 读取模板
reader = PdfReader(template_path)
writer = PdfWriter()
# 获取第一页尺寸
first_page = reader.pages[0]
page_width = float(first_page.mediabox.width)
page_height = float(first_page.mediabox.height)
# 构建覆盖数据
overlay_data = {}
for field_name, position in field_positions.items():
if field_name in data and data[field_name] is not None:
x, y = position[:2]
font_size = position[2] if len(position) > 2 else 12
color = position[3] if len(position) > 3 else '#000000'
overlay_data[field_name] = (x, y, data[field_name], font_size, color)
# 创建覆盖层
overlay_buffer = cls.create_overlay_pdf(overlay_data, page_width, page_height)
overlay_reader = PdfReader(overlay_buffer)
# 合并页面
for i, page in enumerate(reader.pages):
if i == 0 and len(overlay_reader.pages) > 0:
page.merge_page(overlay_reader.pages[0])
writer.add_page(page)
# 输出
output = io.BytesIO()
writer.write(output)
output.seek(0)
return output.read()

View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
import re
def search_string(pattern: str, text: str) -> re.Match[str] | None:
"""
全字段正则匹配
参数:
- pattern (str): 正则表达式模式。
- text (str): 待匹配的文本。
返回:
- re.Match[str] | None: 匹配结果。
"""
if not pattern or not text:
return None
result = re.search(pattern, text)
return result
def match_string(pattern: str, text: str) -> re.Match[str] | None:
"""
从字段开头正则匹配
参数:
- pattern (str): 正则表达式模式。
- text (str): 待匹配的文本。
返回:
- re.Match[str] | None: 匹配结果。
"""
if not pattern or not text:
return None
result = re.match(pattern, text)
return result
def is_phone(number: str) -> re.Match[str] | None:
"""
检查手机号码格式
参数:
- number (str): 待检查的手机号码。
返回:
- re.Match[str] | None: 匹配结果。
"""
if not number:
return None
phone_pattern = r'^1[3-9]\d{9}$'
return match_string(phone_pattern, number)
def is_git_url(url: str) -> re.Match[str] | None:
"""
检查 git URL 格式
参数:
- url (str): 待检查的 URL。
返回:
- re.Match[str] | None: 匹配结果。
"""
if not url:
return None
git_pattern = r'^(?!(git\+ssh|ssh)://|git@)(?P<scheme>git|https?|file)://(?P<host>[^/]*)(?P<path>(?:/[^/]*)*/)(?P<repo>[^/]+?)(?:\.git)?$'
return match_string(git_pattern, url)

View File

@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
import json
import random
import string
from typing import Dict, Any
from alibabacloud_dysmsapi20170525.client import Client as DysmsapiClient
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_dysmsapi20170525 import models as dysmsapi_models
from alibabacloud_tea_util import models as util_models
from app.core.logger import log
from app.core.exceptions import CustomException
class SMSUtil:
"""阿里云短信服务工具类"""
def __init__(self):
"""初始化阿里云短信客户端"""
# 阿里云短信配置信息
self.access_key_id = "LTAI5t7ox6LSot4bTXQiU39R"
self.access_key_secret = "X4Z5K3ZSrZcXzcc5HgWZNmMUmTvK8N"
self.region_id = "cn-hangzhou"
self.sign_name = "山东实战派网络科技"
self.international_sign_name = "Shando Tech"
self.enable_international = True
self.domestic_endpoint = "dysmsapi.aliyuncs.com"
self.international_endpoint = "dysmsapi.aliyuncs.com"
# 短信模板配置
self.templates = {
# 国内短信模板
"register": "SMS_486390015", # 注册验证码模板
"resetpwd": "SMS_486445022", # 重置密码验证码模板
"changepwd": "SMS_486450016", # 修改密码验证码模板
"changemobile": "SMS_487250049", # 修改手机号验证码模板
"mobilelogin": "SMS_487410035", # 手机登录验证码模板
# 国际短信模板
"international_register": "SMS_INTL_486390015",
"international_resetpwd": "SMS_INTL_486445022",
"international_changepwd": "SMS_INTL_486450016",
"international_changemobile": "SMS_INTL_487250049",
"international_mobilelogin": "SMS_INTL_487410035"
}
def _create_client(self, is_international: bool = False) -> DysmsapiClient:
"""创建阿里云短信客户端"""
config = open_api_models.Config(
access_key_id=self.access_key_id,
access_key_secret=self.access_key_secret
)
# 设置访问的域名
if is_international:
config.endpoint = self.international_endpoint
else:
config.endpoint = self.domestic_endpoint
return DysmsapiClient(config)
def _is_international_mobile(self, mobile: str) -> bool:
"""判断是否为国际手机号"""
# 简单判断中国大陆手机号以1开头11位数字
if mobile.startswith('+'):
return True
if len(mobile) == 11 and mobile.startswith('1') and mobile.isdigit():
return False
return True
def generate_verification_code(self, length: int = 6) -> str:
"""生成验证码"""
return ''.join(random.choices(string.digits, k=length))
async def send_sms(
self,
mobile: str,
template_type: str,
template_params: Dict[str, Any] = None
) -> bool:
"""
发送短信
参数:
- mobile: 手机号
- template_type: 模板类型 (register, resetpwd, changepwd, changemobile, mobilelogin)
- template_params: 模板参数,如 {"code": "123456"}
返回:
- bool: 发送是否成功
"""
try:
# 判断是否为国际手机号
is_international = self._is_international_mobile(mobile)
# 获取对应的模板ID和签名
if is_international and self.enable_international:
template_id = self.templates.get(f"international_{template_type}")
sign_name = self.international_sign_name
else:
template_id = self.templates.get(template_type)
sign_name = self.sign_name
if not template_id:
raise CustomException(msg=f"未找到模板类型: {template_type}")
# 创建客户端
client = self._create_client(is_international)
# 构建请求
send_sms_request = dysmsapi_models.SendSmsRequest(
phone_numbers=mobile,
sign_name=sign_name,
template_code=template_id,
template_param=json.dumps(template_params) if template_params else None
)
# 发送短信
runtime = util_models.RuntimeOptions()
response = client.send_sms_with_options(send_sms_request, runtime)
# 检查响应
if response.body.code == "OK":
log.info(f"短信发送成功: {mobile}, 模板: {template_type}")
return True
else:
error_code = response.body.code
error_message = response.body.message
log.error(f"短信发送失败: {mobile}, 错误码: {error_code}, 错误信息: {error_message}")
# 根据错误码抛出不同的异常信息
if error_code == "isv.BUSINESS_LIMIT_CONTROL":
if "小时级流控" in error_message:
raise CustomException(msg="发送过于频繁同一手机号1小时内最多发送5条短信请稍后再试")
elif "天级流控" in error_message:
raise CustomException(msg="今日短信发送次数已达上限,请明天再试")
else:
raise CustomException(msg="短信发送频率超限,请稍后再试")
elif error_code == "isv.SMS_SIGNATURE_ILLEGAL":
raise CustomException(msg="短信签名配置错误,请联系管理员")
elif error_code == "isv.SMS_TEMPLATE_ILLEGAL":
raise CustomException(msg="短信模板配置错误,请联系管理员")
elif error_code == "isv.INVALID_PARAMETERS":
raise CustomException(msg="手机号格式错误")
elif error_code == "isv.MOBILE_NUMBER_ILLEGAL":
raise CustomException(msg="手机号格式不正确或为空号")
elif error_code == "isv.AMOUNT_NOT_ENOUGH":
raise CustomException(msg="短信余额不足,请联系管理员")
else:
raise CustomException(msg=f"短信发送失败: {error_message}")
return False
except CustomException:
# 重新抛出业务异常
raise
except Exception as e:
log.error(f"短信发送异常: {mobile}, 错误: {str(e)}")
raise CustomException(msg="短信服务异常,请稍后重试")
async def send_verification_code(
self,
mobile: str,
code_type: str = "register"
) -> tuple[bool, str]:
"""
发送验证码短信
参数:
- mobile: 手机号
- code_type: 验证码类型 (register, resetpwd, changepwd, changemobile, mobilelogin)
返回:
- tuple[bool, str]: (是否成功, 验证码)
"""
# 生成验证码
verification_code = self.generate_verification_code()
# 发送短信如果发生异常会直接抛出不会返回False
try:
success = await self.send_sms(
mobile=mobile,
template_type=code_type,
template_params={"code": verification_code}
)
return success, verification_code if success else ""
except CustomException:
# 重新抛出业务异常,让上层处理
raise

View File

@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
from app.common.constant import CommonConstant
class StringUtil:
"""
字符串工具类
"""
@classmethod
def is_blank(cls, string: str) -> bool:
"""
校验字符串是否为''或全空格
参数:
- string (str): 需要校验的字符串。
返回:
- bool: 校验结果。
"""
if string is None:
return False
str_len = len(string)
if str_len == 0:
return True
else:
for i in range(str_len):
if string[i] != ' ':
return False
return True
@classmethod
def is_empty(cls, string) -> bool:
"""
校验字符串是否为''或None
参数:
- string (str | None): 需要校验的字符串。
返回:
- bool: 校验结果。
"""
return string is None or len(string) == 0
@classmethod
def is_not_empty(cls, string: str) -> bool:
"""
校验字符串是否不是''和None
参数:
- string (str): 需要校验的字符串。
返回:
- bool: 校验结果。
"""
return not cls.is_empty(string)
@classmethod
def is_http(cls, link: str):
"""
判断是否为 http(s):// 开头
参数:
- link (str): 链接。
返回:
- bool: 是否为 http(s):// 开头。
"""
return link.startswith(CommonConstant.HTTP) or link.startswith(CommonConstant.HTTPS)
@classmethod
def contains_ignore_case(cls, search_str: str, compare_str: str):
"""
查找指定字符串是否包含指定字符串同时忽略大小写
参数:
- search_str (str): 查找的字符串。
- compare_str (str): 比对的字符串。
返回:
- bool: 查找结果。
"""
if compare_str and search_str:
return compare_str.lower() in search_str.lower()
return False
@classmethod
def contains_any_ignore_case(cls, search_str: str, compare_str_list: list[str]):
"""
查找指定字符串是否包含列表中的任意一个字符串(忽略大小写)
参数:
- search_str (str): 查找的字符串。
- compare_str_list (list[str]): 比对的字符串列表。
返回:
- bool: 查找结果。
"""
if search_str and compare_str_list:
return any([cls.contains_ignore_case(search_str, compare_str) for compare_str in compare_str_list])
return False
@classmethod
def equals_ignore_case(cls, search_str: str, compare_str: str):
"""
比较两个字符串是否相等(忽略大小写)
参数:
- search_str (str): 查找的字符串。
- compare_str (str): 比对的字符串。
返回:
- bool: 比较结果。
"""
if search_str and compare_str:
return search_str.lower() == compare_str.lower()
return False
@classmethod
def equals_any_ignore_case(cls, search_str: str, compare_str_list: list[str]):
"""
判断指定字符串是否与列表中任意一个字符串相等(忽略大小写)
参数:
- search_str (str): 查找的字符串。
- compare_str_list (list[str]): 比对的字符串列表。
返回:
- bool: 比较结果。
"""
if search_str and compare_str_list:
return any([cls.equals_ignore_case(search_str, compare_str) for compare_str in compare_str_list])
return False
@classmethod
def startswith_case(cls, search_str: str, compare_str: str):
"""
查找指定字符串是否以指定字符串开头
参数:
- search_str (str): 查找的字符串。
- compare_str (str): 比对的字符串。
返回:
- bool: 查找结果。
"""
if compare_str and search_str:
return search_str.startswith(compare_str)
return False
@classmethod
def startswith_any_case(cls, search_str: str, compare_str_list: list[str]):
"""
查找指定字符串是否以列表中任意一个字符串开头
参数:
- search_str (str): 查找的字符串。
- compare_str_list (list[str]): 比对的字符串列表。
返回:
- bool: 查找结果。
"""
if search_str and compare_str_list:
return any([cls.startswith_case(search_str, compare_str) for compare_str in compare_str_list])
return False
@classmethod
def convert_to_camel_case(cls, name: str) -> str:
"""
将下划线大写方式命名的字符串转换为驼峰式;若输入为空则返回空字符串。
参数:
- name (str): 下划线大写方式命名的字符串。
返回:
- str: 转换后的驼峰式命名的字符串。
"""
if not name:
return ''
if '_' not in name:
return name[0].upper() + name[1:]
parts = name.split('_')
result = []
for part in parts:
if not part:
continue
result.append(part[0].upper() + part[1:].lower())
return ''.join(result)
@classmethod
def get_mapping_value_by_key_ignore_case(cls, mapping: dict[str, str], key: str) -> str:
"""
根据忽略大小写的键获取字典中的对应的值
参数:
- mapping (dict[str, str]): 字典。
- key (str): 字典的键。
返回:
- str: 字典键对应的值,未匹配则返回空字符串。
"""
for k, v in mapping.items():
if key.lower() == k.lower():
return v
return ''

View File

@@ -0,0 +1,281 @@
# -*- coding: utf-8 -*-
import re
from datetime import datetime
from typing import Any
class TimeUtil:
"""
时间格式化工具类
"""
DEFAULT_DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S'
@classmethod
def object_format_datetime(cls, obj: Any) -> Any:
"""
格式化对象中的 datetime 属性为默认字符串格式。
参数:
- obj (Any): 输入对象。
返回:
- Any: 格式化后的对象。
"""
for attr in dir(obj):
if not attr.startswith('_'): # 跳过私有属性
value = getattr(obj, attr)
if isinstance(value, datetime):
setattr(obj, attr, value.strftime(cls.DEFAULT_DATETIME_FORMAT))
return obj
@classmethod
def list_format_datetime(cls, lst: list[Any]) -> list[Any]:
"""
格式化列表内每个对象的 datetime 属性。
参数:
- lst (List[Any]): 对象列表。
返回:
- list[Any]: 格式化后的对象列表。
"""
return [cls.object_format_datetime(obj) for obj in lst]
@classmethod
def format_datetime_dict_list(cls, dicts: list[dict]) -> list[dict]:
"""
递归格式化字典列表中的 datetime 值为默认字符串格式。
参数:
- dicts (list[dict]): 字典列表。
返回:
- list[dict]: 格式化后的字典列表。
"""
def _format_value(value: Any) -> Any:
if isinstance(value, dict):
return {k: _format_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [_format_value(item) for item in value]
elif isinstance(value, datetime):
return value.strftime(cls.DEFAULT_DATETIME_FORMAT)
return value
return [_format_value(item) for item in dicts]
@classmethod
def __valid_range(cls, search_str: str, start_range: int, end_range: int) -> bool:
"""
校验范围字符串是否合法。
参数:
- search_str (str): 范围字符串(例如:"1-5")。
- start_range (int): 允许的最小范围值。
- end_range (int): 允许的最大范围值。
返回:
- bool: 校验是否通过。
"""
match = re.match(r'^(\d+)-(\d+)$', search_str)
if match:
start, end = int(match.group(1)), int(match.group(2))
return start_range <= start < end <= end_range
return False
@classmethod
def __valid_sum(cls, search_str: str, start_range_a: int, start_range_b: int, end_range_a: int, end_range_b: int, sum_range: int) -> bool:
"""
校验和字符串是否合法。
参数:
- search_str (str): 和字符串(例如:"1/5")。
- start_range_a (int): 允许的最小范围值A。
- start_range_b (int): 允许的最大范围值A。
- end_range_a (int): 允许的最小范围值B。
- end_range_b (int): 允许的最大范围值B。
- sum_range (int): 允许的最大和值。
返回:
- bool: 校验是否通过。
"""
match = re.match(r'^(\d+)/(\d+)$', search_str)
if match:
start, end = int(match.group(1)), int(match.group(2))
return (
start_range_a <= start <= start_range_b
and end_range_a <= end <= end_range_b
and start + end <= sum_range
)
return False
@classmethod
def validate_second_or_minute(cls, second_or_minute: str):
"""
校验秒或分钟字段的合法性。
参数:
- second_or_minute (str): 秒或分钟值。
返回:
- bool: 校验是否通过。
"""
if (
second_or_minute == '*'
or ('-' in second_or_minute and cls.__valid_range(second_or_minute, 0, 59))
or ('/' in second_or_minute and cls.__valid_sum(second_or_minute, 0, 58, 1, 59, 59))
or re.match(r'^(?:[0-5]?\d|59)(?:,[0-5]?\d|59)*$', second_or_minute)
):
return True
return False
@classmethod
def validate_hour(cls, hour: str):
"""
校验小时字段的合法性。
参数:
- hour (str): 小时值。
返回:
- bool: 校验是否通过。
"""
if (
hour == '*'
or ('-' in hour and cls.__valid_range(hour, 0, 23))
or ('/' in hour and cls.__valid_sum(hour, 0, 22, 1, 23, 23))
or re.match(r'^(?:0|[1-9]|1\d|2[0-3])(?:,(?:0|[1-9]|1\d|2[0-3]))*$', hour)
):
return True
return False
@classmethod
def validate_day(cls, day: str):
"""
校验日期字段的合法性。
参数:
- day (str): 日值。
返回:
- bool: 校验是否通过。
"""
if (
day in ['*', '?', 'L']
or ('-' in day and cls.__valid_range(day, 1, 31))
or ('/' in day and cls.__valid_sum(day, 1, 30, 1, 30, 31))
or ('W' in day and re.match(r'^(?:[1-9]|1\d|2\d|3[01])W$', day))
or re.match(r'^(?:0|[1-9]|1\d|2[0-9]|3[0-1])(?:,(?:0|[1-9]|1\d|2[0-9]|3[0-1]))*$', day)
):
return True
return False
@classmethod
def validate_month(cls, month: str):
"""
校验月份字段的合法性。
参数:
- month (str): 月值。
返回:
- bool: 校验是否通过。
"""
if (
month == '*'
or ('-' in month and cls.__valid_range(month, 1, 12))
or ('/' in month and cls.__valid_sum(month, 1, 11, 1, 11, 12))
or re.match(r'^(?:0|[1-9]|1[0-2])(?:,(?:0|[1-9]|1[0-2]))*$', month)
):
return True
return False
@classmethod
def validate_week(cls, week: str):
"""
校验星期字段的合法性。
参数:
- week (str): 周值。
返回:
- bool: 校验是否通过。
"""
if (
week in ['*', '?']
or ('-' in week and cls.__valid_range(week, 1, 7))
or ('#' in week and re.match(r'^[1-7]#[1-4]$', week))
or ('L' in week and re.match(r'^[1-7]L$', week))
or re.match(r'^[1-7](?:(,[1-7]))*$', week)
):
return True
return False
@classmethod
def validate_year(cls, year: str):
"""
校验年份字段的合法性。
参数:
- year (str): 年值。
返回:
- bool: 校验是否通过。
"""
current_year = int(datetime.now().year)
future_years = [current_year + i for i in range(9)]
if (
year == '*'
or ('-' in year and cls.__valid_range(year, current_year, 2099))
or ('/' in year and cls.__valid_sum(year, current_year, 2098, 1, 2099 - current_year, 2099))
or ('#' in year and re.match(r'^[1-7]#[1-4]$', year))
or ('L' in year and re.match(r'^[1-7]L$', year))
or (
(len(year) == 4 or ',' in year)
and all(int(item) in future_years and current_year <= int(item) <= 2099 for item in year.split(','))
)
):
return True
return False
@classmethod
def validate_cron_expression(cls, cron_expression: str):
"""
校验 Cron 表达式是否正确。
* * * * * *
| | | | | |
| | | | | +--- 星期0-70和7都表示星期日
| | | | +----- 月份1-12
| | | +------- 日期1-31
| | +--------- 小时0-23
| +----------- 分钟0-59
+------------- 秒0-59部分环境不支持秒字段。
参数:
- cron_expression (str): Cron 表达式。
返回:
- bool: 校验是否通过。
"""
values = cron_expression.split()
if len(values) != 6 and len(values) != 7:
return False
second_validation = cls.validate_second_or_minute(values[0])
minute_validation = cls.validate_second_or_minute(values[1])
hour_validation = cls.validate_hour(values[2])
day_validation = cls.validate_day(values[3])
month_validation = cls.validate_month(values[4])
week_validation = cls.validate_week(values[5])
validation = (
second_validation
and minute_validation
and hour_validation
and day_validation
and month_validation
and week_validation
)
if len(values) == 6:
return validation
if len(values) == 7:
year_validation = cls.validate_year(values[6])
return validation and year_validation

View File

@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
import random
import mimetypes
from datetime import datetime
import aiofiles
from fastapi import UploadFile
from pathlib import Path
from urllib.parse import urljoin
from app.config.setting import settings
from app.core.exceptions import CustomException
from app.core.logger import log
class UploadUtil:
"""
上传工具类
"""
@staticmethod
def generate_random_number() -> str:
"""
生成3位随机数字字符串。
返回:
- str: 三位随机数字字符串。
"""
return f'{random.randint(1, 999):03}'
@staticmethod
def check_file_exists(filepath: str) -> bool:
"""
检查文件是否存在。
参数:
- filepath (str): 文件路径。
返回:
- bool: 文件是否存在。
"""
return Path(filepath).exists()
@staticmethod
def check_file_extension(file: UploadFile) -> bool:
"""
检查文件后缀是否合法。
参数:
- file (UploadFile): 上传的文件对象。
返回:
- bool: 文件后缀是否合法。
异常:
- CustomException: 文件类型不支持时抛出。
"""
if file.content_type and file.filename:
# 优先使用文件名的扩展名
file_extension = '.' + file.filename.rsplit('.', 1)[-1].lower() if '.' in file.filename else None
if file_extension and file_extension in settings.ALLOWED_EXTENSIONS:
return True
# 备用使用content_type推断
guessed_ext = mimetypes.guess_extension(file.content_type)
if guessed_ext and guessed_ext in settings.ALLOWED_EXTENSIONS:
return True
raise CustomException(msg="文件类型不支持")
else:
raise CustomException(msg="文件类型不支持")
@staticmethod
def check_file_timestamp(filename: str) -> bool:
"""
校验文件时间戳是否合法。
参数:
- filename (str): 文件名(包含时间戳片段)。
返回:
- bool: 时间戳是否合法。
"""
try:
name_parts = filename.rsplit('.', 1)[0].split('_')
timestamp = name_parts[-1].split(settings.UPLOAD_MACHINE)[0]
datetime.strptime(timestamp, '%Y%m%d%H%M%S')
return True
except (ValueError, IndexError):
return False
@staticmethod
def check_file_machine(filename: str) -> bool:
"""
校验文件机器码是否合法。
参数:
- filename (str): 文件名。
返回:
- bool: 机器码是否合法。
"""
try:
name_without_ext = filename.rsplit('.', 1)[0]
return len(name_without_ext) >= 4 and name_without_ext[-4] == settings.UPLOAD_MACHINE
except IndexError:
return False
@staticmethod
def check_file_random_code(filename: str) -> bool:
"""
校验文件随机码是否合法。
参数:
- filename (str): 文件名。
返回:
- bool: 随机码是否合法000999
"""
try:
code = filename.rsplit('.', 1)[0][-3:]
return code.isdigit() and 1 <= int(code) <= 999
except IndexError:
return False
@staticmethod
def check_file_size(file: UploadFile) -> bool:
"""
校验文件大小是否合法。
参数:
- file (UploadFile): 上传的文件对象。
返回:
- bool: 文件大小是否合法(未提供 size 返回 False
"""
if file.size:
return file.size <= settings.MAX_FILE_SIZE
else:
return False
@classmethod
def generate_file_name(cls, filename: str) -> str:
"""
生成文件名称。
参数:
- filename (str): 原始文件名(包含拓展名)。
返回:
- str: 生成的文件名(包含时间戳、机器码、随机码)。
"""
name, ext = filename.rsplit(".", 1)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return f'{name}_{timestamp}{settings.UPLOAD_MACHINE}{cls.generate_random_number()}.{ext}'
@staticmethod
def generate_file(filepath: Path, chunk_size: int = 8192):
"""
根据文件生成二进制数据迭代器。
参数:
- filepath (Path): 文件路径。
- chunk_size (int): 分块大小,默认 8192 字节。
返回:
- Iterator[bytes]: 文件二进制数据分块迭代器。
"""
with filepath.open('rb') as f:
while chunk := f.read(chunk_size):
yield chunk
@staticmethod
def delete_file(filepath: Path) -> bool:
"""
删除文件。
参数:
- filepath (Path): 文件路径。
返回:
- bool: 删除是否成功。
"""
try:
filepath.unlink(missing_ok=True)
return True
except OSError:
return False
@classmethod
async def upload_file(cls, file: UploadFile, base_url: str) -> tuple[str, Path, str]:
"""
文件上传。
参数:
- file (UploadFile): 上传的文件对象。
- base_url (str): 基础 URL。
返回:
- tuple[str, Path, str]: (文件名, 文件路径, 文件 URL)。
异常:
- CustomException: 当文件类型不支持或大小超限时抛出。
"""
# 文件校验
if not all([cls.check_file_extension(file), cls.check_file_size(file)]):
raise CustomException(msg='文件类型或大小不合法')
try:
# 构建完整的目录路径
dir_path = settings.UPLOAD_FILE_PATH.joinpath(datetime.now().strftime("%Y/%m/%d"))
log.info(f"上传目录路径 dir_path is {dir_path}")
dir_path.mkdir(parents=True, exist_ok=True)
filename = ""
# 生成文件名并保存
if file.filename:
filename = cls.generate_file_name(file.filename)
filepath = dir_path.joinpath(filename)
file_url = urljoin(base_url, filepath.as_posix())
# filepath.mkdir(parents=True, exist_ok=True)
# 分块写入文件
chunk_size = 8 * 1024 * 1024 # 8MB chunks
async with aiofiles.open(filepath, 'wb') as f:
while chunk := await file.read(chunk_size):
await f.write(chunk)
# 返回相对路径
return filename, filepath, file_url
except Exception as e:
log.error(f"文件上传失败: {e}")
raise CustomException(msg='文件上传失败')
@staticmethod
def get_file_tree(file_path: str) -> list[dict]:
"""
获取文件树结构。
参数:
- file_path (str): 文件路径。
返回:
- list[dict]: 文件树列表。
"""
return [{"name": item.name, "is_dir": item.is_dir()} for item in Path(file_path).iterdir()]
@classmethod
async def download_file(cls, file_path: str) -> str:
"""
下载文件,生成新的文件名。
参数:
- file_path (str): 文件路径。
返回:
- str: 文件下载信息。
"""
filename = cls.generate_file(Path(file_path))
return str(filename)