通用AI聊天平台实现3:RAG知识库扩展
2025-07-26 15:18:29一、实现多格式文档解析器
1. 通用文档解析接口
我们先在ai文件夹创建一个名为parsers的pathon包,用来实现多格式文档解析器。
接着我们先定义一个解析器接口
# app/ai/parsers/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from pydantic import BaseModel
class DocumentChunk(BaseModel):
"""文档块"""
content: str
metadata: Dict[str, Any]
chunk_index: int
class DocumentParser(ABC):
"""文档解析器抽象基类"""
@abstractmethod
def parse(self, file_path: str, file_content: bytes) -> List[DocumentChunk]:
"""解析文档并返回文档块列表"""
pass
- DocumentChunk是文档解析器返回类型。
- DocumentParser是抽象基类,通过继承ABC来定义。
2. 实现文本解析器
# app/ai/parsers/txt_parser.py
from typing import List
from langchain_text_splitters import RecursiveCharacterTextSplitter
from app.ai.parsers.base import DocumentParser, DocumentChunk
class TextParser(DocumentParser):
"""文本文件解析器"""
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
"""
初始化文本解析器
Args:
chunk_size: 文档块大小
chunk_overlap: 文档块重叠大小
"""
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
def parse(self, file_path: str, file_content: bytes) -> List[DocumentChunk]:
"""解析文本文件"""
try:
# 解码文件内容
text_content = file_content.decode('utf-8')
except UnicodeDecodeError:
# 如果UTF-8解码失败,尝试其他编码
try:
text_content = file_content.decode('gbk')
except UnicodeDecodeError:
text_content = file_content.decode('utf-8', errors='ignore')
# 分割文本
chunks = self.text_splitter.split_text(text_content)
# 转换为DocumentChunk对象
document_chunks = []
for i, chunk in enumerate(chunks):
document_chunks.append(DocumentChunk(
content=chunk,
metadata={
"file_path": file_path,
"file_type": "text",
"chunk_size": len(chunk)
},
chunk_index=i
))
return document_chunks
- TextParser继承抽象基类DocumentParser,并实现parse函数。
- 解析文本时会先按 utf-8 格式解析,如果失败再尝试 gbk 格式。如果都失败了,继续使用接受忽略解析失败文本的 utf-8 格式来解析。
3. 实现Markdown解析器
# app/ai/parsers/markdown_parser.py
from typing import List, Dict, Any
from langchain.text_splitter import MarkdownHeaderTextSplitter
from app.ai.parsers.base import DocumentParser, DocumentChunk
class MarkdownParser(DocumentParser):
"""Markdown文档解析器"""
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
"""
初始化Markdown解析器
Args:
chunk_size: 文档块大小
chunk_overlap: 文档块重叠大小
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# 定义标题层级
self.headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
]
# 初始化MarkdownHeaderTextSplitter
self.markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=self.headers_to_split_on,
strip_headers=False
)
def parse(self, file_path: str, file_content: bytes) -> List[DocumentChunk]:
"""
解析Markdown文档
Args:
file_path: 文件路径
file_content: 文件内容(字节)
Returns:
List[DocumentChunk]: 文档块列表
"""
try:
# 解码文件内容
content = file_content.decode('utf-8')
# 分割文档
md_header_splits = self.markdown_splitter.split_text(content)
chunks = []
chunk_index = 0
for split in md_header_splits:
# 获取分割后的内容和元数据
split_content = split.page_content
split_metadata = split.metadata
if not split_content.strip():
continue
# 如果内容太长,进一步分割
if len(split_content) > self.chunk_size:
sub_chunks = self._split_long_content(split_content, split_metadata)
for sub_chunk in sub_chunks:
chunks.append(DocumentChunk(
content=sub_chunk['content'],
metadata={
**split_metadata,
**sub_chunk['metadata'],
'file_path': file_path,
'chunk_index': chunk_index,
'parser_type': 'markdown'
},
chunk_index=chunk_index
))
chunk_index += 1
else:
chunks.append(DocumentChunk(
content=split_content,
metadata={
**split_metadata,
'file_path': file_path,
'chunk_index': chunk_index,
'parser_type': 'markdown'
},
chunk_index=chunk_index
))
chunk_index += 1
return chunks
except Exception as e:
raise Exception(f"解析Markdown文档失败: {str(e)}")
def _split_long_content(self, content: str, base_metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
分割过长的内容
Args:
content: 内容
base_metadata: 基础元数据
Returns:
List[Dict]: 分割后的内容块
"""
chunks = []
# 按段落分割
paragraphs = content.split('\n\n')
current_chunk = ''
chunk_count = 0
for paragraph in paragraphs:
# 如果添加这个段落会超过大小限制
if len(current_chunk) + len(paragraph) + 2 > self.chunk_size and current_chunk:
# 保存当前块
chunks.append({
'content': current_chunk.strip(),
'metadata': {
**base_metadata,
'sub_chunk_index': chunk_count,
'is_sub_chunk': True
}
})
chunk_count += 1
current_chunk = paragraph
else:
if current_chunk:
current_chunk += '\n\n' + paragraph
else:
current_chunk = paragraph
# 添加最后一个块
if current_chunk.strip():
chunks.append({
'content': current_chunk.strip(),
'metadata': {
**base_metadata,
'sub_chunk_index': chunk_count,
'is_sub_chunk': True
}
})
return chunks
- 这里解码文件内容直接用 utf-8 格式了,可以参考文本解析器来保底格式转换。
- strip_headers=False 是保留标题。
- 代码比较长,不过逻辑很简单。当内容过多时会调用 _split_long_content 进一步解析。
4. 实现HTML解析器
# app/ai/parsers/html_parser.py
from typing import List
from bs4 import BeautifulSoup
from langchain_text_splitters import HTMLHeaderTextSplitter, RecursiveCharacterTextSplitter
from app.ai.parsers.base import DocumentParser, DocumentChunk
class HTMLParser(DocumentParser):
"""HTML文件解析器"""
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
"""
初始化HTML解析器
"""
# 配置HTML标题分割器
self.html_splitter = HTMLHeaderTextSplitter(
headers_to_split_on=[
("h1", "Header 1"),
("h2", "Header 2"),
("h3", "Header 3"),
("h4", "Header 4"),
("h5", "Header 5"),
("h6", "Header 6"),
]
)
# 用于进一步分割长内容的文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
def parse(self, file_path: str, file_content: bytes) -> List[DocumentChunk]:
"""解析HTML文件"""
try:
# 解码文件内容
html_content = file_content.decode('utf-8')
# 使用BeautifulSoup解析HTML获取标题信息
soup = BeautifulSoup(html_content, 'html.parser')
title = soup.title.string if soup.title else None
# 使用HTMLHeaderTextSplitter分割文档
header_splits = self.html_splitter.split_text(html_content)
# 转换为DocumentChunk对象
document_chunks = []
for i, split in enumerate(header_splits):
# 如果内容太长,进一步分割
if len(split.page_content) > self.text_splitter._chunk_size:
sub_chunks = self.text_splitter.split_text(split.page_content)
for j, sub_chunk in enumerate(sub_chunks):
# 合并元数据
metadata = {
"file_path": file_path,
"file_type": "html",
"chunk_size": len(sub_chunk),
"title": title,
"chunk_index": i,
"sub_chunk_index": j,
"parser_type": "html"
}
# 添加标题元数据
metadata.update(split.metadata)
document_chunks.append(DocumentChunk(
content=sub_chunk,
metadata=metadata,
chunk_index=len(document_chunks)
))
else:
# 合并元数据
metadata = {
"file_path": file_path,
"file_type": "html",
"chunk_size": len(split.page_content),
"title": title,
"chunk_index": i,
"parser_type": "html"
}
# 添加标题元数据
metadata.update(split.metadata)
document_chunks.append(DocumentChunk(
content=split.page_content,
metadata=metadata,
chunk_index=i
))
return document_chunks
except Exception as e:
raise ValueError(f"HTML解析失败: {str(e)}")
5. 实现PDF解析器
# app/ai/parsers/pdf_parser.py
from typing import List
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from app.ai.parsers.base import DocumentParser, DocumentChunk
class PDFParser(DocumentParser):
"""PDF文件解析器"""
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
"""
初始化PDF解析器
"""
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
def parse(self, file_path: str, file_content: bytes) -> List[DocumentChunk]:
"""解析PDF文件"""
try:
# 使用PyPDFLoader加载PDF
loader = PyPDFLoader(file_path)
documents = loader.load()
# 合并所有页面的内容
text_content = ""
total_pages = len(documents)
for doc in documents:
text_content += doc.page_content + "\n"
# 分割文本
chunks = self.text_splitter.split_text(text_content)
# 转换为DocumentChunk对象
document_chunks = []
for i, chunk in enumerate(chunks):
document_chunks.append(DocumentChunk(
content=chunk,
metadata={
"file_path": file_path,
"file_type": "pdf",
"chunk_size": len(chunk),
"total_pages": total_pages,
"parser_type": "pdf"
},
chunk_index=i
))
return document_chunks
except Exception as e:
raise ValueError(f"PDF解析失败: {str(e)}")
6. 完成工厂调用
# app/ai/parsers/factory.py
from typing import List
from app.ai.parsers.base import DocumentParser
from app.ai.parsers.text_parser import TextParser
from app.ai.parsers.pdf_parser import PDFParser
from app.ai.parsers.html_parser import HTMLParser
from app.ai.parsers.markdown_parser import MarkdownParser
class DocumentParserFactory:
"""文档解析器工厂类"""
_parsers = {
'text': TextParser,
'pdf': PDFParser,
'html': HTMLParser,
'markdown': MarkdownParser,
}
@staticmethod
def create_parser(parser_type: str, **kwargs) -> DocumentParser:
"""
创建文档解析器实例
Args:
parser_type: 解析器类型,支持 'text', 'pdf', 'html', 'markdown'
**kwargs: 其他参数,chunk_size=500, chunk_overlap=50
Returns:
DocumentParser: 文档解析器实例
Raises:
ValueError: 不支持的parser_type
"""
if parser_type not in DocumentParserFactory._parsers:
raise ValueError(f"不支持的文档解析器类型: {parser_type}")
return DocumentParserFactory._parsers[parser_type](**kwargs)
@staticmethod
def get_parser_by_extension(file_extension: str, **kwargs) -> DocumentParser:
"""
根据文件扩展名获取解析器
Args:
file_extension: 文件扩展名(包含点号,如'.pdf')
**kwargs: 其他参数
Returns:
DocumentParser: 文档解析器实例
Raises:
ValueError: 不支持的文件扩展名
"""
extension_mapping = {
'.txt': 'text',
'.md': 'markdown',
'.markdown': 'markdown',
'.pdf': 'pdf',
'.html': 'html',
'.htm': 'html',
}
parser_type = extension_mapping.get(file_extension.lower())
if not parser_type:
raise ValueError(f"不支持的文件扩展名: {file_extension}")
return DocumentParserFactory.create_parser(parser_type, **kwargs)
@staticmethod
def get_supported_extensions() -> List[str]:
"""获取所有支持的文件扩展名"""
return ['.txt', '.md', '.markdown', '.pdf', '.html', '.htm']
@staticmethod
def get_available_parsers() -> List[str]:
"""获取可用的解析器类型"""
return list(DocumentParserFactory._parsers.keys())
二、实现向量嵌入服务
1. 通用向量计算接口
我们在ai文件夹继续创建一个名为 embeddings 的包,用来实现向量嵌入服务
接着我们先定义一个向量计算接口
# app/ai/embeddings/base.py
from abc import ABC, abstractmethod
from langchain_core.embeddings import Embeddings
from typing import List
class EmbeddingProvider(ABC):
"""向量计算服务抽象基类"""
@abstractmethod
async def embed_text(self, text: str) -> List[float]:
"""将单个文本转换为向量"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""将多个文本批量转换为向量"""
pass
@abstractmethod
def get_embedding_dimension(self) -> int:
"""获取向量维度"""
pass
@abstractmethod
def get_model_name(self) -> str:
"""获取模型名称"""
pass
def get_model(self) -> Embeddings:
"""获取 LangChain 嵌入模型实例"""
pass
2. 实现云端嵌入服务
uv add dashscopse
# app/ai/embeddings/dashscope.py
from typing import List, Optional
from langchain_core.embeddings import Embeddings
from langchain_community.embeddings import DashScopeEmbeddings
from app.ai.embeddings.base import EmbeddingProvider
class DashScopeEmbedding(EmbeddingProvider):
"""使用LangChain DashScopeEmbeddings的云端向量计算服务"""
def __init__(
self,
model: str = "text-embedding-v4",
dashscope_api_key: Optional[str] = None # DASHSCOPE_API_KEY
):
"""
初始化DashScope嵌入模型
"""
self.dimension = 1024
self.model_name = model
self.dashscope_api_key = dashscope_api_key
# 初始化DashScope嵌入模型
self.model = DashScopeEmbeddings(
model=model,
dashscope_api_key=dashscope_api_key
)
async def embed_text(self, text: str) -> List[float]:
"""将单个文本转换为向量"""
embedding = await self.model.aembed_query(text)
return embedding
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""将多个文本批量转换为向量"""
embedding = await self.model.aembed_documents(texts)
return embedding
def get_embedding_dimension(self) -> int:
"""获取向量维度"""
return self.dimension
def get_model_name(self) -> str:
"""获取模型名称"""
return self.model_name
def get_model(self) -> Embeddings:
"""获取 LangChain 嵌入模型实例"""
return self.model
3. 本地嵌入服务
本地我们使用阿里云的离线模型包
uv add langchain_huggingface
uv add sectence-transformers
# app/ai/embeddings/huggingface.py
from typing import List, Dict, Any, Optional
from langchain_core.embeddings import Embeddings
from langchain_huggingface import HuggingFaceEmbeddings
from app.ai.embeddings.base import EmbeddingProvider
class HuggingFaceEmbedding(EmbeddingProvider):
"""使用LangChain HuggingFaceEmbeddings的本地向量计算服务"""
def __init__(
self,
model_name: str,
model_kwargs: Optional[Dict[str, Any]] = None,
encode_kwargs: Optional[Dict[str, Any]] = None,
dimension: int = 1024
):
"""
初始化HuggingFace嵌入模型
"""
self.model_name = model_name
self.model_kwargs = model_kwargs or {}
self.encode_kwargs = encode_kwargs or {}
self.dimension = dimension
# 初始化HuggingFace嵌入模型
self.model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
async def embed_text(self, text: str) -> List[float]:
"""将单个文本转换为向量"""
embedding = await self.model.aembed_query(text)
return embedding
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""将多个文本批量转换为向量"""
embedding = await self.model.aembed_documents(texts)
return embedding
def get_embedding_dimension(self) -> int:
"""获取向量维度"""
return self.dimension
def get_model_name(self) -> str:
"""获取模型名称"""
return self.model_name
def get_model(self) -> Embeddings:
"""获取 LangChain 嵌入模型实例"""
return self.model
4. 完成工厂调用
# app/ai/embeddings/factory.py
from typing import Dict, Any, Optional
from app.ai.embeddings.base import EmbeddingProvider
from app.ai.embeddings.huggingface import HuggingFaceEmbedding
from app.ai.embeddings.dashscope import DashScopeEmbedding
class EmbeddingFactory:
"""向量计算服务工厂类"""
@staticmethod
def create_embedding_provider(
provider_type: str = "huggingface",
model_name: str = None,
dimension: int = 1024,
hf_model_kwargs: Optional[Dict[str, Any]] = None,
hf_encode_kwargs: Optional[Dict[str, Any]] = None,
dashscope_api_key: Optional[str] = None
) -> EmbeddingProvider:
"""
创建向量计算服务实例
"""
if model_name is None:
raise ValueError("模型名称不能为空")
if provider_type == "huggingface":
return HuggingFaceEmbedding(
model_name=model_name,
model_kwargs=hf_model_kwargs or {},
encode_kwargs=hf_encode_kwargs or {},
dimension=dimension
)
elif provider_type == "dashscope":
if not dashscope_api_key:
raise ValueError("DashScope API key 不能为空")
return DashScopeEmbedding(
model=model_name,
dashscope_api_key=dashscope_api_key,
)
else:
raise ValueError(f"未知类型: {provider_type}")
@staticmethod
def get_available_providers() -> list:
"""获取可用的向量计算服务类型"""
return ["huggingface", "dashscope"]
三、实现向量存储服务
1. 通用向量存储接口
# app/ai/vector_store/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
class VectorSearchResult(BaseModel):
"""向量搜索结果"""
id: str
content: str
metadata: Dict[str, Any]
score: float
class VectorStoreProvider(ABC):
"""向量存储服务抽象基类"""
@abstractmethod
async def add_documents(
self,
documents: List[str],
metadatas: List[Dict[str, Any]],
ids: List[str]
) -> List[str]:
"""添加文档到向量存储"""
pass
@abstractmethod
async def search(
self,
query_vector: List[float],
k: int = 5,
filter: Optional[Dict[str, Any]] = None
) -> List[VectorSearchResult]:
"""搜索相似向量"""
pass
@abstractmethod
async def delete_documents(self, ids: List[str]) -> bool:
"""删除文档"""
pass
@abstractmethod
async def get_collection_info(self) -> Dict[str, Any]:
"""获取集合信息"""
pass
@abstractmethod
async def clear_collection(self) -> bool:
"""清空集合"""
pass
2. 实现 Chroma 存储
uv add langchain_chroma
# app/ai/vector_store/chroma.py
import uuid
from typing import List, Dict, Any, Optional
from langchain_chroma import Chroma
from langchain_core.documents import Document
from app.ai.vector_store.base import VectorStoreProvider, VectorSearchResult
class ChromaVectorStore(VectorStoreProvider):
"""使用ChromaDB的向量存储服务"""
def __init__(
self,
collection_name: str = "documents",
persist_directory: str = "./chroma_db",
embedding_function=None
):
"""
初始化ChromaDB向量存储
"""
self.collection_name = collection_name
self.persist_directory = persist_directory
# 初始化LangChain Chroma向量存储
self.vector_store = Chroma(
collection_name=collection_name,
persist_directory=persist_directory,
embedding_function=embedding_function
)
async def add_documents(
self,
documents: List[str],
metadatas: List[Dict[str, Any]],
ids: List[str]
) -> List[str]:
"""添加文档到向量存储"""
# 如果没有提供ID,生成UUID
if not ids:
ids = [str(uuid.uuid4()) for _ in documents]
# 创建LangChain Document对象
langchain_docs = []
for i, doc in enumerate(documents):
langchain_docs.append(Document(
page_content=doc,
metadata=metadatas[i] if i < len(metadatas) else {}
))
# 添加文档到向量存储
self.vector_store.add_documents(langchain_docs, ids=ids)
return ids
async def search(
self,
query_vector: List[float],
k: int = 5,
filter: Optional[Dict[str, Any]] = None
) -> List[VectorSearchResult]:
"""搜索相似向量"""
# 使用LangChain的相似性搜索
# 注意:LangChain Chroma的search方法需要查询文本,而不是向量
# 这里我们需要使用similarity_search_by_vector方法
try:
# 使用similarity_search_by_vector进行向量搜索
results = self.vector_store.similarity_search_by_vector(
embedding=query_vector,
k=k,
filter=filter
)
# 转换结果格式
search_results = []
for doc in results:
# 从文档中提取信息
search_results.append(VectorSearchResult(
id=getattr(doc, 'id', str(uuid.uuid4())),
content=doc.page_content,
metadata=doc.metadata,
score=getattr(doc, 'score', 0.0)
))
return search_results
except Exception as e:
# 如果similarity_search_by_vector不可用,使用文本搜索作为备选
print(f"向量搜索失败,使用文本搜索: {e}")
return []
async def delete_documents(self, ids: List[str]) -> bool:
"""删除文档"""
try:
self.vector_store.delete(ids=ids)
return True
except Exception as e:
print(f"删除文档失败: {e}")
return False
async def get_collection_info(self) -> Dict[str, Any]:
"""获取集合信息"""
try:
# 获取集合中的文档数量
# 注意:LangChain Chroma可能没有直接的count方法
# 我们可以通过搜索所有文档来获取数量
all_docs = self.vector_store.similarity_search("", k=10000) # 获取大量文档
count = len(all_docs)
return {
"collection_name": self.collection_name,
"document_count": count,
"persist_directory": self.persist_directory
}
except Exception as e:
return {
"collection_name": self.collection_name,
"document_count": 0,
"error": str(e)
}
async def clear_collection(self) -> bool:
"""清空集合"""
try:
self.vector_store.delete_collection()
return True
except Exception as e:
print(f"清空集合失败: {e}")
return False
3. 完成工厂调用
# app/ai/vector_store/factory.py
from app.ai.vector_store.base import VectorStoreProvider
from app.ai.vector_store.chroma import ChromaVectorStore
class VectorStoreFactory:
"""向量存储服务工厂类"""
@staticmethod
def create_vector_store(
store_type: str,
**kwargs
) -> VectorStoreProvider:
"""
创建向量存储服务实例
"""
if store_type == "chroma":
return ChromaVectorStore(**kwargs)
else:
raise ValueError(f"不支持的向量存储类型: {store_type}")
@staticmethod
def get_available_stores() -> list:
"""获取可用的向量存储类型"""
return ["chroma"]
四、实现文档向量化服务
1. 文档向量化的功能描述
- 编排整个文档处理流程
- 封装复杂性
- 实现依赖注入
2. 文档向量化服务
# app/ai/vectorization/service.py
import uuid
from typing import List, Dict, Any, Optional
from pathlib import Path
from app.ai.embeddings.base import EmbeddingProvider
from app.ai.vector_store.base import VectorStoreProvider
from app.ai.parsers.factory import DocumentParserFactory
class DocumentVectorizationService:
"""文档向量化服务"""
def __init__(
self,
embedding_provider: EmbeddingProvider,
vector_store: VectorStoreProvider,
):
"""
初始化文档向量化服务
"""
self.embedding_provider = embedding_provider
self.vector_store = vector_store
async def vectorize_document(
self,
file_path: str,
file_content: bytes,
document_id: int,
knowledge_base_id: int,
filename: str,
chunk_size: int = 1000,
chunk_overlap: int = 200
) -> Dict[str, Any]:
"""
向量化文档
"""
try:
# 1. 解析文档
file_extension = Path(filename).suffix.lower()
parser = DocumentParserFactory.get_parser_by_extension(
file_extension,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
document_chunks = parser.parse(file_path, file_content)
if not document_chunks:
raise ValueError("文档解析后没有内容")
# 2. 提取文本内容
texts = [chunk.content for chunk in document_chunks]
# 准备元数据
metadatas = []
ids = []
for i, chunk in enumerate(document_chunks):
metadata = {
**chunk.metadata,
"document_id": document_id,
"knowledge_base_id": knowledge_base_id,
"filename": filename,
"chunk_index": chunk.chunk_index,
"total_chunks": len(document_chunks)
}
metadatas.append(metadata)
# 生成唯一的块ID
chunk_id = f"{document_id}_{i}_{uuid.uuid4().hex[:8]}"
ids.append(chunk_id)
# 3. 存储到向量数据库
stored_ids = await self.vector_store.add_documents(
documents=texts,
metadatas=metadatas,
ids=ids
)
return {
"success": True,
"document_id": document_id,
"knowledge_base_id": knowledge_base_id,
"chunks_count": len(document_chunks),
"stored_ids": stored_ids,
"embedding_dimension": self.embedding_provider.get_embedding_dimension(),
"model_name": self.embedding_provider.get_model_name()
}
except Exception as e:
return {
"success": False,
"error": str(e),
"document_id": document_id,
"knowledge_base_id": knowledge_base_id
}
async def search_similar_documents(
self,
query: str,
knowledge_base_ids: List[int],
k: int = 5,
filter_metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
搜索相似文档
Args:
query: 查询文本
knowledge_base_ids: 知识库ID列表
k: 返回结果数量
filter_metadata: 额外的过滤条件
Returns:
List[Dict[str, Any]]: 搜索结果
"""
try:
# 1. 计算查询向量
query_embedding = await self.embedding_provider.embed_text(query)
# 2. 构建过滤条件
# 支持按多个知识库ID过滤
filter_condition = {
"knowledge_base_id": {"$in": knowledge_base_ids}
} if knowledge_base_ids else {}
if filter_metadata:
filter_condition.update(filter_metadata)
# 3. 搜索相似向量
search_results = await self.vector_store.search(
query_vector=query_embedding,
k=k,
filter=filter_condition
)
# 4. 格式化结果
results = []
for result in search_results:
results.append({
"chunk_id": result.id,
"content": result.content,
"score": result.score,
"metadata": result.metadata,
"document_id": result.metadata.get("document_id"),
"chunk_index": result.metadata.get("chunk_index")
})
return results
except Exception as e:
raise ValueError(f"搜索相似文档失败: {str(e)}")
async def delete_document_vectors(
self,
document_id: int
) -> bool:
"""
删除文档的所有向量
Args:
document_id: 文档ID
Returns:
bool: 是否删除成功
"""
try:
# 搜索该文档的所有向量
# 这里需要先搜索,然后删除,因为ChromaDB没有直接按metadata删除的功能
# 在实际应用中,可能需要维护一个文档ID到向量ID的映射
# 暂时返回True,实际实现需要根据具体的向量存储服务调整
return True
except Exception as e:
print(f"删除文档向量失败: {e}")
return False
async def get_vector_store_info(self) -> Dict[str, Any]:
"""获取向量存储信息"""
try:
collection_info = await self.vector_store.get_collection_info()
return {
"embedding_provider": {
"model_name": self.embedding_provider.get_model_name(),
"dimension": self.embedding_provider.get_embedding_dimension()
},
"vector_store": collection_info
}
except Exception as e:
return {
"error": str(e)
}
3. 实现文档向量化服务工厂
# app/ai/vectorization/factory.py
from app.ai.embeddings.factory import EmbeddingFactory
from app.ai.vector_store.factory import VectorStoreFactory
from app.ai.vectorization.service import DocumentVectorizationService
class VectorizationServiceFactory:
"""向量化服务工厂类"""
@staticmethod
def create_vectorization_service(
embedding_provider: str,
embedding_model_name: str,
hf_model_kwargs: dict,
hf_encode_kwargs: dict,
dashscope_api_key: str,
vector_store_type: str,
chroma_persist_directory: str,
) -> DocumentVectorizationService:
"""创建文档向量化服务
"""
# 创建嵌入模型
embedding_provider_instance = EmbeddingFactory.create_embedding_provider(
provider_type=embedding_provider,
model_name=embedding_model_name,
hf_model_kwargs=hf_model_kwargs,
hf_encode_kwargs=hf_encode_kwargs,
dashscope_api_key=dashscope_api_key,
)
# 创建向量存储
vector_store = VectorStoreFactory.create_vector_store(
store_type=vector_store_type,
collection_name="documents",
persist_directory=chroma_persist_directory,
embedding_function=embedding_provider_instance.get_model()
)
# 创建向量化服务
return DocumentVectorizationService(
embedding_provider=embedding_provider_instance,
vector_store=vector_store
)
4. 添加配置项
# app/.env
# 嵌入模型配置
# 提供商类型:huggingface, dashscope
EMBEDDING_PROVIDER=huggingface
# 模型名称(根据提供商类型选择)
EMBEDDING_MODEL_NAME='D:\Repos\Tools\Qwen3-Embedding-0.6B'
# HuggingFace 配置
# 模型加载参数(JSON格式)
HF_MODEL_KWARGS='{"device": "cuda"}'
# 编码参数(JSON格式)
HF_ENCODE_KWARGS='{"normalize_embeddings": true}'
# DashScope 配置
# DASHSCOPE_API_KEY=your_dashscope_api_key_here
# 向量存储配置
VECTOR_STORE=chroma
CHROMA_PERSIST_DIR=./chroma_db
5. 读取配置
# app/ai/config.py
# 嵌入模型配置
EMBEDDING_PROVIDER = os.getenv("EMBEDDING_PROVIDER", "huggingface") # huggingface, dashscope
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "./Qwen3-Embedding-0.6B")
# HuggingFace 配置
HF_MODEL_KWARGS = json.loads(os.getenv("HF_MODEL_KWARGS", "{}"))
HF_ENCODE_KWARGS = json.loads(os.getenv("HF_ENCODE_KWARGS", "{}"))
# DashScope 配置
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
# 向量存储配置
VECTOR_STORE = os.getenv("VECTOR_STORE", "chroma")
CHROMA_PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "./chroma_db")
6. 实现缓存
- 提供标准化方法
- 提供一个极其方便的函数,获取唯一示例
在工厂类的最后添加如下代码(注意是不要有缩进,是一个一级函数)
# app/ai/vectorization/factory.py
from functools import lru_cache
from app.config import (
EMBEDDING_PROVIDER, EMBEDDING_MODEL_NAME,
HF_MODEL_KWARGS, HF_ENCODE_KWARGS,
DASHSCOPE_API_KEY,
VECTOR_STORE, CHROMA_PERSIST_DIR
)
# Least Recently Used Cache
@lru_cache()
def get_vectorization_service() -> DocumentVectorizationService:
"""获取文档向量化服务"""
return VectorizationServiceFactory.create_vectorization_service(
embedding_provider=EMBEDDING_PROVIDER,
embedding_model_name=EMBEDDING_MODEL_NAME,
hf_model_kwargs=HF_MODEL_KWARGS,
hf_encode_kwargs=HF_ENCODE_KWARGS,
dashscope_api_key=DASHSCOPE_API_KEY,
vector_store_type=VECTOR_STORE,
chroma_persist_directory=CHROMA_PERSIST_DIR
)
五、实现向量服务API
1. 添加业务层向量服务DTO
# app/schemas/vector.py
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field
class VectorSearchRequest(BaseModel):
"""向量搜索请求DTO模型"""
query: str = Field(..., description="搜索查询")
knowledge_base_ids: list[int] = Field(..., description="知识库ID列表")
k: int = Field(5, description="返回结果数量", ge=1, le=200)
class VectorSearchResult(BaseModel):
"""向量搜索结果DTO模型"""
chunk_id: str = Field(..., description="文档块ID")
content: str = Field(..., description="文档块内容")
score: float = Field(..., description="相似度分数")
document_id: int = Field(..., description="文档ID")
chunk_index: int = Field(..., description="文档块索引")
metadata: dict = Field(..., description="元数据")
class VectorSearchResponse(BaseModel):
"""向量搜索响应DTO模型"""
query: str = Field(..., description="搜索查询")
knowledge_base_ids: list[int] = Field(..., description="知识库ID列表")
results: list[VectorSearchResult] = Field(..., description="搜索结果")
total_results: int = Field(..., description="总结果数量")
class VectorizationStatusResponse(BaseModel):
"""向量化状态响应DTO模型"""
document_id: int = Field(..., description="文档ID")
is_vectorized: bool = Field(..., description="是否已向量化")
vectorization_status: str = Field(..., description="向量化状态")
chunks_count: Optional[int] = Field(None, description="文档块数量")
embedding_model: Optional[str] = Field(None, description="使用的嵌入模型")
vectorization_completed_at: Optional[datetime] = Field(None, description="向量化完成时间")
vectorization_error: Optional[str] = Field(None, description="向量化错误信息")
2. 改造文档数据模型
# app/data/model/document.py
from datetime import datetime, timezone
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, func, Boolean, Text
from sqlalchemy.orm import relationship
from app.data.models.base import Base
class Document(Base):
__tablename__ = "documents"
id = Column(Integer, primary_key=True)
title = Column(String(200), nullable=False) # 文档标题
source = Column(String(255), nullable=True) # 文档来源或路径
created_at = Column(DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
server_default=func.now()) # 创建时间
# 向量化相关字段
is_vectorized = Column(Boolean, default=False, nullable=False) # 是否已向量化
vectorization_status = Column(String(50), default="pending") # 向量化状态: pending, processing, completed, failed
vectorization_error = Column(Text, nullable=True) # 向量化错误信息
chunks_count = Column(Integer, default=0) # 文档块数量
embedding_model = Column(String(100), nullable=True) # 使用的嵌入模型
vectorization_completed_at = Column(DateTime(timezone=True), nullable=True) # 向量化完成时间
# 关联知识库ID
knowledge_base_id = Column(Integer, ForeignKey("knowledge_bases.id"), nullable=False)
# 定义关系:文档 -> 知识库(多对一)
knowledge_base = relationship("KnowledgeBase", back_populates="documents")
数据库迁移
alembic revision --autogenerate -m "RAG"
alembic upgrade head
3. 实现业务层向量服务
# app/service/vector.py
from datetime import timezone, datetime
from pathlib import Path
from typing import Optional, List
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.ai.vectorization.service import DocumentVectorizationService
from app.data.models import Document
from app.schemas.vector import (
VectorSearchRequest,
VectorSearchResult,
VectorSearchResponse,
VectorizationStatusResponse
)
class VectorService:
"""向量服务类,处理文档向量化相关操作"""
def __init__(self, db: AsyncSession, vectorization_service: Optional[DocumentVectorizationService] = None):
self.db = db
self.vectorization_service = vectorization_service
async def vectorize_document(
self, document_id: int,
chunk_size:Optional[int],
chunk_overlap: Optional[int]) -> bool:
"""向量化文档"""
if not self.vectorization_service:
raise ValueError("向量化服务未配置")
# 获取文档信息
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
raise ValueError(f"文档 ID {document_id} 不存在")
if doc.is_vectorized:
return True # 已经向量化过了
try:
# 更新状态为处理中
await self.db.execute(
update(Document)
.where(Document.id == document_id)
.values(
vectorization_status="processing",
vectorization_error=None
)
)
await self.db.commit()
# 读取文件内容
if not doc.source or not Path(doc.source).exists():
raise ValueError(f"文档文件不存在: {doc.source}")
with open(doc.source, "rb") as f:
file_content = f.read()
# 执行向量化
vectorization_result = await self.vectorization_service.vectorize_document(
file_path=doc.source,
file_content=file_content,
document_id=document_id,
knowledge_base_id=doc.knowledge_base_id,
filename=doc.title,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
if vectorization_result["success"]:
# 更新文档状态为已完成
await self.db.execute(
update(Document)
.where(Document.id == document_id)
.values(
is_vectorized=True,
vectorization_status="completed",
chunks_count=vectorization_result["chunks_count"],
embedding_model=vectorization_result["model_name"],
vectorization_completed_at=datetime.now(timezone.utc),
vectorization_error=None
)
)
await self.db.commit()
return True
else:
# 更新文档状态为失败
await self.db.execute(
update(Document)
.where(Document.id == document_id)
.values(
vectorization_status="failed",
vectorization_error=vectorization_result["error"]
)
)
await self.db.commit()
return False
except Exception as e:
# 更新文档状态为失败
await self.db.execute(
update(Document)
.where(Document.id == document_id)
.values(
vectorization_status="failed",
vectorization_error=str(e)
)
)
await self.db.commit()
raise e
async def search_similar_documents(
self,
request: VectorSearchRequest
) -> VectorSearchResponse:
"""搜索相似文档"""
if not self.vectorization_service:
raise ValueError("向量化服务未配置")
# 调用向量化服务搜索
results = await self.vectorization_service.search_similar_documents(
query=request.query,
knowledge_base_ids=request.knowledge_base_ids,
k=request.k
)
# 转换结果为schema格式
search_results = []
for result in results:
search_results.append(VectorSearchResult(
chunk_id=result.get("chunk_id", ""),
content=result.get("content", ""),
score=result.get("score", 0.0),
document_id=result.get("document_id", 0),
chunk_index=result.get("chunk_index", 0),
metadata=result.get("metadata", {})
))
return VectorSearchResponse(
query=request.query,
knowledge_base_ids=request.knowledge_base_ids,
results=search_results,
total_results=len(search_results)
)
async def get_vectorization_status(self, document_id: int) -> VectorizationStatusResponse:
"""获取文档向量化状态"""
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
raise ValueError(f"文档 ID {document_id} 不存在")
return VectorizationStatusResponse(
document_id=document_id,
is_vectorized=doc.is_vectorized,
vectorization_status=doc.vectorization_status,
chunks_count=doc.chunks_count,
embedding_model=doc.embedding_model,
vectorization_completed_at=doc.vectorization_completed_at,
vectorization_error=doc.vectorization_error
)
async def get_vectorization_status_by_knowledge_base(self, knowledge_base_id: int) -> List[VectorizationStatusResponse]:
"""获取知识库中所有文档的向量化状态"""
result = await self.db.execute(
select(Document).where(Document.knowledge_base_id == knowledge_base_id)
)
docs = result.scalars().all()
status_list = []
for doc in docs:
status_list.append(VectorizationStatusResponse(
document_id=doc.id,
is_vectorized=doc.is_vectorized,
vectorization_status=doc.vectorization_status,
chunks_count=doc.chunks_count,
embedding_model=doc.embedding_model,
vectorization_completed_at=doc.vectorization_completed_at,
vectorization_error=doc.vectorization_error
))
return status_list
async def batch_vectorize_documents(self, document_ids: List[int]) -> List[dict]:
"""批量向量化文档"""
results = []
for doc_id in document_ids:
try:
success = await self.vectorize_document(doc_id)
results.append({
"document_id": doc_id,
"success": success,
"message": "向量化成功" if success else "向量化失败"
})
except Exception as e:
results.append({
"document_id": doc_id,
"success": False,
"message": f"向量化失败: {str(e)}"
})
return results
4. 实现文档向量服务API
# app/api/vector.py
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.db import get_db
from app.schemas.vector import (
VectorSearchRequest,
VectorSearchResponse,
VectorizationStatusResponse
)
from app.service.vector import VectorService
from app.ai.vectorization.factory import get_vectorization_service
router = APIRouter(prefix="/vectors", tags=["向量搜索与向量化"])
@router.post("/search", response_model=VectorSearchResponse)
async def vector_search(
search_request: VectorSearchRequest,
db: AsyncSession = Depends(get_db)
):
"""向量搜索接口"""
vectorization_service = get_vectorization_service()
vector_service = VectorService(db, vectorization_service)
try:
return await vector_service.search_similar_documents(search_request)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"向量搜索失败: {str(e)}"
)
@router.post("/documents/{doc_id}/vectorize")
async def vectorize_document(
doc_id: int,
chunk_size: Optional[int] = 1000,
chunk_overlap: Optional[int] = 200,
db: AsyncSession = Depends(get_db)
):
"""向量化文档"""
vectorization_service = get_vectorization_service()
vector_service = VectorService(db, vectorization_service)
try:
success = await vector_service.vectorize_document(doc_id, chunk_size, chunk_overlap)
if success:
return {"message": "文档向量化成功", "document_id": doc_id}
else:
raise HTTPException(
status_code=500,
detail="文档向量化失败"
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"文档向量化失败: {str(e)}"
)
@router.get("/documents/{doc_id}/vectorization-status", response_model=VectorizationStatusResponse)
async def get_document_vectorization_status(
doc_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取文档向量化状态"""
vectorization_service = get_vectorization_service()
vector_service = VectorService(db, vectorization_service)
try:
status = await vector_service.get_vectorization_status(doc_id)
return status
except ValueError as e:
raise HTTPException(
status_code=404,
detail=str(e)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取向量化状态失败: {str(e)}"
)
@router.post("/batch-vectorize")
async def batch_vectorize_documents(
document_ids: List[int],
db: AsyncSession = Depends(get_db)
):
"""批量向量化文档"""
vectorization_service = get_vectorization_service()
vector_service = VectorService(db, vectorization_service)
try:
results = await vector_service.batch_vectorize_documents(document_ids)
return {
"message": "批量向量化完成",
"results": results,
"total": len(document_ids),
"success_count": sum(1 for r in results if r["success"])
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"批量向量化失败: {str(e)}"
)
@router.get("/knowledge-bases/{kb_id}/vectorization-status")
async def get_knowledge_base_vectorization_status(
kb_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取知识库中所有文档的向量化状态"""
vectorization_service = get_vectorization_service()
vector_service = VectorService(db, vectorization_service)
try:
status_list = await vector_service.get_vectorization_status_by_knowledge_base(kb_id)
return {
"knowledge_base_id": kb_id,
"documents": status_list,
"total_documents": len(status_list),
"vectorized_count": sum(1 for doc in status_list if doc.is_vectorized),
"processing_count": sum(1 for doc in status_list if doc.vectorization_status == "processing"),
"failed_count": sum(1 for doc in status_list if doc.vectorization_status == "failed")
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取知识库向量化状态失败: {str(e)}"
)
添加API路由配置
# app/main.py
# 注册Vector路由
app.include_router(
vector.router,
prefix="/api"
)
至此,RAG服务就完成了。
六、AI聊天服务集成RAG生成链
RAG服务有了,我们如何将RAG服务与AI聊天结合起来呢?当用户提出问题,需要从知识库中检索数据的时候。AI服务能自动调用RAG服务,检索到的信息,交给大模型,再给出回答。
我们先来梳理一下AI聊天平台如何集成RAG服务都要做什么。
1. 需求分析
- 可控性:是否启用知识库功能
- 知识范围动态化:角色、会话、聚合
- 上下文感知检索
- 答案有据可依
- 可溯源性
- 优雅降级
- 流式响应
2. 完整流程
请求(API层)
{ "message": "A", "session_id": 123, "use_knowledge_base": true }路由(API层):调用聊天方法
编排(服务层):定界 [KB A\KB B\KB C]
调用:RAG聊天方法
检索(AI层)
决策(AI层)
- 无结果:降级 ==》标准聊天方法
- 有结果:取n个片段 ==》格式化JSON
增强(AI层)
生成(LLM)
3. 创建RAG聊天方法
前面我们已经在实现了一个普通的聊天方法,我们再来实现一个RAG的聊天方法,代码如下:
# app/ai/chat.py
class AIChatService:
async def _search_knowledge_base(self, query: str, knowledge_base_ids: List[int]) -> List[Any]:
"""搜索知识库并返回带有元数据的JSON格式结果"""
if not self.vectorization_service or not knowledge_base_ids:
return None
# 直接调用DocumentVectorizationService搜索相关文档
search_results = await self.vectorization_service.search_similar_documents(
query=query,
knowledge_base_ids=knowledge_base_ids,
k=20
)
if not search_results:
return None
# 将搜索结果转换为带有元数据的JSON格式
formatted_results = []
for i, result in enumerate(search_results, 1):
metadata = result.get('metadata', {})
formatted_result = {
"index": i,
"content": result['content'],
"metadata": {
"filename": metadata.get('filename', '未知文件'),
"chunk_index": metadata.get('chunk_index', 0),
"total_chunks": metadata.get('total_chunks', 1)
}
}
formatted_results.append(formatted_result)
return formatted_results
def _create_rag_chain(self, knowledge_data: Dict[str, Any]) -> Runnable:
"""创建RAG链,包含知识库上下文"""
# 将知识库数据转换为JSON字符串
knowledge_json = json.dumps(knowledge_data, ensure_ascii=False, indent=2)
# 构建增强的系统提示词
enhanced_system_prompt = f"""{self.role.system_prompt}
基于以下知识库信息回答用户问题:
知识库搜索结果:
{knowledge_json}
重要要求:
1. 请根据上述知识库信息回答用户的问题
2. 在回答中必须明确指出信息来源,使用以下格式引用:
- 当引用某个文档片段时,使用 [1]、[2] 等格式标注
- 在回答末尾提供完整的来源列表
3. 如果知识库中没有相关信息,请明确说明"根据知识库信息,未找到相关内容"
4. 回答要准确、详细,并确保每个关键信息都有明确的来源标注
5. 充分利用JSON中的元数据信息,包括文件名、片段位置等
文档来源格式说明:
- [1] 表示第1个引用
- 每个结果包含:文件名、片段位置等元数据
- 在回答末尾提供完整的来源列表"""
# 构建Prompt模板
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=enhanced_system_prompt),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
])
# 创建RAG链
return (
RunnablePassthrough.assign(
history=RunnableLambda(lambda x: self.trimmer.invoke(x["history"]))
)
| prompt
| self.llm
| self.parser
)
async def rag_chat_stream(self, history: List[BaseMessage], user_input: str,
knowledge_base_ids: List[int]) -> AsyncGenerator[str]:
"""
基于知识库的RAG流式聊天
"""
# 搜索知识库
knowledge_data = await self._search_knowledge_base(user_input, knowledge_base_ids)
if not knowledge_data:
# 如果没有找到相关文档,回退到普通聊天
async for chunk in self.chat_stream(history, user_input):
yield chunk
return
# 创建RAG链
rag_chain = self._create_rag_chain(knowledge_data)
# 流式生成
async for chunk in rag_chain.astream({"history": history, "input": user_input}):
yield chunk
4. 业务方法集成RAG生成
接下来我们在会话业务中集成RAG生成。
首先,我们修改message
- 添加获取会话和角色的关联的知识库。
- 修改chat方法,根据用户选择决定使用RAG还是普通聊天
# app/service/message.py
class MessageService:
async def _get_knowledge_base_ids(self, session_id: int) -> List[int]:
"""获取会话和角色关联的知识库ID列表"""
session = await self.db.scalar(
select(Session)
.options(
selectinload(Session.knowledge_bases),
selectinload(Session.role).selectinload(Role.knowledge_bases)
)
.where(Session.id == session_id)
)
# 收集知识库ID
knowledge_base_ids = set()
# 添加会话关联的知识库ID
for kb in session.knowledge_bases:
knowledge_base_ids.add(kb.id)
# 添加角色关联的知识库ID
for kb in session.role.knowledge_bases:
knowledge_base_ids.add(kb.id)
return list(knowledge_base_ids)
async def chat(self, chat_request: ChatRequest) -> AsyncGenerator[str]:
"""处理聊天请求"""
# 验证session是否存在
session = await self.db.get(Session, chat_request.session_id)
if not session:
raise ValueError(f"会话 ID {chat_request.session_id} 不存在")
# 获取会话绑定的角色与模型信息
role = await self.db.scalar(
select(Role)
.options(selectinload(Role.provider)) # 贪婪加载 provider
.where(Role.id == session.role_id)
)
# 创建聊天历史管理器
history_mgr = ChatHistoryManager(self.db, chat_request.session_id)
history = await history_mgr.load_history()
assistant_content_buffer = ""
# 根据用户选择决定使用RAG还是普通聊天
if chat_request.use_knowledge_base:
# 检查是否有知识库可用
knowledge_base_ids = await self._get_knowledge_base_ids(chat_request.session_id)
ai = AIChatService(role, vectorization_service=self.vectorization_service)
async for chunk in ai.rag_chat_stream(history, chat_request.message, knowledge_base_ids):
assistant_content_buffer += chunk
payload = json.dumps({
"content": chunk
}, ensure_ascii=False)
yield payload
else:
ai = AIChatService(role)
async for chunk in ai.chat_stream(history, chat_request.message):
assistant_content_buffer += chunk
payload = json.dumps({
"content": chunk
}, ensure_ascii=False)
yield payload
# 保存消息到 DB + Redis
await history_mgr.save_message(MessageBase(
role=MessageRole.User,
content=chat_request.message,
session_id=chat_request.session_id
))
await history_mgr.save_message(MessageBase(
role=MessageRole.Assistant,
content=assistant_content_buffer,
session_id=chat_request.session_id
))
5. 修改API方法
将创建 service 对象的方法传入向量化服务实例
# app/api/session.py
@router.post("/chat")
async def chat(chat_request: ChatRequest, db: AsyncSession = Depends(get_db)):
"""处理聊天请求"""
# service = MessageService(db)
# 创建向量化服务实例
vectorization_service = get_vectorization_service()
# 创建消息服务实例,注入向量化服务
service = MessageService(db, vectorization_service)
# SSE
async def event_generator():
try:
async for chunk in service.chat(chat_request):
# SSE 格式,每条数据前面加 "data: ",末尾两个换行
event = f"data: {chunk}\n\n"
yield event
except ValueError as e:
yield f"\n[Error] {str(e)}"
except Exception as e:
yield f"\n[Error] 处理聊天请求失败: {str(e)}"
return StreamingResponse(event_generator(), media_type="text/event-stream")
七、测试
- 创建一个知识库
{
"name": "我的知识库",
"description": "测试用"
}
上传文档:这里我上传了一个markdown文件《计算机原理.md》
文档向量化:将上传的《计算机原理.md》向量化
查询文档向量化状态
测试语义搜索
{
"query": "存储体系有几大模块",
"knowledge_base_ids": [
1
],
"k": 5
}
- 创建一个新的会话,并绑定知识库
{
"title": "知识库聊天",
"role_id": 1,
"knowledge_base_ids": [1]
}
- 直接聊天提问
{
"message": "存储体系有几大模块",
"session_id": 3,
"use_knowledge_base": true
}
