通用AI聊天平台实现1:基础服务与API
2025-07-12 14:39:31一、项目介绍
1. 项目目标
Python + 面向单一用户 + 私有化部署 + 通用AI聊天平台
- 通过Web界面轻松交互
- 管理多个自定义AI角色
- 拥有会话管理能力
- 自由接入并切换语言模型(兼容OpenAl接口规范)
- 构建私有化知识库
- 扩展外部工具:集成MCP服务器
2. 功能性需求
- 多角色管理功能
- 对话管理功能
- 模型管理功能
- RAG功能
- 前端界面(会话、知识库、模型设置、角色设置)
- MCP 服务器接入
3. 技术选型
- 包与项目管理工具:NuGet+ dotnet cli == uv
- 后端API框架与服务器:ASP.NET Core Web APl(Kestrel) == FastAPl + Uvicorn
- 数据模型与验证:数据注解/FluentValidation == Pydantic
- ORM + 数据迁移:EF Core == SQLAlchemy + Alembic
- 前端UI框架:Blazor Server == Gradio
- 元数据库(关系):PgSQL
- 向量数据库:Chroma
- 嵌入生成服务:云端向量服务 + 本地开源向量模型
- AI组件:SK == LangChain
4. 项目架构
该项目采用分层架构(因为是第一个实例项目,没有采用领域驱动设计)
二、项目初始化
- 安装UV
# https://docs.astral.sh/uv/getting-started/installation/#__tabbed_1_2
pip install uv
# uv的依赖环境恢复可以使用 uv sync
- 创建UV项目

- 项目初始化
创建app和ui文件夹,按下面目录结构分别创建Python软件包和文件:
app
|--ai
|--api
|--data
|--|--models
|--schemas
|--service
|--.env
|--config.py
|--main.py
ui
- 安装 fastapi
uv add fastapi
- 安装 uvicorn
uv add uvicorn
- 创建HelloWorld API
# app/api/hello.py
# 从 fastapi 库中导入 APIRouter 类,用于创建路由模块
from fastapi import APIRouter
# 创建一个 APIRouter 的实例
# 这就像在 ASP.NET Core 中定义一个新的 Controller 类
router = APIRouter()
# 使用装饰器来定义一个 API 路由
# @router.post(...) 表示这是一个处理 HTTP POST 请求的端点
@router.get("/hello")
# 定义一个异步函数来处理这个请求
# 函数名 hello 是我们自己取的,它会显示在自动生成的 API 文档中
async def hello():
return "HelloWorld"
- 配置api端点和入库函数
# app/main.py
# 从 fastapi 库导入 FastAPI 类,这是创建应用的核心
from fastapi import FastAPI
# 从我们创建的 api 子目录中导入 hello 模块(也就是 hello.py 文件)
from api import hello
# 创建一个 FastAPI 应用的实例
# 类似 Program.cs 中调用 WebApplication.CreateBuilder() 和 builder.Build()
app = FastAPI(
# title 是一个可选参数,它会显示在自动生成的 API 文档的标题位置
title="AI 聊天平台",
# version 是版本号,同样会显示在文档中
version="0.1.0",
)
# 使用 app.include_router() 方法来将我们创建的角色路由包含到主应用中
# 这类似于在 ASP.NET Core 中调用 app.MapControllers()
app.include_router(
# 第一个参数是我们想要包含的路由对象
hello.router,
# prefix 参数为这个路由下的所有端点 URL 添加一个统一的前缀
# 这样,我们 hello.py 中的 "/hello" 就会变成 "/api/hello"
prefix="/api",
# tags 参数用于在 API 文档中对端点进行分组
# 所有来自 hello.router 的端点都会被归类到 "测试" 这个标签下
tags=["测试"],
)
# 定义一个根路径的 GET 请求处理函数
# 这通常用于做一个简单的健康检查,确认服务是否正常运行
@app.get("/")
async def read_root():
# 返回一个简单的 JSON 对象
return {"message": "欢迎使用通用 AI 聊天平台 API"}
# 判断当前模块是否为主程序入口模块
if __name__ == "__main__":
# 导入 uvicorn 模块
import uvicorn
# 使用 uvicorn 运行当前应用,指定应用路径、主机地址、端口和热重载功能
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
- 运行

三、构建 SQLAlchemy 模型
1. 数据模型
AI角色:AI 配置、名称、提示词、温度、模型
模型提供商:调用端点、名称、APIKey
会话:关联角色、知识库
消息:用户与 AI 的消息记录,关联会话
知识库:技术文档库、员工手册库
文档:文档基本信息(元数据),标题、来源、所属知识库
文档片段:文本分片的源信息
2. 环境配置
- 安装包
# 安装 sqlalchemy
uv add sqlalchemy psycopg2-binary asyncpy
# 安装 dotenv
uv add python-dotenv
SQLAlchemy介绍
- 引擎:与数据库建立连接、发出SQL命令
- 会话:负责跟踪对象的状态、管理事务
- 基类:所有模型类都要继承基类
- 模型类:每个模型类对应数据库中的一张表
- 关系:模型与模型之间的关系
- 模式基类
# app/data/models/base.py
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
"""
SQLAlchemy 声明式基类
这个基类是所有 SQLAlchemy 数据模型的父类。它继承自 SQLAlchemy 的 DeclarativeBase,
为应用程序提供了一个统一的基类来定义数据库表结构。
使用声明式风格可以让我们通过 Python 类来定义数据库表,SQLAlchemy 会自动处理类与数据库表之间的映射关系。
"""
pass
- 配置数据库链接
# docker创建pgsql容器
docker run -d --name postgres -e POSTGRES_PASSWORD=123456 -e POSTGRES_DB=aichat-demo -p 5432:5432 -v pgdata:/var/lib/postgresql/data --restart unless-stopped postgres:latest
# app/.env
DATABASE_URL=postgresql+asyncpg://postgres:123456@localhost:5432/aichat-demo
# app/config.py
import os
from dotenv import load_dotenv
load_dotenv()
DATABASE_URL = os.getenv("DATABASE_URL")
- 数据库上下文对象
# app/data/db.py
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from app.config import DATABASE_URL
# 创建异步 Engine
engine = create_async_engine(DATABASE_URL, echo=True)
# 异步 Session 工厂
SessionLocal = async_sessionmaker(
engine,
expire_on_commit=False
)
# FastAPI 依赖注入工具函数
async def get_db():
async with SessionLocal() as session:
yield session
3. 数据模型代码实现
- 提供商
# app/data/models/provider.py
# -------------------------
# 模型提供商
# -------------------------
from sqlalchemy import Column, Integer, String
from app.data.models.base import Base
class Provider(Base):
__tablename__ = "providers"
id = Column(Integer, primary_key=True) # 主键
name = Column(String, nullable=False) # 提供商名称,例如 OpenAI
endpoint = Column(String, nullable=False) # API 调用地址
model = Column(String, nullable=False) # 模型名称,例如 gpt-4
api_key = Column(String, nullable=False) # API Key
- 知识库
# app/data/models/knowledge_base.py
# ---------------------------
# 知识库
# ---------------------------
from sqlalchemy import Column, Integer, String, Text
from sqlalchemy.orm import relationship
from app.data.models.base import Base
class KnowledgeBase(Base):
__tablename__ = "knowledge_bases"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False) # 知识库名称
description = Column(Text, nullable=True) # 知识库描述
# 定义关系:知识库 -> 角色(多对多)
roles = relationship("Role", secondary="role_knowledge", back_populates="knowledge_bases")
# 定义关系:知识库 -> 会话(多对多)
sessions = relationship("Session", secondary="session_knowledge", back_populates="knowledge_bases")
# 定义关系:知识库 -> 文档(一对多)
documents = relationship("Document", back_populates="knowledge_base", cascade="all, delete-orphan")
- 文档
# app/data/models/document.py
# ---------------------------
# 文档
# ---------------------------
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from app.data.models.base import Base
class Document(Base):
__tablename__ = "documents"
id = Column(Integer, primary_key=True)
title = Column(String(200), nullable=False) # 文档标题
source = Column(String(255), nullable=True) # 文档来源或路径
created_at = Column(DateTime, default=datetime.now()) # 创建时间
# 关联知识库ID
knowledge_base_id = Column(Integer, ForeignKey("knowledge_bases.id"), nullable=False)
# 定义关系:文档 -> 知识库(多对一)
knowledge_base = relationship("KnowledgeBase", back_populates="documents")
- 角色
# app/data/models/role.py
# -------------------------
# 角色
# -------------------------
from sqlalchemy import Integer, Column, String, Float, ForeignKey
from sqlalchemy.orm import relationship
from app.data.models.base import Base
class Role(Base):
__tablename__ = "roles"
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False) # 角色名称
description = Column(String, nullable=True) # 描述
system_prompt = Column(String, nullable=True) # 系统提示词
temperature = Column(Float, default=0.7) # 模型温度
# 关联模型提供商ID
provider_id = Column(Integer, ForeignKey("providers.id"), nullable=False)
# 定义模型关系:角色 -> 提供商(多对一)
provider = relationship("Provider")
# 定义模型关系:角色 -> 会话(一对多)
sessions = relationship("Session", back_populates="role", cascade="all, delete-orphan")
# 定义关系:角色 -> 知识库(多对多)
knowledge_bases = relationship("KnowledgeBase", secondary="role_knowledge", back_populates="roles")
# app/data/models/role_knowledge.py
# ---------------------------
# 角色-知识库 关系表(多对多中间表)
# ---------------------------
from sqlalchemy import Table, Column, Integer, ForeignKey
from app.data.models.base import Base
role_knowledge = Table(
"role_knowledge",
Base.metadata,
Column("role_id", Integer, ForeignKey("roles.id"), primary_key=True),
Column("knowledge_base_id", Integer, ForeignKey("knowledge_bases.id"), primary_key=True)
)
- 会话
# app/data/models/session.py
# ---------------------------
# 会话
# ---------------------------
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from app.data.models.base import Base
class Session(Base):
__tablename__ = "sessions"
id = Column(Integer, primary_key=True)
title = Column(String(100), nullable=False) # 标题
created_at = Column(DateTime, default=datetime.now()) # 创建时间
# 关联角色ID
role_id = Column(Integer, ForeignKey("roles.id"), nullable=False)
# 定义模型关系:会话 -> 角色(多对一)
role = relationship("Role", back_populates="sessions")
# 定义关系:会话 -> 消息(一对多)
messages = relationship("Message", back_populates="session", cascade="all, delete-orphan")
# 定义关系:会话 -> 知识库(多对多)
knowledge_bases = relationship("KnowledgeBase", secondary="session_knowledge", back_populates="sessions")
# app/data/models/session_knowledge.py
# ---------------------------
# 会话-知识库 关系表(多对多中间表)
# ---------------------------
from sqlalchemy import Table, Column, Integer, ForeignKey
from app.data.models.base import Base
session_knowledge = Table(
"session_knowledge",
Base.metadata,
Column("session_id", Integer, ForeignKey("sessions.id"), primary_key=True),
Column("knowledge_base_id", Integer, ForeignKey("knowledge_bases.id"), primary_key=True)
)
- 消息
# app/data/models/message.py
# ---------------------------
# 消息
# ---------------------------
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey
from sqlalchemy.orm import relationship
from app.data.models.base import Base
class Message(Base):
__tablename__ = "messages"
id = Column(Integer, primary_key=True)
role = Column(String(20), nullable=False) # 消息角色类型:user/assistant
content = Column(Text, nullable=False) # 消息内容
created_at = Column(DateTime, default=datetime.now()) # 创建时间
# 关联会话ID
session_id = Column(Integer, ForeignKey("sessions.id"), nullable=False)
# 定义关系:消息 -> 会话(多对一)
session = relationship("Session", back_populates="messages")
4. 数据库迁移
- 安装包
# 安装 alembic
uv add alembic
# 初始化环境
alembic init alembic
# 在 alembic.ini 文件中修改数据库链接字符串
# 主意:这里的同步方式,项目文件.env中的链接字符串是异步方式
sqlalchemy.url = postgresql+psycopg2://postgres:123456@localhost:5432/aichat-demo
- 设置metadata对象
# app/data/models/__init__.py
# 合并model
from .base import Base
from .provider import Provider
from .role import Role
from .session import Session
from .message import Message
from .knowledge_base import KnowledgeBase
from .document import Document
from .role_knowledge import role_knowledge
from .session_knowledge import session_knowledge
# alembic/.env
# 设置 target_metadata
#21行 target_metadata = None
from app.data import models
target_metadata = models.Base.metadata
- 生成迁移
alembic revision --autogenerate -m "Initial migration"
- 应用迁移
alembic upgrade head
四、基础API服务
1. 定义模型提供商DTO
# app/schemas/provider.py
from typing import Optional
from pydantic import BaseModel, Field
class ProviderBase(BaseModel):
"""Provider基础DTO模型"""
name: str = Field(..., max_length=100, description="提供商名称")
endpoint: str = Field(..., max_length=255, description="API调用地址")
model: str = Field(..., max_length=100, description="模型名称")
api_key: str = Field(..., description="API密钥")
class ProviderCreate(ProviderBase):
"""创建Provider的DTO模型"""
pass
class ProviderResponse(ProviderBase):
"""Provider响应DTO模型"""
id: int = Field(..., description="提供商ID")
class Config:
from_attributes = True
class ProviderUpdate(BaseModel):
"""更新Provider的DTO模型"""
name: Optional[str] = Field(None, max_length=100, description="提供商名称")
endpoint: Optional[str] = Field(None, max_length=255, description="API调用地址")
model: Optional[str] = Field(None, max_length=100, description="模型名称")
api_key: Optional[str] = Field(None , description="API密钥")
2. 实现模型提供商业务方法
# app/service/provider.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from app.data.models import Provider
from app.schemas.provider import ProviderResponse, ProviderCreate, ProviderUpdate
class ProviderService:
def __init__(self, db: AsyncSession):
self.db = db
async def create(self, provider_create: ProviderCreate) -> ProviderResponse:
"""创建Provider"""
# 1. 将 DTO 转换为字典
provider_data = provider_create.model_dump()
# 2. 使用字典解包 (**) 来创建 ORM 模型实例
provider = Provider(**provider_data)
# 3. 将新创建的 ORM 对象添加到数据库会话中
self.db.add(provider)
await self.db.commit()
await self.db.refresh(provider)
# 刷新实例,以获取数据库生成的值 (如 id, created_at)
return ProviderResponse.model_validate(provider)
async def get_by_id(self, provider_id: int) -> Optional[ProviderResponse]:
"""根据ID获取Provider"""
# 1. 使用 session.get() 方法通过主键高效查询
provider = await self.db.get(Provider, provider_id)
# 2. 检查是否找到了记录
if provider:
# 3. 如果找到,将 ORM 模型转换为 Pydantic 响应模型并返回
return ProviderResponse.model_validate(provider)
# 4. 如果没找到,返回 None
return None
async def get_by_name(self, name: str) -> Optional[ProviderResponse]:
"""根据名称获取Provider"""
# 1. 执行一个 SELECT 查询
result = await self.db.execute(
select(Provider).where(Provider.name == name)
)
# 2. 从结果中获取单个标量(Scalar)对象
provider = result.scalar_one_or_none()
# 3. 检查、转换并返回,逻辑同 get_by_id
if provider:
return ProviderResponse.model_validate(provider)
return None
async def get_all(self) -> List[ProviderResponse]:
"""获取所有Provider"""
# 1. 执行一个查询,获取所有 Provider,并按 ID 降序排序
result = await self.db.execute(
select(Provider)
.order_by(Provider.id.desc())
)
# 2. 获取所有结果行中的标量对象
providers = result.scalars().all()
# 3. 使用列表推导式将所有 ORM 对象转换为 Pydantic 响应模型
provider_responses = [ProviderResponse.model_validate(p) for p in providers]
# 4. 返回 Pydantic 模型组成的列表
return provider_responses
async def update(self, provider_id: int, provider_data: ProviderUpdate) -> Optional[ProviderResponse]:
"""更新Provider"""
# 1. 检查要更新的 Provider 是否存在
existing_provider = await self.get_by_id(provider_id)
if not existing_provider:
return None
# 2. 将 Pydantic 输入模型转换为用于更新的字典
update_data = provider_data.model_dump(exclude_unset=True)
# 3. 如果有数据需要更新,则执行数据库操作
if update_data:
await self.db.execute(
update(Provider)
.where(Provider.id == provider_id)
.values(**update_data)
)
await self.db.commit()
# 4. 重新获取并返回更新后的完整 Provider 数据
return await self.get_by_id(provider_id)
async def delete(self, provider_id: int) -> bool:
"""删除Provider"""
# 1. 检查要删除的 Provider 是否存在
existing_provider = await self.get_by_id(provider_id)
if not existing_provider:
return False
# 2. 执行删除操作
await self.db.execute(
delete(Provider).where(Provider.id == provider_id)
)
# 3. 提交事务,使删除生效
await self.db.commit()
# 4. 返回 True 表示成功
return True
**3. 实现模型提供商 API **
# app/api/provider.py
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.db import get_db
from app.schemas.provider import ProviderResponse, ProviderCreate, ProviderUpdate
from app.service.provider import ProviderService
router = APIRouter(prefix="/providers", tags=["模型提供商管理"])
@router.post("/", response_model=ProviderResponse, status_code=201)
async def create_provider(provider_data: ProviderCreate, db: AsyncSession = Depends(get_db)):
"""创建新的模型提供商"""
service = ProviderService(db)
# 检查名称是否已存在
existing_provider = await service.get_by_name(provider_data.name)
if existing_provider:
raise HTTPException(
status_code=400,
detail=f"提供商名称 '{provider_data.name}' 已存在"
)
try:
provider = await service.create(provider_data)
return provider
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"创建模型提供商失败: {str(e)}"
)
@router.get("/", response_model=List[ProviderResponse])
async def get_providers(db: AsyncSession = Depends(get_db)):
"""获取所有模型提供商"""
service = ProviderService(db)
try:
providers = await service.get_all()
return providers
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取提供商列表失败: {str(e)}"
)
@router.get("/{provider_id}", response_model=ProviderResponse)
async def get_provider(
provider_id: int,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取模型提供商"""
service = ProviderService(db)
provider = await service.get_by_id(provider_id)
if not provider:
raise HTTPException(
status_code=404,
detail=f"提供商 ID {provider_id} 不存在"
)
return provider
@router.put("/{provider_id}", response_model=ProviderResponse)
async def update_provider(
provider_id: int,
provider_data: ProviderUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新模型提供商"""
service = ProviderService(db)
# 如果更新名称,检查新名称是否已存在
if provider_data.name:
existing_provider = await service.get_by_name(provider_data.name)
if existing_provider and existing_provider.id != provider_id:
raise HTTPException(
status_code=400,
detail=f"提供商名称 '{provider_data.name}' 已存在"
)
try:
provider = await service.update(provider_id, provider_data)
if not provider:
raise HTTPException(
status_code=404,
detail=f"提供商 ID {provider_id} 不存在"
)
return provider
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"更新提供商失败: {str(e)}"
)
@router.delete("/{provider_id}", status_code=204)
async def delete_provider(
provider_id: int,
db: AsyncSession = Depends(get_db)
):
"""删除模型提供商"""
service = ProviderService(db)
try:
success = await service.delete(provider_id)
if not success:
raise HTTPException(
status_code=404,
detail=f"提供商 ID {provider_id} 不存在"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"删除提供商失败: {str(e)}"
)
4. 配置模型提供商路由
# app/main.py
from app.api import provider
# 注册Provider路由
app.include_router(
provider.router,
prefix="/api"
)
配置自动启动文件
PyCharm 配置
点击 Run > Edit Configurations
添加 Python 配置
设置 Script path 为你的主文件路径
运行
可能会报找不到 asyncpg,安装一下 uv add asyncpg
5. 实现角色管理服务与API
- 定义DTO
# app/schemas/role.py
from typing import Optional
from pydantic import BaseModel, Field
class RoleBase(BaseModel):
"""Role基础DTO模型"""
name: str = Field(..., max_length=100, description="角色名称")
description: Optional[str] = Field(None, max_length=500, description="角色描述")
system_prompt: Optional[str] = Field(None, description="系统提示词")
temperature: float = Field(0.7, ge=0.0, le=2.0, description="模型温度")
provider_id: int = Field(..., description="关联的模型提供商ID")
class RoleCreate(RoleBase):
"""创建Role的DTO模型"""
pass
class RoleResponse(RoleBase):
"""Role响应DTO模型"""
id: int = Field(..., description="角色ID")
class Config:
from_attributes = True
class RoleUpdate(BaseModel):
"""更新Role的DTO模型"""
name: Optional[str] = Field(None, max_length=100, description="角色名称")
description: Optional[str] = Field(None, max_length=500, description="角色描述")
system_prompt: Optional[str] = Field(None, description="系统提示词")
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="模型温度")
provider_id: Optional[int] = Field(None, description="关联的模型提供商ID")
- 实现业务方法
# app/service/role.py
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, update
from app.data.models import Role, Provider
from app.schemas.role import RoleCreate, RoleResponse, RoleUpdate
class RoleService:
def __init__(self, db: AsyncSession):
self.db = db
async def create(self, role_create: RoleCreate) -> RoleResponse:
"""创建Role"""
# 验证provider_id是否存在
provider = await self.db.get(Provider, role_create.provider_id)
if not provider:
raise ValueError(f"提供商 ID {role_create.provider_id} 不存在")
role_data = role_create.model_dump()
role = Role(**role_data)
self.db.add(role)
await self.db.commit()
await self.db.refresh(role)
return RoleResponse.model_validate(role)
async def get_by_id(self, role_id: int) -> Optional[RoleResponse]:
"""根据ID获取Role"""
role = await self.db.get(Role, role_id)
if role:
return RoleResponse.model_validate(role)
return None
async def get_by_name(self, name: str) -> Optional[RoleResponse]:
"""根据名称获取Role"""
result = await self.db.execute(
select(Role).where(Role.name == name)
)
role = result.scalar_one_or_none()
if role:
return RoleResponse.model_validate(role)
return None
async def get_all(self) -> List[RoleResponse]:
"""获取所有Role"""
result = await self.db.execute(
select(Role)
.order_by(Role.id.desc())
)
roles = result.scalars().all()
role_responses = [RoleResponse.model_validate(r) for r in roles]
return role_responses
async def update(self, role_id: int, role_data: RoleUpdate) -> Optional[RoleResponse]:
"""更新Role"""
existing_role = await self.get_by_id(role_id)
if not existing_role:
return None
if role_data.provider_id is not None:
provider = await self.db.get(Provider, role_data.provider_id)
if not provider:
raise ValueError(f"提供商 ID {role_data.provider_id} 不存在")
# 构建更新数据
update_data = role_data.model_dump(exclude_unset=True)
if update_data:
await self.db.execute(
update(Role)
.where(Role.id == role_id)
.values(**update_data)
)
await self.db.commit()
return await self.get_by_id(role_id)
async def delete(self, role_id: int) -> bool:
"""删除Role"""
existing_role = await self.get_by_id(role_id)
if not existing_role:
return False
await self.db.execute(
delete(Role).where(Role.id == role_id)
)
await self.db.commit()
return True
- 实现 API
# app/api/role.py
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.db import get_db
from app.schemas.role import RoleResponse, RoleCreate, RoleUpdate
from app.service.role import RoleService
router = APIRouter(prefix="/roles", tags=["角色管理"])
@router.post("/", response_model=RoleResponse, status_code=201)
async def create_role(role_data: RoleCreate, db: AsyncSession = Depends(get_db)):
"""创建新的角色"""
service = RoleService(db)
# 检查名称是否已存在
existing_role = await service.get_by_name(role_data.name)
if existing_role:
raise HTTPException(
status_code=400,
detail=f"角色名称 '{role_data.name}' 已存在"
)
try:
role = await service.create(role_data)
return role
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"创建角色失败: {str(e)}"
)
@router.get("/", response_model=List[RoleResponse])
async def get_roles(db: AsyncSession = Depends(get_db)):
"""获取所有角色"""
service = RoleService(db)
try:
roles = await service.get_all()
return roles
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取角色列表失败: {str(e)}"
)
@router.get("/{role_id}", response_model=RoleResponse)
async def get_role(
role_id: int,
db: AsyncSession = Depends(get_db)
):
"""根据ID获取角色"""
service = RoleService(db)
role = await service.get_by_id(role_id)
if not role:
raise HTTPException(
status_code=404,
detail=f"角色 ID {role_id} 不存在"
)
return role
@router.put("/{role_id}", response_model=RoleResponse)
async def update_role(
role_id: int,
role_data: RoleUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新角色"""
service = RoleService(db)
# 如果更新名称,检查新名称是否已存在
if role_data.name:
existing_role = await service.get_by_name(role_data.name)
if existing_role and existing_role.id != role_id:
raise HTTPException(
status_code=400,
detail=f"角色名称 '{role_data.name}' 已存在"
)
try:
role = await service.update(role_id, role_data)
if not role:
raise HTTPException(
status_code=404,
detail=f"角色 ID {role_id} 不存在"
)
return role
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e)
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"更新角色失败: {str(e)}"
)
@router.delete("/{role_id}", status_code=204)
async def delete_role(
role_id: int,
db: AsyncSession = Depends(get_db)
):
"""删除角色"""
service = RoleService(db)
try:
success = await service.delete(role_id)
if not success:
raise HTTPException(
status_code=404,
detail=f"角色 ID {role_id} 不存在"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"删除角色失败: {str(e)}"
)
- 配置路由
# app/main.py
from app.api import provider,role
# 注册Role路由
app.include_router(
role.router,
prefix="/api"
)
6. 实现会话与消息管理服务与API
- 定义DTO
# app/schemas/session.py
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field
class SessionBase(BaseModel):
"""Session基础DTO模型"""
title: str = Field(..., max_length=100, description="会话标题")
role_id: int = Field(..., description="关联的角色ID")
class SessionCreate(SessionBase):
"""创建Session的DTO模型"""
pass
class SessionResponse(SessionBase):
"""Session响应DTO模型"""
id: int = Field(..., description="会话ID")
created_at: datetime = Field(..., description="创建时间")
message_count: Optional[int] = Field(None, description="消息数量")
class Config:
from_attributes = True
class SessionUpdate(BaseModel):
"""更新Session的DTO模型"""
title: Optional[str] = Field(None, max_length=100, description="会话标题")
role_id: Optional[int] = Field(None, description="关联的角色ID")
class SessionListResponse(BaseModel):
"""Session列表响应DTO模型"""
id: int = Field(..., description="会话ID")
title: str = Field(..., description="会话标题")
created_at: datetime = Field(..., description="创建时间")
role_name: str = Field(..., description="角色名称")
class Config:
from_attributes = True
# app/schemas/message.py
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
class MessageRole(str, Enum):
"""消息角色枚举"""
User = "user"
Assistant = "assistant"
System = "system"
class MessageBase(BaseModel):
"""Message基础DTO模型"""
role: MessageRole = Field(..., description="消息角色:user/assistant/system")
content: str = Field(..., description="消息内容")
session_id: int = Field(..., description="关联的会话ID")
class MessageResponse(MessageBase):
"""Message响应DTO模型"""
id: int = Field(..., description="消息ID")
created_at: datetime = Field(..., description="创建时间")
class Config:
from_attributes = True
class ChatRequest(BaseModel):
"""聊天请求DTO模型"""
message: str = Field(..., description="用户消息内容")
session_id: int = Field(..., description="会话ID")
- 实现业务方法
# app/service/session.py
from typing import List
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.models import Session, Role, Provider, Message
from app.schemas.session import SessionCreate, SessionResponse, SessionListResponse
class SessionService:
def __init__(self, db: AsyncSession):
self.db = db
async def create(self, session_create: SessionCreate) -> SessionResponse:
"""创建Session"""
role = await self.db.get(Role, session_create.role_id)
if not role:
raise ValueError(f"角色 ID {session_create.role_id} 不存在")
session_data = session_create.model_dump()
session = Session(**session_data)
self.db.add(session)
await self.db.commit()
await self.db.refresh(session)
return SessionResponse.model_validate(session)
async def get_all(self) -> List[SessionListResponse]:
"""获取所有Session,包含角色"""
result = await self.db.execute(
select(
Session.id,
Session.title,
Session.created_at,
Role.name.label('role_name')
)
.join(Role, Session.role_id == Role.id)
.group_by(Session.id, Session.title, Session.created_at, Role.name, Provider.name)
.order_by(Session.created_at.desc())
)
sessions = result.all()
return [
SessionListResponse(
id=session.id,
title=session.title,
created_at=session.created_at,
role_name=session.role_name
)
for session in sessions
]
async def delete(self, session_id: int) -> bool:
"""删除Session"""
await self.db.execute(
delete(Session).where(Session.id == session_id)
)
await self.db.commit()
return True
# app/service/message.py
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
from sqlalchemy.orm import selectinload
from app.data.models import Message, Session, Role
from app.schemas.message import MessageRole, MessageBase, MessageResponse, ChatRequest
class MessageService:
def __init__(self, db: AsyncSession):
self.db = db
async def create(self, message: MessageBase) -> MessageResponse:
"""创建Message"""
session = await self.db.get(Session, message.session_id)
if not session:
raise ValueError(f"会话 ID {message.session_id} 不存在")
message_data = message.model_dump()
message = Message(**message_data)
self.db.add(message)
await self.db.commit()
await self.db.refresh(message)
return MessageResponse.model_validate(message)
async def get_conversation_history(self, session_id: int, limit: int = 50) -> List[MessageResponse]:
"""获取会话历史记录"""
result = await self.db.execute(
select(Message)
.where(Message.session_id == session_id)
.order_by(Message.id)
.limit(limit)
)
messages = result.scalars().all()
# 按时间正序返回
return [MessageResponse.model_validate(msg) for msg in messages]
async def chat(self, chat_request: ChatRequest) -> MessageResponse:
"""处理聊天请求"""
# 验证session是否存在
session = await self.db.get(Session, chat_request.session_id)
if not session:
raise ValueError(f"会话 ID {chat_request.session_id} 不存在")
# 获取会话绑定的角色与模型信息
role = await self.db.scalar(
select(Role)
.options(selectinload(Role.provider)) # 贪婪加载 provider
.where(Role.id == session.role_id)
)
if not role:
raise ValueError(f"角色【{role.name}】不存在")
# 获取消息历史
conversation_history = await self.get_conversation_history(chat_request.session_id)
# 这里应该调用AI服务来生成回复
# 目前先返回一个模拟的回复
assistant_content = (f"我是'{role.name}',我们聊了'{len(conversation_history)}'条消息,收到了您的消息:'{chat_request.message}'。"
f"这是一个模拟回复,实际应用中这里会调用AI模型生成回复。")
# 创建用户消息
user_message = MessageBase(
role=MessageRole.User,
content=chat_request.message,
session_id=chat_request.session_id
)
await self.create(user_message)
# 创建助手回复
assistant_message = MessageBase(
role=MessageRole.Assistant,
content=assistant_content,
session_id=chat_request.session_id
)
response = await self.create(assistant_message)
return response
async def delete(self, message_id: int) -> bool:
"""删除Message"""
await self.db.execute(
delete(Message).where(Message.id == message_id)
)
await self.db.commit()
return True
- 实现 API
# app/api/session.py
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.data.db import get_db
from app.schemas.message import MessageResponse, ChatRequest
from app.schemas.session import SessionResponse, SessionCreate
from app.service.message import MessageService
from app.service.session import SessionService
router = APIRouter(prefix="/sessions", tags=["会话管理"])
@router.post("/", response_model=SessionResponse, status_code=201)
async def create_session(session_data: SessionCreate, db: AsyncSession = Depends(get_db)):
"""创建新的会话"""
service = SessionService(db)
try:
session = await service.create(session_data)
return session
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"创建会话失败: {str(e)}"
)
@router.delete("/{session_id}", status_code=204)
async def delete_session(
session_id: int,
db: AsyncSession = Depends(get_db)
):
"""删除会话"""
service = SessionService(db)
try:
success = await service.delete(session_id)
if not success:
raise HTTPException(
status_code=404,
detail=f"会话 ID {session_id} 不存在"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"删除会话失败: {str(e)}"
)
@router.get("/{session_id}/messages", response_model=List[MessageResponse])
async def get_messages(
session_id: int,
limit: int = Query(50, ge=1, le=100, description="限制返回的消息数量"),
db: AsyncSession = Depends(get_db)
):
"""根据会话ID获取消息列表"""
service = MessageService(db)
try:
messages = await service.get_conversation_history(session_id, limit)
return messages
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取消息列表失败: {str(e)}"
)
@router.delete("/messages/{message_id}", status_code=204)
async def delete_message(
message_id: int,
db: AsyncSession = Depends(get_db)
):
"""删除消息"""
service = MessageService(db)
try:
success = await service.delete(message_id)
if not success:
raise HTTPException(
status_code=404,
detail=f"消息 ID {message_id} 不存在"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"删除消息失败: {str(e)}"
)
@router.post("/chat", response_model=MessageResponse)
async def chat(chat_request: ChatRequest, db: AsyncSession = Depends(get_db)):
"""处理聊天请求"""
service = MessageService(db)
try:
response = await service.chat(chat_request)
return response
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"处理聊天请求失败: {str(e)}"
)
- 配置路由
# app/main.py
from app.api import provider,role,session
# 注册Session路由
app.include_router(
session.router,
prefix="/api"
)

