Spiga

通用AI聊天平台实现2:实现AI聊天服务

2025-07-19 15:02:32

上一篇我们已经实现了该项目的一些非AI功能的基础服务,如提供商、角色、会话等,本篇内容开始来实现AI部分的服务开发。

首先来实现一个AI聊天服务,然后在这些基础上将我们的AI聊天服务实现支持上传知识库。

一、实现AI聊天服务

AI聊天服务的功能很简单:就是通过我们提供的AI角色的信息, 对话的历史记录,已经用户发送过来的最新的信息,与大语言模型进行交互,生成回复内容。

AI聊天服务通过LangChain来实现。

1. 安装LangChain环境

uv add langchain langchain-openai

2. 实现AI聊天服务类

# app/ai/chat.py

from typing import List,AsyncGenerator
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from app.data.models import Role
from app.schemas.message import MessageBase, MessageRole

class AIChatService:
    def __init__(self, role: Role):
        # 创建ChatOpenAI实例
        llm = ChatOpenAI(
            base_url=role.provider.endpoint,
            api_key=role.provider.api_key,
            model=role.provider.model,
            temperature=role.temperature,
            streaming=True
        )

        # 构建Prompt模板
        prompt = ChatPromptTemplate.from_messages([
            ("system", role.system_prompt),	    			#提示词
            MessagesPlaceholder(variable_name="history"),   #历时消息
            ("human", "{input}"),    						#最新消息
        ])

        # 初始化输出解析器
        parser = StrOutputParser()			#从AI回复里提取出只需要用户知道的核心文本

        # 初始化输出解析器
        parser = StrOutputParser()

        # 创建顺序链
        self.chain = prompt | llm | parser

    async def chat(self, history: List[MessageBase], user_input: str) -> str:
        # 转换历史消息
        chat_history= [
            HumanMessage(content=msg.content) if msg.role.value == MessageRole.User else AIMessage(content=msg.content)
            for msg in history if msg.role.value in (MessageRole.User, MessageRole.Assistant)            
        ]   # 列表推导式语法
        
        # 调用链
        result = await self.chain.ainvoke({"history": chat_history, "input": user_input})
        return result

3. 在消息基础服务中使用AI聊天服务

上一篇创建消息服务的时候,我们使用的是一个硬编码的模拟聊天消息,现在我们可以使用AI聊天服务替换这段代码了

# app/service/message.py

# 旧的代码
# 这里应该调用AI服务来生成回复
# 目前先返回一个模拟的回复
assistant_content = (f"我是'{role.name}',我们聊了'{len(conversation_history)}'条消息,收到了您的消息:'{chat_request.message}'。"
                     f"这是一个模拟回复,实际应用中这里会调用AI模型生成回复。")

# 新的代码
ai = AIChatService(role)
# 调用AI聊天服务来生成回复
assistant_content = await ai.chat(conversation_history, chat_request.message)

4. 运行测试

  • 创建一个提供商
{
  "name": "通义千问",
  "endpoint": "https://dashscope.aliyuncs.com/compatible-mode/v1",
  "model": "qwen-plus",
  "api_key": "sk-xxxxxx"
}
  • 创建角色
{
  "name": "测试小助手",
  "description": "用来测试的AI助理",
  "system_prompt": "你叫小唆,你是一个AI学习助手,善于引导用户学习各类知识与技术",
  "temperature": 0.8,
  "provider_id": 1
}
  • 创建一个会话
{
  "title": "测试会话",
  "role_id": 1
}
  • 聊天
{
  "message": "你是谁",
  "session_id": 1
}

//AI回复
{
  "role": "assistant",
  "content": "嗨,我是小唆!是你的AI学习小伙伴哦~我最喜欢和大家一起探索新知识了。不管是编程、数学还是其他有趣的技术话题,我们都可以一起研究。虽然我个子小小的,但学习劲头可足啦!来吧,告诉我你想学点什么?让我们一起开启学习之旅吧!(•̀ᴗ•́)و",
  "session_id": 1,
  "id": 2,
  "created_at": "2025-10-02T10:29:49.349687"
}

//继续问
{
  "message": "什么是LangChain",
  "session_id": 1
}

//AI回复
{
  "role": "assistant",
  "content": "嗨,我是小唆!让我来跟你聊聊LangChain吧~\n\n你知道吗?就像我们平时搭积木一样,LangChain就是专门为大语言模型搭建\"积木\"的工具箱哦!它就像是一个超级连接器,能把各种不同的工具和数据源都连在一起。\n\n想象一下,如果你要让AI助手帮你查天气、订餐厅,还要写封邮件,这需要好多不同的技能对吧?以前这些都要单独写代码,特别麻烦。但有了LangChain,就像拿到了一套智能乐高,把这些功能像拼积木一样轻松组合起来!\n\n它最厉害的地方在于:\n1. 能记住对话历史,就像我们聊天不会前言不搭后语\n2. 可以查找资料,像是给AI配了个小秘书\n3. 还能调用各种工具,比如计算器、搜索引擎\n\n要不要一起动手试试看?我们可以从一个小项目开始,比如做个智能问答机器人。你觉得怎么样呀?(•̀ᴗ•́)و",
  "session_id": 1,
  "id": 4,
  "created_at": "2025-10-02T10:29:49.349687"
}

//第三次追问
{
  "message": "它有什么功能",
  "session_id": 1
}

//AI回复
{
  "role": "assistant",
  "content": "哎呀,让我掰着手指数数LangChain的本领!它可真是个全能小能手呢~\n\n首先最厉害的就是\"记忆魔法\"啦!就像我们聊天时记得之前说过什么,LangChain能让AI记住对话历史。比如你问\"北京天气怎么样?\",接着问\"那上海呢?\",它也不会懵圈。\n\n还有超酷的\"工具百宝箱\"功能!我给你举个例子,假设我们要做个旅行规划助手:\n- 需要查资料?有Retrieval功能帮你从海量信息里找答案\n- 要做数学题?可以连接计算器工具\n- 想分析网页内容?没问题,直接调用浏览器工具\n\n最有趣的是它的\"思维链\"能力!就像我们解数学题要一步一步来,LangChain能让AI先思考\"我需要做什么\",再分步骤解决问题。比如说你要算\"3个人住5晚,每晚300块,人均多少\",它会先算总价,再平均分配。\n\n对了对了,它还能当\"数据翻译官\"!能把不同格式的数据互相转换,比如把数据库里的信息变成漂亮的报告。\n\n咦,看你感兴趣的样子,要不要我们一起动手做个智能小助手玩玩?我们可以从简单的开始~你觉得想先试试哪个功能呀?(•̀ᴗ•́)و",
  "session_id": 1,
  "id": 6,
  "created_at": "2025-10-02T10:29:49.349687"
}

上面的例子我们可以看到,第三次问“它有什么功能”的时候,AI是知道这里的“它”是langchain的。

二、实现流式输出方法

1. 创建异步流式方法

# app/ai/chat.py

async def chat_stream(self, history: List[MessageBase], user_input: str) -> AsyncGenerator[str]:
    """
    流式返回生成内容,每生成一段就 yield 一次
    """
    chat_history = [
        HumanMessage(content=msg.content) if msg.role.value == MessageRole.User else AIMessage(content=msg.content)
        for msg in history if msg.role.value in (MessageRole.User, MessageRole.Assistant)
    ]

    # 这里直接调用 LLM 的流式生成方法
    async for chunk in self.chain.astream({"history": chat_history, "input": user_input}):
        yield chunk

2. 修改chat 服务实现

  • AI方法调用改成流水方法调用
  • 返回值改成异步方式返回
  • 因此不再需要 return了
  • assistant_content改成从assistant_content_buffer获取

完整函数如下:

# app/service/message.py

async def chat(self, chat_request: ChatRequest) -> AsyncGenerator[str]:
    """处理聊天请求"""
    # 验证session是否存在
    session = await self.db.get(Session, chat_request.session_id)
    if not session:
        raise ValueError(f"会话 ID {chat_request.session_id} 不存在")

    # 获取会话绑定的角色与模型信息
    role = await self.db.scalar(
        select(Role)
        .options(selectinload(Role.provider))  # 贪婪加载 provider
        .where(Role.id == session.role_id)
    )
    if not role:
        raise ValueError(f"角色【{role.name}】不存在")

    # 获取消息历史
    conversation_history = await self.get_conversation_history(chat_request.session_id)

    ai = AIChatService(role)
    # 调用 AI 服务的流式方法
    assistant_content_buffer = ""
    async for chunk in ai.chat_stream(conversation_history, chat_request.message):
        assistant_content_buffer += chunk
        payload = json.dumps({
            "content": chunk
        }, ensure_ascii=False)
        yield payload  # 边生成边返回给调用方

    # 创建用户消息
    user_message = MessageBase(
        role=MessageRole.User,
        content=chat_request.message,
        session_id=chat_request.session_id
    )
    await self.create(user_message)

    # 创建助手回复
    assistant_message = MessageBase(
        role=MessageRole.Assistant,
        content=assistant_content_buffer,
        session_id=chat_request.session_id
    )
    await self.create(assistant_message)

3. 修改chat API实现

使用SSE协议,通过内部函数实现流失HTTP响应

@router.post("/chat")
async def chat(chat_request: ChatRequest, db: AsyncSession = Depends(get_db)):
    """处理聊天请求"""
    service = MessageService(db)
    # SSE
    async def event_generator():
        try:
            async for chunk in service.chat(chat_request):
                # SSE 格式,每条数据前面加 "data: ",末尾两个换行
                event = f"data: {chunk}\n\n"
                yield event
        except ValueError as e:
            yield f"\n[Error] {str(e)}"
        except Exception as e:
            yield f"\n[Error] 处理聊天请求失败: {str(e)}"

    return StreamingResponse(event_generator(), media_type="text/event-stream")

4. 运行测试

创建一个新的会话,再发起聊天。

可以看到这次回复内容已经是流式返回了

{
  "message": "你是谁",
  "session_id": 2
}

//AI回复
data: {"content": ""}

data: {"content": "嗨"}

data: {"content": ",我是"}

data: {"content": "小唆!"}

data: {"content": "*眨"}

data: {"content": "眨眼* \n\n我"}

data: {"content": "是个爱学习的小助手"}

data: {"content": ",最喜欢和大家一起"}

data: {"content": "探索新知识啦"}

data: {"content": "。就像现在,"}

data: {"content": "看到你来特别"}

data: {"content": "开心呢!不管是"}

data: {"content": "学编程、搞"}

data: {"content": "科研,还是想"}

data: {"content": "了解生活中的小常识"}

data: {"content": ",我都很乐意陪你一起"}

data: {"content": "研究。\n\n*歪"}

data: {"content": "头好奇地看着你* "}

data: {"content": "你最近在学什么有趣"}

data: {"content": "的东西吗?说不定"}

data: {"content": "我们可以一起讨论呢!"}

data: {"content": ""}

三、缓存对话历史记录

1. 安装与配置redis

uv add redis langchain_community
app/config.py

REDIS_URL = os.getenv("REDIS_URL")
app/.env

REDIS_URL=redis://127.0.0.1:6379/0

2. 聊天历史管理类

# app/ai/history.py

from langchain_core.messages import BaseMessage
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from langchain_community.chat_message_histories import RedisChatMessageHistory
from app.data.models import Message, Session
from app.schemas.message import MessageBase, MessageResponse, MessageRole
from app.config import REDIS_URL

class ChatHistoryManager:
    def __init__(self, db: AsyncSession, session_id: int, cache_ttl = 600):
        self.db = db
        self.session_id = session_id
        self.redis = Redis.from_url(REDIS_URL, decode_responses=True)
        self.redis_history = RedisChatMessageHistory(
            session_id=f"chat:{session_id}",
            url=REDIS_URL,
            ttl=cache_ttl     # 默认缓存10分钟
        )

    async def load_history(self) -> list[BaseMessage]:
        """加载历史记录:先查Redis,没有就回填DB"""
        if len(self.redis_history.messages) == 0:
            # 从DB回填
            result = await self.db.execute(
                select(Message)
                .where(Message.session_id == self.session_id)
                .order_by(Message.id)
                .limit(100)
            )
            messages = result.scalars().all()
            for msg in messages:
                if msg.role == MessageRole.User:
                    self.redis_history.add_user_message(msg.content)
                elif msg.role == MessageRole.Assistant:
                    self.redis_history.add_ai_message(msg.content)

        return self.redis_history.messages

    async def save_message(self, message: MessageBase) -> MessageResponse:
        """保存消息到DB + Redis"""
        # 存DB
        session = await self.db.get(Session, message.session_id)
        if not session:
            raise ValueError(f"会话 ID {message.session_id} 不存在")

        message_data = message.model_dump()
        db_message = Message(**message_data)
        self.db.add(db_message)
        await self.db.commit()
        await self.db.refresh(db_message)

        # 存Redis
        if message.role == MessageRole.User:
            self.redis_history.add_user_message(message.content)
        elif message.role == MessageRole.Assistant:
            self.redis_history.add_ai_message(message.content)

        return MessageResponse.model_validate(db_message)

3. 聊天历史消息修剪

# app/ai/chat.py

from typing import List, AsyncGenerator
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, trim_messages
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from app.data.models import Role
from app.schemas.message import MessageBase, MessageRole

class AIChatService:
    def __init__(self, role: Role,  max_tokens = 20):
        # 创建ChatOpenAI实例
        llm = ChatOpenAI(
            base_url=role.provider.endpoint,
            api_key=role.provider.api_key,
            model=role.provider.model,
            temperature=role.temperature,
            streaming=True
        )

        # 构建聊天历史消息的修剪器
        trimmer = trim_messages(
            token_counter=len,
            max_tokens=max_tokens,
            start_on=("human", "ai")
        )

        # 构建Prompt模板
        prompt = ChatPromptTemplate.from_messages([
            ("system", role.system_prompt),
            MessagesPlaceholder(variable_name="history"),
            ("human", "{input}"),
        ])

        # 初始化输出解析器
        parser = StrOutputParser()

        # 创建顺序链
        # x={"history": [<一个很长的消息列表>], "input": "你好吗?"}
        # {"history": [<一个被修剪过的短列表>], "input": "你好吗?"}
        self.chain = (
                RunnablePassthrough.assign(
                    history=RunnableLambda(lambda x: trimmer.invoke(x["history"]))
                )
                | prompt
                | llm
                | parser
        )

    async def chat_stream(self, history: List[BaseMessage], user_input: str) -> AsyncGenerator[str]:
        """
        流式返回生成内容,每生成一段就 yield 一次
        """
        # 这里直接调用 LLM 的流式生成方法
        async for chunk in self.chain.astream({"history": history, "input": user_input}):
            yield chunk

上面代码除了把聊天消息历史做裁剪外,chat_stream方法的history参数,改成了langchain的基类BaseMessage,因为上一步ChatHistoryManager中返回直已经是LangChain的基类类型了。

4. 修改message中chat函数

把获取消息历史方法替换成从ChatHistoryManager中获取

async def chat(self, chat_request: ChatRequest) -> AsyncGenerator[str]:
    """处理聊天请求"""
    # 验证session是否存在
    session = await self.db.get(Session, chat_request.session_id)
    if not session:
        raise ValueError(f"会话 ID {chat_request.session_id} 不存在")

    # 获取会话绑定的角色与模型信息
    role = await self.db.scalar(
        select(Role)
        .options(selectinload(Role.provider))  # 贪婪加载 provider
        .where(Role.id == session.role_id)
    )
    if not role:
        raise ValueError(f"角色【{role.name}】不存在")

    # 创建聊天历史管理器
    history_mgr = ChatHistoryManager(self.db, chat_request.session_id)
    history = await history_mgr.load_history()

    # 调用 AI 服务的流式方法
    ai = AIChatService(role)
    assistant_content_buffer = ""
    async for chunk in ai.chat_stream(history, chat_request.message):
        assistant_content_buffer += chunk
        payload = json.dumps({
            "content": chunk
        }, ensure_ascii=False)
        yield payload  # 边生成边返回给调用方

    # 保存消息到 DB + Redis
    await history_mgr.save_message(MessageBase(
        role=MessageRole.User,
        content=chat_request.message,
        session_id=chat_request.session_id
    ))
    await history_mgr.save_message(MessageBase(
        role=MessageRole.Assistant,
        content=assistant_content_buffer,
        session_id=chat_request.session_id
    ))

5. 运行测试

随便发送一条消息后,连接到redis,查看数据是否缓存。

四、知识库的管理

知识库的管理实现跟第一篇介绍基础服务的实现方式差不多。也是按下面几个步骤实现:

1. 定义DTO

  • 知识库DTO
# app/schemas/knowledge_base.py

from typing import Optional
from pydantic import BaseModel, Field

class KnowledgeBaseBase(BaseModel):
    """知识库基础DTO模型"""
    name: str = Field(..., max_length=100, description="知识库名称")
    description: Optional[str] = Field(None, description="知识库描述")

class KnowledgeBaseCreate(KnowledgeBaseBase):
    """创建知识库的DTO模型"""
    pass

class KnowledgeBaseResponse(KnowledgeBaseBase):
    """知识库响应DTO模型"""
    id: int = Field(..., description="知识库ID")
    document_count: int = Field(0, description="文档数量")

    class Config:
        from_attributes = True

class KnowledgeBaseUpdate(BaseModel):
    """更新知识库的DTO模型"""
    name: Optional[str] = Field(None, max_length=100, description="知识库名称")
    description: Optional[str] = Field(None, description="知识库描述")

class KnowledgeBaseListResponse(BaseModel):
    """知识库列表响应DTO模型"""
    id: int = Field(..., description="知识库ID")
    name: str = Field(..., description="知识库名称")
    description: Optional[str] = Field(None, description="知识库描述")
    document_count: int = Field(0, description="文档数量")

    class Config:
        from_attributes = True
  • 角色DTO绑定知识库
# app/schemas/role.py

from typing import Optional, List
from pydantic import BaseModel, Field

class RoleBase(BaseModel):
    """Role基础DTO模型"""
    name: str = Field(..., max_length=100, description="角色名称")
    description: Optional[str] = Field(None, max_length=500, description="角色描述")
    system_prompt: Optional[str] = Field(None, description="系统提示词")
    temperature: float = Field(0.7, ge=0.0, le=2.0, description="模型温度")
    provider_id: int = Field(..., description="关联的模型提供商ID")

class RoleCreate(RoleBase):
    """创建Role的DTO模型"""
    knowledge_base_ids: Optional[List[int]] = Field(default=[], description="绑定的知识库ID列表")

class RoleResponse(RoleBase):
    """Role响应DTO模型"""
    id: int = Field(..., description="角色ID")

    class Config:
        from_attributes = True

class RoleUpdate(BaseModel):
    """更新Role的DTO模型"""
    name: Optional[str] = Field(None, max_length=100, description="角色名称")
    description: Optional[str] = Field(None, max_length=500, description="角色描述")
    system_prompt: Optional[str] = Field(None, description="系统提示词")
    temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="模型温度")
    provider_id: Optional[int] = Field(None, description="关联的模型提供商ID")
    knowledge_bases: Optional[List[dict]] = Field(default=[], description="绑定的知识库详细信息")

class RoleWithKnowledgeBases(RoleBase):
    """包含知识库详细信息的Role响应DTO模型"""
    id: int = Field(..., description="角色ID")
    knowledge_base_ids: Optional[List[int]] = Field(default=[], description="绑定的知识库ID列表")

    class Config:
        from_attributes = True
  • 包含知识库详细信息的Session响应DTO模型
# app/schemas/session.py

from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field

class SessionBase(BaseModel):
    """Session基础DTO模型"""
    title: str = Field(..., max_length=100, description="会话标题")
    role_id: int = Field(..., description="关联的角色ID")

class SessionCreate(SessionBase):
    """创建Session的DTO模型"""
    knowledge_base_ids: Optional[List[int]] = Field(default=[], description="绑定的知识库ID列表")

class SessionResponse(SessionBase):
    """Session响应DTO模型"""
    id: int = Field(..., description="会话ID")
    created_at: datetime = Field(..., description="创建时间")
    message_count: Optional[int] = Field(None, description="消息数量")

    class Config:
        from_attributes = True

class SessionUpdate(BaseModel):
    """更新Session的DTO模型"""
    title: Optional[str] = Field(None, max_length=100, description="会话标题")
    role_id: Optional[int] = Field(None, description="关联的角色ID")
    knowledge_bases: Optional[List[dict]] = Field(default=[], description="绑定的知识库详细信息")

class SessionListResponse(BaseModel):
    """Session列表响应DTO模型"""
    id: int = Field(..., description="会话ID")
    title: str = Field(..., description="会话标题")
    created_at: datetime = Field(..., description="创建时间")
    role_name: str = Field(..., description="角色名称")

    class Config:
        from_attributes = True

class SessionWithKnowledgeBases(SessionBase):
    """包含知识库详细信息的Session响应DTO模型"""
    id: int = Field(..., description="会话ID")
    created_at: datetime = Field(..., description="创建时间")
    message_count: Optional[int] = Field(None, description="消息数量")
    knowledge_base_ids: Optional[List[int]] = Field(default=[], description="绑定的知识库ID列表")

    class Config:
        from_attributes = True

2. 实现业务方法

  • 知识库业务服务实现
# app/service/knowledge_base.py

from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, update, func
from app.data.models import KnowledgeBase, Document
from app.schemas.knowledge_base import KnowledgeBaseCreate, KnowledgeBaseResponse, KnowledgeBaseUpdate, KnowledgeBaseListResponse

class KnowledgeBaseService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def create(self, kb_create: KnowledgeBaseCreate) -> KnowledgeBaseResponse:
        """创建知识库"""
        # 检查名称是否已存在
        existing_kb = await self.get_by_name(kb_create.name)
        if existing_kb:
            raise ValueError(f"知识库名称 '{kb_create.name}' 已存在")

        kb_data = kb_create.model_dump()
        kb = KnowledgeBase(**kb_data)
        self.db.add(kb)
        await self.db.commit()
        await self.db.refresh(kb)

        return KnowledgeBaseResponse(
            id=kb.id,
            name=kb.name,
            description=kb.description,
            document_count=0
        )

    async def get_by_id(self, kb_id: int) -> Optional[KnowledgeBaseResponse]:
        """根据ID获取知识库"""
        # 先获取知识库基本信息
        result = await self.db.execute(
            select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
        )
        kb = result.scalar_one_or_none()

        if kb:
            # 单独查询文档数量,避免贪婪加载
            doc_count = await self.get_document_count(kb_id)
            return KnowledgeBaseResponse(
                id=kb.id,
                name=kb.name,
                description=kb.description,
                document_count=doc_count
            )
        return None

    async def get_by_name(self, name: str) -> Optional[KnowledgeBaseResponse]:
        """根据名称获取知识库"""
        # 先获取知识库基本信息
        result = await self.db.execute(
            select(KnowledgeBase).where(KnowledgeBase.name == name)
        )
        kb = result.scalar_one_or_none()

        if kb:
            # 单独查询文档数量,避免贪婪加载
            doc_count = await self.get_document_count(kb.id)
            return KnowledgeBaseResponse(
                id=kb.id,
                name=kb.name,
                description=kb.description,
                document_count=doc_count
            )
        return None

    async def get_all(self) -> List[KnowledgeBaseListResponse]:
        """获取所有知识库"""
        # 先获取所有知识库基本信息
        result = await self.db.execute(
            select(KnowledgeBase).order_by(KnowledgeBase.id.desc())
        )
        kbs = result.scalars().all()

        # 批量获取所有知识库的文档数量
        kb_ids = [kb.id for kb in kbs]
        doc_counts = {}

        if kb_ids:
            # 一次性查询所有知识库的文档数量
            count_result = await self.db.execute(
                select(
                    Document.knowledge_base_id,
                    func.count(Document.id).label('doc_count')
                )
                .where(Document.knowledge_base_id.in_(kb_ids))
                .group_by(Document.knowledge_base_id)
            )

            # 构建文档数量映射
            for row in count_result.all():
                doc_counts[row.knowledge_base_id] = row.doc_count

        # 构建响应列表
        return [
            KnowledgeBaseListResponse(
                id=kb.id,
                name=kb.name,
                description=kb.description,
                document_count=doc_counts.get(kb.id, 0)
            )
            for kb in kbs
        ]

    async def update(self, kb_id: int, kb_data: KnowledgeBaseUpdate) -> Optional[KnowledgeBaseResponse]:
        """更新知识库"""
        existing_kb = await self.get_by_id(kb_id)
        if not existing_kb:
            return None

        # 如果更新名称,检查新名称是否已存在
        if kb_data.name and kb_data.name != existing_kb.name:
            name_exists = await self.get_by_name(kb_data.name)
            if name_exists:
                raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")

        # 构建更新数据
        update_data = kb_data.model_dump(exclude_unset=True)

        if update_data:
            await self.db.execute(
                update(KnowledgeBase)
                .where(KnowledgeBase.id == kb_id)
                .values(**update_data)
            )
            await self.db.commit()

        return await self.get_by_id(kb_id)

    async def delete(self, kb_id: int) -> bool:
        """删除知识库(如果知识库中有文档则不允许删除)"""
        existing_kb = await self.get_by_id(kb_id)
        if not existing_kb:
            return False

        # 检查知识库中是否有文档
        doc_count = await self.get_document_count(kb_id)
        if doc_count > 0:
            raise ValueError(f"知识库 '{existing_kb.name}' 中还有 {doc_count} 个文档,请先删除所有文档后再删除知识库")

        await self.db.execute(
            delete(KnowledgeBase).where(KnowledgeBase.id == kb_id)
        )
        await self.db.commit()

        return True

    async def get_document_count(self, kb_id: int) -> int:
        """获取知识库的文档数量"""
        result = await self.db.execute(
            select(func.count(Document.id))
            .where(Document.knowledge_base_id == kb_id)
        )
        return result.scalar() or 0
  • 角色绑定知识库业务实现
# app/service/role.py

from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, update
from sqlalchemy.orm import selectinload
from app.data.models import Role, Provider, KnowledgeBase
from app.schemas.role import RoleCreate, RoleResponse, RoleUpdate, RoleWithKnowledgeBases

class RoleService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def create(self, role_create: RoleCreate) -> RoleResponse:
        """创建Role"""
        # 验证provider_id是否存在
        provider = await self.db.get(Provider, role_create.provider_id)
        if not provider:
            raise ValueError(f"提供商 ID {role_create.provider_id} 不存在")

        # 验证知识库ID是否存在
        if role_create.knowledge_base_ids:
            for kb_id in role_create.knowledge_base_ids:
                kb = await self.db.get(KnowledgeBase, kb_id)
                if not kb:
                    raise ValueError(f"知识库 ID {kb_id} 不存在")

        role_data = role_create.model_dump(exclude={'knowledge_base_ids'})
        role = Role(**role_data)

        # 如果有知识库绑定,先获取知识库对象
        if role_create.knowledge_base_ids:
            knowledge_bases = await self.db.execute(
                select(KnowledgeBase).where(KnowledgeBase.id.in_(role_create.knowledge_base_ids))
            )
            role.knowledge_bases = knowledge_bases.scalars().all()

        self.db.add(role)
        await self.db.commit()
        await self.db.refresh(role)

        return RoleResponse.model_validate(role)

    async def get_by_id(self, role_id: int) -> Optional[RoleResponse]:
        """根据ID获取Role"""
        role = await self.db.get(Role, role_id)

        if role:
            return RoleResponse.model_validate(role)
        return None

    async def get_by_name(self, name: str) -> Optional[RoleResponse]:
        """根据名称获取Role"""
        result = await self.db.execute(
            select(Role).where(Role.name == name)
        )
        role = result.scalar_one_or_none()

        if role:
            return RoleResponse.model_validate(role)
        return None

    async def get_all(self) -> List[RoleResponse]:
        """获取所有Role"""
        result = await self.db.execute(
            select(Role)
            .order_by(Role.id.desc())
        )
        roles = result.scalars().all()

        role_responses = [RoleResponse.model_validate(r) for r in roles]

        return role_responses

    async def update(self, role_id: int, role_data: RoleUpdate) -> Optional[RoleResponse]:
        """更新Role"""
        existing_role = await self.get_by_id(role_id)
        if not existing_role:
            return None

        if role_data.provider_id is not None:
            provider = await self.db.get(Provider, role_data.provider_id)
            if not provider:
                raise ValueError(f"提供商 ID {role_data.provider_id} 不存在")

        # 验证知识库ID是否存在
        if role_data.knowledge_base_ids is not None:
            for kb_id in role_data.knowledge_base_ids:
                kb = await self.db.get(KnowledgeBase, kb_id)
                if not kb:
                    raise ValueError(f"知识库 ID {kb_id} 不存在")

        # 构建更新数据
        update_data = role_data.model_dump(exclude_unset=True, exclude={'knowledge_base_ids'})

        if update_data:
            await self.db.execute(
                update(Role)
                .where(Role.id == role_id)
                .values(**update_data)
            )

        # 更新知识库绑定
        if role_data.knowledge_base_ids is not None:
            # 获取角色对象
            result = await self.db.execute(
                select(Role)
                .options(selectinload(Role.knowledge_bases))
                .where(Role.id == role_id)
            )
            role = result.scalar_one_or_none()

            if role:
                # 清空现有绑定
                role.knowledge_bases.clear()

                # 添加新的绑定
                if role_data.knowledge_base_ids:
                    knowledge_bases = await self.db.execute(
                        select(KnowledgeBase).where(KnowledgeBase.id.in_(role_data.knowledge_base_ids))
                    )
                    role.knowledge_bases = knowledge_bases.scalars().all()

        await self.db.commit()
        return await self.get_by_id(role_id)

    async def delete(self, role_id: int) -> bool:
        """删除Role"""
        existing_role = await self.get_by_id(role_id)
        if not existing_role:
            return False

        await self.db.execute(
            delete(Role).where(Role.id == role_id)
        )
        await self.db.commit()

        return True

    async def get_by_id_with_knowledge_bases(self, role_id: int) -> Optional[RoleWithKnowledgeBases]:
        """根据ID获取Role,包含知识库信息"""
        result = await self.db.execute(
            select(Role)
            .options(selectinload(Role.knowledge_bases))
            .where(Role.id == role_id)
        )
        role = result.scalar_one_or_none()

        if role:
            role_data = RoleWithKnowledgeBases.model_validate(role)
            role_data.knowledge_base_ids = [kb.id for kb in role.knowledge_bases]
            role_data.knowledge_bases = [
                {"id": kb.id, "name": kb.name, "description": kb.description}
                for kb in role.knowledge_bases
            ]
            return role_data
        return None
  • 会话关联知识库业务实现
# app/service/session.py

from typing import List, Optional
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.data.models import Session, Role, Provider, Message, KnowledgeBase
from app.schemas.session import SessionCreate, SessionResponse, SessionListResponse, SessionWithKnowledgeBases

class SessionService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def create(self, session_create: SessionCreate) -> SessionResponse:
        """创建Session"""
        role = await self.db.get(Role, session_create.role_id)
        if not role:
            raise ValueError(f"角色 ID {session_create.role_id} 不存在")

        # 验证知识库ID是否存在
        if session_create.knowledge_base_ids:
            for kb_id in session_create.knowledge_base_ids:
                kb = await self.db.get(KnowledgeBase, kb_id)
                if not kb:
                    raise ValueError(f"知识库 ID {kb_id} 不存在")

        session_data = session_create.model_dump(exclude={'knowledge_base_ids'})
        session = Session(**session_data)

        # 如果有知识库绑定,先获取知识库对象
        if session_create.knowledge_base_ids:
            knowledge_bases = await self.db.execute(
                select(KnowledgeBase).where(KnowledgeBase.id.in_(session_create.knowledge_base_ids))
            )
            session.knowledge_bases = knowledge_bases.scalars().all()

        self.db.add(session)
        await self.db.commit()
        await self.db.refresh(session)

        return SessionResponse.model_validate(session)

    async def get_all(self) -> List[SessionListResponse]:
        """获取所有Session,包含角色"""
        result = await self.db.execute(
            select(
                Session.id,
                Session.title,
                Session.created_at,
                Role.name.label('role_name')
            )
            .join(Role, Session.role_id == Role.id)
            .group_by(Session.id, Session.title, Session.created_at, Role.name, Provider.name)
            .order_by(Session.created_at.desc())
        )

        sessions = result.all()

        return [
            SessionListResponse(
                id=session.id,
                title=session.title,
                created_at=session.created_at,
                role_name=session.role_name
            )
            for session in sessions
        ]

    async def delete(self, session_id: int) -> bool:
        """删除Session"""
        await self.db.execute(
            delete(Session).where(Session.id == session_id)
        )
        await self.db.commit()

        return True

    async def get_by_id_with_knowledge_bases(self, session_id: int) -> Optional[SessionWithKnowledgeBases]:
        """根据ID获取Session,包含知识库信息"""
        result = await self.db.execute(
            select(Session)
            .options(selectinload(Session.knowledge_bases))
            .where(Session.id == session_id)
        )
        session = result.scalar_one_or_none()

        if session:
            session_data = SessionWithKnowledgeBases.model_validate(session)
            session_data.knowledge_base_ids = [kb.id for kb in session.knowledge_bases]
            session_data.knowledge_bases = [
                {"id": kb.id, "name": kb.name, "description": kb.description}
                for kb in session.knowledge_bases
            ]
            return session_data
        return None

3. 实现API

  • 实现知识库API
# app/api/knowledge_base.py

from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.db import get_db
from app.schemas.knowledge_base import (
    KnowledgeBaseResponse,
    KnowledgeBaseCreate,
    KnowledgeBaseUpdate,
    KnowledgeBaseListResponse
)
from app.service.knowledge_base import KnowledgeBaseService

router = APIRouter(prefix="/knowledge-bases", tags=["知识库与文档管理"])

@router.post("/", response_model=KnowledgeBaseResponse, status_code=201)
async def create_knowledge_base(
    kb_data: KnowledgeBaseCreate,
    db: AsyncSession = Depends(get_db)
):
    """创建新的知识库"""
    service = KnowledgeBaseService(db)

    try:
        kb = await service.create(kb_data)
        return kb
    except ValueError as e:
        raise HTTPException(
            status_code=400,
            detail=str(e)
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"创建知识库失败: {str(e)}"
        )

@router.get("/", response_model=List[KnowledgeBaseListResponse])
async def get_knowledge_bases(db: AsyncSession = Depends(get_db)):
    """获取所有知识库"""
    service = KnowledgeBaseService(db)

    try:
        kbs = await service.get_all()
        return kbs
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"获取知识库列表失败: {str(e)}"
        )

@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
async def get_knowledge_base(
    kb_id: int,
    db: AsyncSession = Depends(get_db)
):
    """根据ID获取知识库"""
    service = KnowledgeBaseService(db)

    kb = await service.get_by_id(kb_id)
    if not kb:
        raise HTTPException(
            status_code=404,
            detail=f"知识库 ID {kb_id} 不存在"
        )

    return kb

@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
async def update_knowledge_base(
    kb_id: int,
    kb_data: KnowledgeBaseUpdate,
    db: AsyncSession = Depends(get_db)
):
    """更新知识库"""
    service = KnowledgeBaseService(db)

    try:
        kb = await service.update(kb_id, kb_data)
        if not kb:
            raise HTTPException(
                status_code=404,
                detail=f"知识库 ID {kb_id} 不存在"
            )
        return kb
    except ValueError as e:
        raise HTTPException(
            status_code=400,
            detail=str(e)
        )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"更新知识库失败: {str(e)}"
        )

@router.delete("/{kb_id}", status_code=204)
async def delete_knowledge_base(
    kb_id: int,
    db: AsyncSession = Depends(get_db)
):
    """删除知识库"""
    service = KnowledgeBaseService(db)

    try:
        success = await service.delete(kb_id)
        if not success:
            raise HTTPException(
                status_code=404,
                detail=f"知识库 ID {kb_id} 不存在"
            )
    except ValueError as e:
        raise HTTPException(
            status_code=400,
            detail=str(e)
        )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"删除知识库失败: {str(e)}"
        )
  • 角色API添加角色包含的知识库接口
# app/api/role.py

@router.get("/{role_id}/knowledge-bases", response_model=RoleWithKnowledgeBases)
async def get_role_with_knowledge_bases(
        role_id: int,
        db: AsyncSession = Depends(get_db)
):
    """根据ID获取角色,包含知识库信息"""
    service = RoleService(db)

    role = await service.get_by_id_with_knowledge_bases(role_id)
    if not role:
        raise HTTPException(
            status_code=404,
            detail=f"角色 ID {role_id} 不存在"
        )

    return role
  • 会话API添加会话包含的知识库接口
# app/api/session.py

@router.get("/{session_id}/knowledge-bases", response_model=SessionWithKnowledgeBases)
async def get_session_with_knowledge_bases(
        session_id: int,
        db: AsyncSession = Depends(get_db)
):
    """根据ID获取会话,包含知识库信息"""
    service = SessionService(db)

    session = await service.get_by_id_with_knowledge_bases(session_id)
    if not session:
        raise HTTPException(
            status_code=404,
            detail=f"会话 ID {session_id} 不存在"
        )

    return session

4. 配置路由

# app/main.py

# 注册KnowledgeBase路由
app.include_router(
    knowledge_base.router,
    prefix="/api"
)

五、知识库文档的上传

1. 定义DTO

# app/schemas/document.py

from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field

class DocumentBase(BaseModel):
    """文档基础DTO模型"""
    title: str = Field(..., max_length=200, description="文档标题")
    source: Optional[str] = Field(None, max_length=255, description="文档来源或路径")

class DocumentResponse(DocumentBase):
    """文档响应DTO模型"""
    id: int = Field(..., description="文档ID")
    knowledge_base_id: int = Field(..., description="所属知识库ID")
    created_at: datetime = Field(..., description="创建时间")

    class Config:
        from_attributes = True

class DocumentListResponse(BaseModel):
    """文档列表响应DTO模型"""
    id: int = Field(..., description="文档ID")
    title: str = Field(..., description="文档标题")
    source: Optional[str] = Field(None, description="文档来源或路径")
    knowledge_base_id: int = Field(..., description="所属知识库ID")
    created_at: datetime = Field(..., description="创建时间")

    class Config:
        from_attributes = True

class DocumentUploadResponse(BaseModel):
    """文档上传响应DTO模型"""
    id: int = Field(..., description="文档ID")
    title: str = Field(..., description="文档标题")
    source: str = Field(..., description="文档文件路径")
    knowledge_base_id: int = Field(..., description="所属知识库ID")
    created_at: datetime = Field(..., description="创建时间")
    file_size: int = Field(..., description="文件大小(字节)")
    file_type: str = Field(..., description="文件类型")

    class Config:
        from_attributes = True

2. 实现业务方法

  • 安装python-multipart
uv add python-multipart
# app/service/document.py

import uuid
from pathlib import Path
from typing import Optional, List
from sqlalchemy import select, delete, exists
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.models import Document, KnowledgeBase
from app.schemas.document import DocumentResponse, DocumentListResponse, DocumentUploadResponse

class DocumentService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def _knowledge_base_exists(self, kb_id: int) -> bool:
        """检查知识库是否存在"""
        result = await self.db.execute(
            select(exists().where(KnowledgeBase.id == kb_id))
        )
        return result.scalar()

    async def upload_document(self,
            file_content: bytes,
            filename: str,
            knowledge_base_id: int,
            upload_dir: str = "uploads/documents"
    ) -> DocumentUploadResponse:
        """上传文档文件"""
        # 检查知识库是否存在
        if not await self._knowledge_base_exists(knowledge_base_id):
            raise ValueError(f"知识库 ID {knowledge_base_id} 不存在")

        # 检查同一知识库中是否已存在相同标题的文档
        title_exists_result = await self.db.execute(
            select(exists().where(
                Document.title == filename,
                Document.knowledge_base_id == knowledge_base_id
            ))
        )
        if title_exists_result.scalar():
            raise ValueError(f"知识库中已存在标题为 '{filename}' 的文档")

        # 创建上传目录
        upload_path = Path(upload_dir)
        upload_path.mkdir(parents=True, exist_ok=True)

        # 生成唯一文件名,避免文件名冲突
        file_extension = Path(filename).suffix
        unique_filename = f"{uuid.uuid4()}{file_extension}"
        file_path = upload_path / unique_filename

        # 保存文件
        try:
            with open(file_path, "wb") as f:
                f.write(file_content)
        except Exception as e:
            raise ValueError(f"文件保存失败: {str(e)}")

        # 获取文件信息
        file_size = len(file_content)
        file_type = file_extension.lstrip('.').lower() if file_extension else 'unknown'

        # 创建文档记录
        doc = Document(
            title=filename,
            source=str(file_path),
            knowledge_base_id=knowledge_base_id
        )
        self.db.add(doc)
        await self.db.commit()
        await self.db.refresh(doc)

        return DocumentUploadResponse(
            id=doc.id,
            title=doc.title,
            source=doc.source,
            knowledge_base_id=doc.knowledge_base_id,
            created_at=doc.created_at,
            file_size=file_size,
            file_type=file_type
        )

    async def get_by_id(self, doc_id: int) -> Optional[DocumentResponse]:
        """根据ID获取文档"""
        result = await self.db.execute(
            select(Document).where(Document.id == doc_id)
        )
        doc = result.scalar_one_or_none()

        if doc:
            return DocumentResponse(
                id=doc.id,
                title=doc.title,
                source=doc.source,
                knowledge_base_id=doc.knowledge_base_id,
                created_at=doc.created_at
            )
        return None

    async def get_by_knowledge_base(self, knowledge_base_id: int) -> List[DocumentListResponse]:
        """根据知识库ID获取文档列表"""
        # 先检查知识库是否存在
        if not await self._knowledge_base_exists(knowledge_base_id):
            raise ValueError(f"知识库 ID {knowledge_base_id} 不存在")

        result = await self.db.execute(
            select(Document)
            .where(Document.knowledge_base_id == knowledge_base_id)
            .order_by(Document.created_at.desc())
        )
        docs = result.scalars().all()

        return [
            DocumentListResponse(
                id=doc.id,
                title=doc.title,
                source=doc.source,
                knowledge_base_id=doc.knowledge_base_id,
                created_at=doc.created_at
            )
            for doc in docs
        ]

    async def delete(self, doc_id: int) -> bool:
        """删除文档(同时删除数据库记录和文件)"""
        # 先获取文档信息,包括文件路径
        result = await self.db.execute(
            select(Document).where(Document.id == doc_id)
        )
        doc = result.scalar_one_or_none()

        if not doc:
            return False

        # 删除文件(如果文件存在)
        if doc.source:
            try:
                file_path = Path(doc.source)
                if file_path.exists():
                    file_path.unlink()  # 删除文件
            except Exception as e:
                # 文件删除失败不应该阻止数据库记录删除
                print(f"警告:删除文件失败 {doc.source}: {str(e)}")

        # 删除数据库记录
        delete_result = await self.db.execute(
            delete(Document).where(Document.id == doc_id)
        )
        await self.db.commit()

        # 如果删除的行数大于0,说明文档存在并被删除
        return delete_result.rowcount() > 0

    async def delete_by_knowledge_base(self, knowledge_base_id: int) -> int:
        """删除指定知识库的所有文档(同时删除数据库记录和文件),返回删除的文档数量"""
        # 先检查知识库是否存在
        if not await self._knowledge_base_exists(knowledge_base_id):
            raise ValueError(f"知识库 ID {knowledge_base_id} 不存在")

        # 获取要删除的文档列表
        result = await self.db.execute(
            select(Document).where(Document.knowledge_base_id == knowledge_base_id)
        )
        docs = result.scalars().all()

        # 删除文件
        for doc in docs:
            if doc.source:
                try:
                    file_path = Path(doc.source)
                    if file_path.exists():
                        file_path.unlink()  # 删除文件
                except Exception as e:
                    raise ValueError(f"删除文件失败 {doc.source}: {str(e)}")


        # 删除数据库记录
        await self.db.execute(
            delete(Document).where(Document.knowledge_base_id == knowledge_base_id)
        )
        await self.db.commit()

        return len(docs)

3. 实现API

文档API直接在知识库的API代码类中实现

# app/api/knowledge_base.py

from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from app.schemas.document import (
    DocumentResponse,
    DocumentListResponse,
    DocumentUploadResponse
)
from app.service.document import DocumentService

# ==================== 文档管理路由 ====================

@router.get("/{kb_id}/documents", response_model=List[DocumentListResponse])
async def get_knowledge_base_documents(
    kb_id: int,
    db: AsyncSession = Depends(get_db)
):
    """获取指定知识库的所有文档"""
    doc_service = DocumentService(db)
    try:
        docs = await doc_service.get_by_knowledge_base(kb_id)
        return docs
    except ValueError as e:
        raise HTTPException(
            status_code=404,
            detail=str(e)
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"获取知识库文档列表失败: {str(e)}"
        )

@router.post("/{kb_id}/documents", response_model=DocumentUploadResponse, status_code=201)
async def upload_document_to_knowledge_base(
    kb_id: int,
    file: UploadFile = File(..., description="要上传的文档文件"),
    db: AsyncSession = Depends(get_db)
):
    """上传文档到指定知识库"""
    # 先检查知识库是否存在
    kb_service = KnowledgeBaseService(db)
    kb = await kb_service.get_by_id(kb_id)
    if not kb:
        raise HTTPException(
            status_code=404,
            detail=f"知识库 ID {kb_id} 不存在"
        )

    doc_service = DocumentService(db)

    # 验证文件
    if not file.filename:
        raise HTTPException(
            status_code=400,
            detail="文件名不能为空"
        )

    # 检查文件大小(限制为 10MB)
    file_content = await file.read()
    if len(file_content) > 10 * 1024 * 1024:  # 10MB
        raise HTTPException(
            status_code=400,
            detail="文件大小不能超过 10MB"
        )

    # 检查文件类型(允许常见的文档类型)
    allowed_extensions = {'.txt', '.pdf', '.doc', '.docx', '.md', 'html'}
    file_extension = '.' + file.filename.split('.')[-1].lower() if '.' in file.filename else ''
    if file_extension not in allowed_extensions:
        raise HTTPException(
            status_code=400,
            detail=f"不支持的文件类型。支持的类型: {', '.join(allowed_extensions)}"
        )

    try:
        result = await doc_service.upload_document(
            file_content=file_content,
            filename=file.filename,
            knowledge_base_id=kb_id
        )
        return result
    except ValueError as e:
        raise HTTPException(
            status_code=400,
            detail=str(e)
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"文档上传失败: {str(e)}"
        )

@router.get("/documents/{doc_id}", response_model=DocumentResponse)
async def get_document_from_knowledge_base(
    doc_id: int,
    db: AsyncSession = Depends(get_db)
):
    """获取知识库中的特定文档"""
    doc_service = DocumentService(db)
    doc = await doc_service.get_by_id(doc_id)
    if not doc:
        raise HTTPException(
            status_code=404,
            detail=f"文档 ID {doc_id} 不存在"
        )

    return doc

@router.delete("/documents/{doc_id}", status_code=204)
async def delete_document_from_knowledge_base(
    doc_id: int,
    db: AsyncSession = Depends(get_db)
):
    """删除指定知识库中的特定文档"""
    doc_service = DocumentService(db)
    doc = await doc_service.get_by_id(doc_id)
    if not doc:
        raise HTTPException(
            status_code=404,
            detail=f"文档 ID {doc_id} 不存在"
        )

    try:
        success = await doc_service.delete(doc_id)
        if not success:
            raise HTTPException(
                status_code=404,
                detail=f"文档 ID {doc_id} 不存在"
            )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"删除文档失败: {str(e)}"
        )

@router.delete("/{kb_id}/documents", status_code=200)
async def delete_all_documents_from_knowledge_base(
    kb_id: int,
    db: AsyncSession = Depends(get_db)
):
    """删除指定知识库的所有文档"""
    doc_service = DocumentService(db)
    try:
        deleted_count = await doc_service.delete_by_knowledge_base(kb_id)
        return {
            "knowledge_base_id": kb_id,
            "deleted_count": deleted_count,
            "message": f"成功删除 {deleted_count} 个文档"
        }
    except ValueError as e:
        raise HTTPException(
            status_code=404,
            detail=str(e)
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"删除知识库文档失败: {str(e)}"
        )

@router.get("/{kb_id}/documents/count")
async def get_knowledge_base_document_count(
    kb_id: int,
    db: AsyncSession = Depends(get_db)
):
    """获取知识库的文档数量"""
    service = KnowledgeBaseService(db)

    try:
        count = await service.get_document_count(kb_id)
        return {"knowledge_base_id": kb_id, "document_count": count}
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"获取文档数量失败: {str(e)}"
        )