upload project source code

This commit is contained in:
2026-04-30 18:49:43 +08:00
commit 9b394ba682
2277 changed files with 660945 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
from typing import AsyncGenerator
from openai import AsyncOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
import httpx
from app.config.setting import settings
from app.core.logger import log
class AIClient:
"""
AI客户端类用于与OpenAI API交互。
"""
def __init__(self):
self.model = settings.OPENAI_MODEL
# 创建一个不带冲突参数的httpx客户端
self.http_client = httpx.AsyncClient(
timeout=30.0,
follow_redirects=True
)
# 使用自定义的http客户端
self.client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL,
http_client=self.http_client
)
def _friendly_error_message(self, e: Exception) -> str:
"""将 OpenAI 或网络异常转换为友好的中文提示。"""
# 尝试获取状态码与错误体
status_code = getattr(e, "status_code", None)
body = getattr(e, "body", None)
message = None
error_type = None
error_code = None
try:
if isinstance(body, dict) and "error" in body:
err = body.get("error") or {}
error_type = err.get("type")
error_code = err.get("code")
message = err.get("message")
except Exception:
# 忽略解析失败
pass
text = str(e)
msg = message or text
# 特定错误映射
# 欠费/账户状态异常
if (error_code == "Arrearage") or (error_type == "Arrearage") or ("in good standing" in (msg or "")):
return "账户欠费或结算异常,访问被拒绝。请检查账号状态或更换有效的 API Key。"
# 鉴权失败
if status_code == 401 or "invalid api key" in msg.lower():
return "鉴权失败API Key 无效或已过期。请检查系统配置中的 API Key。"
# 权限不足或被拒绝
if status_code == 403 or error_type in {"PermissionDenied", "permission_denied"}:
return "访问被拒绝,权限不足或账号受限。请检查账户权限设置。"
# 配额不足或限流
if status_code == 429 or error_type in {"insufficient_quota", "rate_limit_exceeded"}:
return "请求过于频繁或配额已用尽。请稍后重试或提升账户配额。"
# 客户端错误
if status_code == 400:
return f"请求参数错误或服务拒绝:{message or '请检查输入内容。'}"
# 服务端错误
if status_code in {500, 502, 503, 504}:
return "服务暂时不可用,请稍后重试。"
# 默认兜底
return f"处理您的请求时出现错误:{msg}"
async def process(self, query: str) -> AsyncGenerator[str, None]:
"""
处理查询并返回流式响应
参数:
- query (str): 用户查询。
返回:
- AsyncGenerator[str, None]: 流式响应内容。
"""
system_prompt = """你是一个有用的AI助手可以帮助用户回答问题和提供帮助。请用中文回答用户的问题。"""
try:
# 使用 await 调用异步客户端
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
],
stream=True
)
# 流式返回响应
async for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
# 记录详细错误,返回友好提示
log.error(f"AI处理查询失败: {str(e)}")
yield self._friendly_error_message(e)
async def close(self) -> None:
"""
关闭客户端连接
"""
if hasattr(self, 'client'):
await self.client.close()
if hasattr(self, 'http_client'):
await self.http_client.aclose()

View File

@@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
"""
文档处理工具类
支持 txt、pdf、doc、docx、md 格式解析
"""
import os
import tempfile
from pathlib import Path
from typing import List, Optional, BinaryIO
from langchain.schema import Document
from fastapi import UploadFile
from app.core.logger import log
class DocumentProcessor:
"""文档处理器"""
# 支持的文件扩展名
SUPPORTED_EXTENSIONS = {'.txt', '.pdf', '.doc', '.docx', '.md'}
@classmethod
def is_supported(cls, filename: str) -> bool:
"""检查文件是否支持"""
ext = Path(filename).suffix.lower()
return ext in cls.SUPPORTED_EXTENSIONS
@classmethod
async def process_upload_file(cls, file: UploadFile) -> List[Document]:
"""
处理上传的文件
参数:
- file: FastAPI UploadFile 对象
返回:
- 文档列表
"""
if not file.filename:
log.warning("文件名为空,跳过处理")
return []
ext = Path(file.filename).suffix.lower()
if not cls.is_supported(file.filename):
log.warning(f"不支持的文件类型: {ext}")
return []
# 读取文件内容
content = await file.read()
await file.seek(0) # 重置文件指针
# 根据文件类型处理
try:
if ext == '.txt':
return cls._process_txt(content, file.filename)
elif ext == '.md':
return cls._process_markdown(content, file.filename)
elif ext == '.pdf':
return await cls._process_pdf(content, file.filename)
elif ext in {'.doc', '.docx'}:
return await cls._process_word(content, file.filename, ext)
else:
log.warning(f"未知的文件类型: {ext}")
return []
except Exception as e:
log.error(f"处理文件失败: {file.filename}, 错误: {e}")
return []
@classmethod
def _process_txt(cls, content: bytes, filename: str) -> List[Document]:
"""处理 TXT 文件"""
try:
# 尝试不同编码
text = None
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
try:
text = content.decode(encoding)
break
except UnicodeDecodeError:
continue
if text is None:
log.error(f"无法解码文件: {filename}")
return []
return [Document(
page_content=text,
metadata={"source": filename, "type": "txt"}
)]
except Exception as e:
log.error(f"处理 TXT 文件失败: {filename}, 错误: {e}")
return []
@classmethod
def _process_markdown(cls, content: bytes, filename: str) -> List[Document]:
"""处理 Markdown 文件"""
try:
text = content.decode('utf-8')
return [Document(
page_content=text,
metadata={"source": filename, "type": "markdown"}
)]
except Exception as e:
log.error(f"处理 Markdown 文件失败: {filename}, 错误: {e}")
return []
@classmethod
async def _process_pdf(cls, content: bytes, filename: str) -> List[Document]:
"""处理 PDF 文件"""
try:
# 使用 pypdf 或 pdfplumber 处理 PDF
import pypdf
from io import BytesIO
pdf_file = BytesIO(content)
reader = pypdf.PdfReader(pdf_file)
documents = []
for page_num, page in enumerate(reader.pages):
text = page.extract_text()
if text and text.strip():
documents.append(Document(
page_content=text,
metadata={
"source": filename,
"type": "pdf",
"page": page_num + 1
}
))
log.info(f"PDF 文件处理完成: {filename}, 共 {len(documents)}")
return documents
except ImportError:
log.error("未安装 pypdf 库,请运行: pip install pypdf")
return []
except Exception as e:
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
return []
@classmethod
async def _process_word(cls, content: bytes, filename: str, ext: str) -> List[Document]:
"""处理 Word 文件 (doc/docx)"""
try:
if ext == '.docx':
return cls._process_docx(content, filename)
else:
# .doc 格式需要特殊处理
return cls._process_doc(content, filename)
except Exception as e:
log.error(f"处理 Word 文件失败: {filename}, 错误: {e}")
return []
@classmethod
def _process_docx(cls, content: bytes, filename: str) -> List[Document]:
"""处理 DOCX 文件"""
try:
from docx import Document as DocxDocument
from io import BytesIO
docx_file = BytesIO(content)
doc = DocxDocument(docx_file)
# 提取所有段落文本
paragraphs = []
for para in doc.paragraphs:
if para.text.strip():
paragraphs.append(para.text)
text = '\n'.join(paragraphs)
if text.strip():
return [Document(
page_content=text,
metadata={"source": filename, "type": "docx"}
)]
return []
except ImportError:
log.error("未安装 python-docx 库,请运行: pip install python-docx")
return []
except Exception as e:
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
return []
@classmethod
def _process_doc(cls, content: bytes, filename: str) -> List[Document]:
"""
处理 DOC 文件 (旧版 Word 格式)
注意: .doc 格式处理需要额外依赖,这里做简单提示
"""
try:
# 尝试使用 antiword 或 textract
# 如果没有安装,建议用户转换为 docx 格式
import subprocess
import tempfile
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix='.doc', delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
# 尝试使用 antiword
result = subprocess.run(
['antiword', tmp_path],
capture_output=True,
text=True,
timeout=30
)
if result.returncode == 0 and result.stdout.strip():
return [Document(
page_content=result.stdout,
metadata={"source": filename, "type": "doc"}
)]
except (subprocess.TimeoutExpired, FileNotFoundError):
log.warning(f"无法处理 .doc 文件: {filename},建议转换为 .docx 格式")
finally:
os.unlink(tmp_path)
return []
except Exception as e:
log.error(f"处理 DOC 文件失败: {filename}, 错误: {e}")
return []
@classmethod
async def process_files(cls, files: List[UploadFile]) -> List[Document]:
"""
批量处理上传的文件
参数:
- files: UploadFile 列表
返回:
- 所有文档列表
"""
all_documents = []
for file in files:
if file.filename and cls.is_supported(file.filename):
docs = await cls.process_upload_file(file)
all_documents.extend(docs)
log.info(f"文件处理完成: {file.filename}, 提取 {len(docs)} 个文档")
else:
log.warning(f"跳过不支持的文件: {file.filename}")
log.info(f"批量处理完成,共 {len(all_documents)} 个文档")
return all_documents

View File

@@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
"""
向量化工具类
支持本地和远程 embedding 模型
"""
import asyncio
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor
import chromadb
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from app.core.logger import log
from app.config.path_conf import BASE_DIR
# ChromaDB 持久化目录
CHROMA_PERSIST_DIR = BASE_DIR / "data" / "chroma_db"
# 全局 ChromaDB 客户端实例(单例模式)
_chroma_client = None
def get_chroma_client():
"""获取全局 ChromaDB 客户端实例(单例模式)"""
global _chroma_client
if _chroma_client is None:
# 确保持久化目录存在
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
_chroma_client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST_DIR)
)
return _chroma_client
class EmbeddingUtil:
"""向量化工具类"""
def __init__(
self,
embedding_type: int = 0,
model_name: str = "text-embedding-ada-002",
base_url: Optional[str] = None,
api_key: Optional[str] = None
):
"""
初始化向量化工具
参数:
- embedding_type: 0=本地, 1=远程
- model_name: Embedding模型名称
- base_url: 远程接口地址(远程模式必填)
- api_key: 远程API Key(远程模式必填)
"""
self.embedding_type = embedding_type
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
# 初始化 embedding 模型
self._embeddings = None
@property
def embeddings(self):
"""延迟加载 embedding 模型"""
if self._embeddings is None:
self._embeddings = self._create_embeddings()
return self._embeddings
def _create_embeddings(self):
"""创建 embedding 模型实例"""
if self.embedding_type == 0:
# 本地模式 - 使用 sentence-transformers
# sentence-transformers==3.3.1
# 本地Embedding模型可选手动安装 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
try:
from langchain_community.embeddings import HuggingFaceEmbeddings
log.info(f"使用本地 Embedding 模型: {self.model_name}")
return HuggingFaceEmbeddings(
model_name=self.model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
except Exception as e:
log.error(f"加载本地 Embedding 模型失败: {e}")
raise
else:
# 远程模式 - 使用 OpenAI 兼容接口
if not self.base_url or not self.api_key:
raise ValueError("远程模式必须提供 base_url 和 api_key")
# 自动拼接 /v1 路径(如果未包含)
api_base = self.base_url.rstrip('/')
if not api_base.endswith('/v1'):
api_base = f"{api_base}/v1"
log.info(f"使用远程 Embedding 模型: {self.model_name}, URL: {api_base}")
return OpenAIEmbeddings(
model=self.model_name,
base_url=api_base,
api_key=self.api_key,
)
def get_vector_store(self, collection_name: str) -> Chroma:
"""
获取或创建向量存储
参数:
- collection_name: 集合名称
返回:
- Chroma 向量存储实例
"""
# 确保目录存在
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
return Chroma(
collection_name=collection_name,
embedding_function=self.embeddings,
persist_directory=str(CHROMA_PERSIST_DIR),
client=get_chroma_client(),
)
def split_documents(
self,
documents: List[Document],
chunk_size: int = 1000,
chunk_overlap: int = 200
) -> List[Document]:
"""
分割文档
参数:
- documents: 文档列表
- chunk_size: 分块大小
- chunk_overlap: 分块重叠大小
返回:
- 分割后的文档列表
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", "", "", "", ".", "!", "?", " ", ""]
)
return text_splitter.split_documents(documents)
def add_documents(
self,
collection_name: str,
documents: List[Document]
) -> int:
"""
添加文档到向量存储
参数:
- collection_name: 集合名称
- documents: 文档列表
返回:
- 添加的向量数量
"""
if not documents:
return 0
# 分割文档
split_docs = self.split_documents(documents)
log.info(f"文档分割完成,共 {len(split_docs)} 个片段")
# 获取向量存储
vector_store = self.get_vector_store(collection_name)
# 添加文档
vector_store.add_documents(split_docs)
log.info(f"向量存储完成,集合: {collection_name}, 向量数: {len(split_docs)}")
return len(split_docs)
def delete_collection(self, collection_name: str) -> bool:
"""
删除集合
参数:
- collection_name: 集合名称
返回:
- 是否删除成功
"""
try:
client = get_chroma_client()
client.delete_collection(collection_name)
log.info(f"删除集合成功: {collection_name}")
return True
except Exception as e:
log.error(f"删除集合失败: {collection_name}, 错误: {e}")
return False
def similarity_search(
self,
collection_name: str,
query: str,
k: int = 4
) -> List[Document]:
"""
相似度搜索
参数:
- collection_name: 集合名称
- query: 查询文本
- k: 返回结果数量
返回:
- 相似文档列表
"""
vector_store = self.get_vector_store(collection_name)
return vector_store.similarity_search(query, k=k)
def get_collection_count(self, collection_name: str) -> int:
"""
获取集合中的向量数量
参数:
- collection_name: 集合名称
返回:
- 向量数量
"""
try:
client = get_chroma_client()
collection = client.get_collection(collection_name)
return collection.count()
except Exception as e:
log.warning(f"获取集合数量失败: {collection_name}, 错误: {e}")
return 0