upload project source code
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
import httpx
|
||||
|
||||
from app.config.setting import settings
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""
|
||||
AI客户端类,用于与OpenAI API交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = settings.OPENAI_MODEL
|
||||
# 创建一个不带冲突参数的httpx客户端
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
# 使用自定义的http客户端
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
http_client=self.http_client
|
||||
)
|
||||
|
||||
def _friendly_error_message(self, e: Exception) -> str:
|
||||
"""将 OpenAI 或网络异常转换为友好的中文提示。"""
|
||||
# 尝试获取状态码与错误体
|
||||
status_code = getattr(e, "status_code", None)
|
||||
body = getattr(e, "body", None)
|
||||
message = None
|
||||
error_type = None
|
||||
error_code = None
|
||||
try:
|
||||
if isinstance(body, dict) and "error" in body:
|
||||
err = body.get("error") or {}
|
||||
error_type = err.get("type")
|
||||
error_code = err.get("code")
|
||||
message = err.get("message")
|
||||
except Exception:
|
||||
# 忽略解析失败
|
||||
pass
|
||||
|
||||
text = str(e)
|
||||
msg = message or text
|
||||
|
||||
# 特定错误映射
|
||||
# 欠费/账户状态异常
|
||||
if (error_code == "Arrearage") or (error_type == "Arrearage") or ("in good standing" in (msg or "")):
|
||||
return "账户欠费或结算异常,访问被拒绝。请检查账号状态或更换有效的 API Key。"
|
||||
# 鉴权失败
|
||||
if status_code == 401 or "invalid api key" in msg.lower():
|
||||
return "鉴权失败,API Key 无效或已过期。请检查系统配置中的 API Key。"
|
||||
# 权限不足或被拒绝
|
||||
if status_code == 403 or error_type in {"PermissionDenied", "permission_denied"}:
|
||||
return "访问被拒绝,权限不足或账号受限。请检查账户权限设置。"
|
||||
# 配额不足或限流
|
||||
if status_code == 429 or error_type in {"insufficient_quota", "rate_limit_exceeded"}:
|
||||
return "请求过于频繁或配额已用尽。请稍后重试或提升账户配额。"
|
||||
# 客户端错误
|
||||
if status_code == 400:
|
||||
return f"请求参数错误或服务拒绝:{message or '请检查输入内容。'}"
|
||||
# 服务端错误
|
||||
if status_code in {500, 502, 503, 504}:
|
||||
return "服务暂时不可用,请稍后重试。"
|
||||
|
||||
# 默认兜底
|
||||
return f"处理您的请求时出现错误:{msg}"
|
||||
|
||||
async def process(self, query: str) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
处理查询并返回流式响应
|
||||
|
||||
参数:
|
||||
- query (str): 用户查询。
|
||||
|
||||
返回:
|
||||
- AsyncGenerator[str, None]: 流式响应内容。
|
||||
"""
|
||||
system_prompt = """你是一个有用的AI助手,可以帮助用户回答问题和提供帮助。请用中文回答用户的问题。"""
|
||||
|
||||
try:
|
||||
# 使用 await 调用异步客户端
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 流式返回响应
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误,返回友好提示
|
||||
log.error(f"AI处理查询失败: {str(e)}")
|
||||
yield self._friendly_error_message(e)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
关闭客户端连接
|
||||
"""
|
||||
if hasattr(self, 'client'):
|
||||
await self.client.close()
|
||||
if hasattr(self, 'http_client'):
|
||||
await self.http_client.aclose()
|
||||
@@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
文档处理工具类
|
||||
支持 txt、pdf、doc、docx、md 格式解析
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, BinaryIO
|
||||
|
||||
from langchain.schema import Document
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理器"""
|
||||
|
||||
# 支持的文件扩展名
|
||||
SUPPORTED_EXTENSIONS = {'.txt', '.pdf', '.doc', '.docx', '.md'}
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, filename: str) -> bool:
|
||||
"""检查文件是否支持"""
|
||||
ext = Path(filename).suffix.lower()
|
||||
return ext in cls.SUPPORTED_EXTENSIONS
|
||||
|
||||
@classmethod
|
||||
async def process_upload_file(cls, file: UploadFile) -> List[Document]:
|
||||
"""
|
||||
处理上传的文件
|
||||
|
||||
参数:
|
||||
- file: FastAPI UploadFile 对象
|
||||
|
||||
返回:
|
||||
- 文档列表
|
||||
"""
|
||||
if not file.filename:
|
||||
log.warning("文件名为空,跳过处理")
|
||||
return []
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
|
||||
if not cls.is_supported(file.filename):
|
||||
log.warning(f"不支持的文件类型: {ext}")
|
||||
return []
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
await file.seek(0) # 重置文件指针
|
||||
|
||||
# 根据文件类型处理
|
||||
try:
|
||||
if ext == '.txt':
|
||||
return cls._process_txt(content, file.filename)
|
||||
elif ext == '.md':
|
||||
return cls._process_markdown(content, file.filename)
|
||||
elif ext == '.pdf':
|
||||
return await cls._process_pdf(content, file.filename)
|
||||
elif ext in {'.doc', '.docx'}:
|
||||
return await cls._process_word(content, file.filename, ext)
|
||||
else:
|
||||
log.warning(f"未知的文件类型: {ext}")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理文件失败: {file.filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_txt(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 TXT 文件"""
|
||||
try:
|
||||
# 尝试不同编码
|
||||
text = None
|
||||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||||
try:
|
||||
text = content.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
log.error(f"无法解码文件: {filename}")
|
||||
return []
|
||||
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "txt"}
|
||||
)]
|
||||
except Exception as e:
|
||||
log.error(f"处理 TXT 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_markdown(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 Markdown 文件"""
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "markdown"}
|
||||
)]
|
||||
except Exception as e:
|
||||
log.error(f"处理 Markdown 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def _process_pdf(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 PDF 文件"""
|
||||
try:
|
||||
# 使用 pypdf 或 pdfplumber 处理 PDF
|
||||
import pypdf
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(content)
|
||||
reader = pypdf.PdfReader(pdf_file)
|
||||
|
||||
documents = []
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
text = page.extract_text()
|
||||
if text and text.strip():
|
||||
documents.append(Document(
|
||||
page_content=text,
|
||||
metadata={
|
||||
"source": filename,
|
||||
"type": "pdf",
|
||||
"page": page_num + 1
|
||||
}
|
||||
))
|
||||
|
||||
log.info(f"PDF 文件处理完成: {filename}, 共 {len(documents)} 页")
|
||||
return documents
|
||||
|
||||
except ImportError:
|
||||
log.error("未安装 pypdf 库,请运行: pip install pypdf")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理 PDF 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def _process_word(cls, content: bytes, filename: str, ext: str) -> List[Document]:
|
||||
"""处理 Word 文件 (doc/docx)"""
|
||||
try:
|
||||
if ext == '.docx':
|
||||
return cls._process_docx(content, filename)
|
||||
else:
|
||||
# .doc 格式需要特殊处理
|
||||
return cls._process_doc(content, filename)
|
||||
except Exception as e:
|
||||
log.error(f"处理 Word 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_docx(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""处理 DOCX 文件"""
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
from io import BytesIO
|
||||
|
||||
docx_file = BytesIO(content)
|
||||
doc = DocxDocument(docx_file)
|
||||
|
||||
# 提取所有段落文本
|
||||
paragraphs = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
paragraphs.append(para.text)
|
||||
|
||||
text = '\n'.join(paragraphs)
|
||||
|
||||
if text.strip():
|
||||
return [Document(
|
||||
page_content=text,
|
||||
metadata={"source": filename, "type": "docx"}
|
||||
)]
|
||||
return []
|
||||
|
||||
except ImportError:
|
||||
log.error("未安装 python-docx 库,请运行: pip install python-docx")
|
||||
return []
|
||||
except Exception as e:
|
||||
log.error(f"处理 DOCX 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _process_doc(cls, content: bytes, filename: str) -> List[Document]:
|
||||
"""
|
||||
处理 DOC 文件 (旧版 Word 格式)
|
||||
注意: .doc 格式处理需要额外依赖,这里做简单提示
|
||||
"""
|
||||
try:
|
||||
# 尝试使用 antiword 或 textract
|
||||
# 如果没有安装,建议用户转换为 docx 格式
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.doc', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# 尝试使用 antiword
|
||||
result = subprocess.run(
|
||||
['antiword', tmp_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return [Document(
|
||||
page_content=result.stdout,
|
||||
metadata={"source": filename, "type": "doc"}
|
||||
)]
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
log.warning(f"无法处理 .doc 文件: {filename},建议转换为 .docx 格式")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"处理 DOC 文件失败: {filename}, 错误: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def process_files(cls, files: List[UploadFile]) -> List[Document]:
|
||||
"""
|
||||
批量处理上传的文件
|
||||
|
||||
参数:
|
||||
- files: UploadFile 列表
|
||||
|
||||
返回:
|
||||
- 所有文档列表
|
||||
"""
|
||||
all_documents = []
|
||||
|
||||
for file in files:
|
||||
if file.filename and cls.is_supported(file.filename):
|
||||
docs = await cls.process_upload_file(file)
|
||||
all_documents.extend(docs)
|
||||
log.info(f"文件处理完成: {file.filename}, 提取 {len(docs)} 个文档")
|
||||
else:
|
||||
log.warning(f"跳过不支持的文件: {file.filename}")
|
||||
|
||||
log.info(f"批量处理完成,共 {len(all_documents)} 个文档")
|
||||
return all_documents
|
||||
@@ -0,0 +1,239 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量化工具类
|
||||
支持本地和远程 embedding 模型
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import chromadb
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_chroma import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.schema import Document
|
||||
|
||||
from app.core.logger import log
|
||||
from app.config.path_conf import BASE_DIR
|
||||
|
||||
|
||||
# ChromaDB 持久化目录
|
||||
CHROMA_PERSIST_DIR = BASE_DIR / "data" / "chroma_db"
|
||||
|
||||
# 全局 ChromaDB 客户端实例(单例模式)
|
||||
_chroma_client = None
|
||||
|
||||
|
||||
def get_chroma_client():
|
||||
"""获取全局 ChromaDB 客户端实例(单例模式)"""
|
||||
global _chroma_client
|
||||
if _chroma_client is None:
|
||||
# 确保持久化目录存在
|
||||
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_chroma_client = chromadb.PersistentClient(
|
||||
path=str(CHROMA_PERSIST_DIR)
|
||||
)
|
||||
return _chroma_client
|
||||
|
||||
|
||||
class EmbeddingUtil:
|
||||
"""向量化工具类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type: int = 0,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化向量化工具
|
||||
|
||||
参数:
|
||||
- embedding_type: 0=本地, 1=远程
|
||||
- model_name: Embedding模型名称
|
||||
- base_url: 远程接口地址(远程模式必填)
|
||||
- api_key: 远程API Key(远程模式必填)
|
||||
"""
|
||||
self.embedding_type = embedding_type
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
# 初始化 embedding 模型
|
||||
self._embeddings = None
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""延迟加载 embedding 模型"""
|
||||
if self._embeddings is None:
|
||||
self._embeddings = self._create_embeddings()
|
||||
return self._embeddings
|
||||
|
||||
def _create_embeddings(self):
|
||||
"""创建 embedding 模型实例"""
|
||||
if self.embedding_type == 0:
|
||||
# 本地模式 - 使用 sentence-transformers
|
||||
# sentence-transformers==3.3.1
|
||||
# 本地Embedding模型(可选)手动安装 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
try:
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
log.info(f"使用本地 Embedding 模型: {self.model_name}")
|
||||
return HuggingFaceEmbeddings(
|
||||
model_name=self.model_name,
|
||||
model_kwargs={'device': 'cpu'},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"加载本地 Embedding 模型失败: {e}")
|
||||
raise
|
||||
else:
|
||||
# 远程模式 - 使用 OpenAI 兼容接口
|
||||
if not self.base_url or not self.api_key:
|
||||
raise ValueError("远程模式必须提供 base_url 和 api_key")
|
||||
|
||||
# 自动拼接 /v1 路径(如果未包含)
|
||||
api_base = self.base_url.rstrip('/')
|
||||
if not api_base.endswith('/v1'):
|
||||
api_base = f"{api_base}/v1"
|
||||
|
||||
log.info(f"使用远程 Embedding 模型: {self.model_name}, URL: {api_base}")
|
||||
return OpenAIEmbeddings(
|
||||
model=self.model_name,
|
||||
base_url=api_base,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
def get_vector_store(self, collection_name: str) -> Chroma:
|
||||
"""
|
||||
获取或创建向量存储
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- Chroma 向量存储实例
|
||||
"""
|
||||
# 确保目录存在
|
||||
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return Chroma(
|
||||
collection_name=collection_name,
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory=str(CHROMA_PERSIST_DIR),
|
||||
client=get_chroma_client(),
|
||||
)
|
||||
|
||||
def split_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200
|
||||
) -> List[Document]:
|
||||
"""
|
||||
分割文档
|
||||
|
||||
参数:
|
||||
- documents: 文档列表
|
||||
- chunk_size: 分块大小
|
||||
- chunk_overlap: 分块重叠大小
|
||||
|
||||
返回:
|
||||
- 分割后的文档列表
|
||||
"""
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
|
||||
)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
collection_name: str,
|
||||
documents: List[Document]
|
||||
) -> int:
|
||||
"""
|
||||
添加文档到向量存储
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
- documents: 文档列表
|
||||
|
||||
返回:
|
||||
- 添加的向量数量
|
||||
"""
|
||||
if not documents:
|
||||
return 0
|
||||
|
||||
# 分割文档
|
||||
split_docs = self.split_documents(documents)
|
||||
log.info(f"文档分割完成,共 {len(split_docs)} 个片段")
|
||||
|
||||
# 获取向量存储
|
||||
vector_store = self.get_vector_store(collection_name)
|
||||
|
||||
# 添加文档
|
||||
vector_store.add_documents(split_docs)
|
||||
log.info(f"向量存储完成,集合: {collection_name}, 向量数: {len(split_docs)}")
|
||||
|
||||
return len(split_docs)
|
||||
|
||||
def delete_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
删除集合
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- 是否删除成功
|
||||
"""
|
||||
try:
|
||||
client = get_chroma_client()
|
||||
client.delete_collection(collection_name)
|
||||
log.info(f"删除集合成功: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"删除集合失败: {collection_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
k: int = 4
|
||||
) -> List[Document]:
|
||||
"""
|
||||
相似度搜索
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
- query: 查询文本
|
||||
- k: 返回结果数量
|
||||
|
||||
返回:
|
||||
- 相似文档列表
|
||||
"""
|
||||
vector_store = self.get_vector_store(collection_name)
|
||||
return vector_store.similarity_search(query, k=k)
|
||||
|
||||
def get_collection_count(self, collection_name: str) -> int:
|
||||
"""
|
||||
获取集合中的向量数量
|
||||
|
||||
参数:
|
||||
- collection_name: 集合名称
|
||||
|
||||
返回:
|
||||
- 向量数量
|
||||
"""
|
||||
try:
|
||||
client = get_chroma_client()
|
||||
collection = client.get_collection(collection_name)
|
||||
return collection.count()
|
||||
except Exception as e:
|
||||
log.warning(f"获取集合数量失败: {collection_name}, 错误: {e}")
|
||||
return 0
|
||||
Reference in New Issue
Block a user