Spiga

Net企业级AI项目4:NL2SQL

2026-01-31 17:35:09

NL2SQL(Natural Language to SQL) 是一项将人类自然语言问题自动转换为结构化 SQL 查询语句的技术。简单来说,就是让用户用大白话(如“上个月销售额最高的产品是什么?”)直接查询数据库,无需学习 SQL 语法。

今天我们的任务就是给我们的企业智能助理扩展 NL2SQL 的能力。

一、动态数据源架构

1. 面向多数据库源的架构

一家企业会有很多不同的系统,而各个系统所用的数据库会是各种各样的,它们不仅数据库产品不一样,同产品的数据库可能版本也不一样,查询数据的时候可能存在一些差别。

我们智能助理需要拥有以下这些功能:

  • 动态配置多个业务数据库
  • 通过自然语言,分析全局源数据,定位到需要链接的数据库(如 ERP\HR\OA)
  • 最小权限连接上数据库(只读,禁止使用 DDL、DML)
  • 分析目标库内部的表结构信息
  • 方言适配(不同类型的数据库 SQL 语法差异)

ps:只做只读的实现,并不是模型不能生成插入、更新、删除等 SQL,主要是不建议用自然语言来做数据库的操作,主要是:

  • 模型会有幻觉,操作数据库存在风险
  • 无法保证事务
  • 无法触发一些关联任务,比如生成订单后,需要通知扣减库存。如果是直接操作订单表,库存的扣减事件就无法触发

如果需要实现自然语音操作数据,可以使用工具调用或者 MCP 来完成,我们的项目只实现生成 SQL 查询语句。

2. ReAct 认知模型

ReAct 模型就是一个会动脑筋、会动手、还会从错误中学习的超级助理的工作方法!

它的任务就是听懂你的话,然后从数据库里找到答案。它可聪明了,做事非常有条理,就像一个大侦探,每次破案都遵循三个步骤:

  • 思考(Think - 推理)

    当用户说“帮我查一下仓库里还有多少台 iPhone 15?”

    助理就会想一想:“这个问题是什么意思呢?我需要去“库存表”里看看,要找“iPhone 15”的数量”。

  • 行动(Act - 行动)

    想好之后,它就立刻行动起来!也就是生成 SQL 查询语句去查找。

  • 检查(Observe - 观察):行动之后,它会仔细检查结果。比如,它可能发现“生成的 SQL ”写错了,它会进行修正。

总之,当你告诉它要什么,它就用这个“思考-行动”的魔法,自己想办法从数据库里把答案给你找出来。

3. 多态数据库设计

  • 全局数据源元数据注册表:
    • 身份标识与路由
    • 语义描述
    • 连接字符串管理
    • 方言适配
  • 连接隔离和最小权限原则
    • 控制平面(AI系统自身的数据库、读写权限)
    • 数据平面(外部业务数据库)
  • 动态数据库连接工厂
    • 我们使用 Dapper 来实现,因为要支持多态数据库,强类型的 EF Core 不太好实现

4. 实现动态数据源管理

  • 实现领域模型

创建一个名为 Qjy.AICopilot.Core.DataAnalysis 的类库,并创建 BusinessDatabase 实体,用来添加支持的数据库,以及配置数据库链接字符串。

我们项目支持三种数据库类型:PostgreSql、SqlServer 和 MySql,所以我们先创建一个 DbProviderType 枚举,然后创建 BusinessDatabase 实体。

//Qjy.AICopilot.Core.DataAnalysis/Aggregates/BusinessDatabase/DbProviderType.cs
/// <summary>
/// 数据库提供程序类型
/// </summary>
public enum DbProviderType
{
    /// <summary>
    /// PostgreSQL 数据库
    /// </summary>
    PostgreSql = 1,

    /// <summary>
    /// Microsoft SQL Server
    /// </summary>
    SqlServer = 2,

    /// <summary>
    /// MySQL
    /// </summary>
    MySql = 3
}

//Qjy.AICopilot.Core.DataAnalysis/Aggregates/BusinessDatabase/BusinessDatabase.cs
/// <summary>
/// 业务数据库聚合根
/// 代表一个可被AI Agent访问的外部数据源
/// </summary>
public class BusinessDatabase : IAggregateRoot
{
    protected BusinessDatabase() { }

    public BusinessDatabase(string name, string description, string connectionString, DbProviderType provider)
    {
        Id = Guid.NewGuid();
        Name = name;
        Description = description;
        ConnectionString = connectionString;
        Provider = provider;
        IsEnabled = true;
        CreatedAt = DateTime.UtcNow;
    }

    public Guid Id { get; private set; }

    /// <summary>
    /// 数据库标识名称
    /// 用于在多库路由时作为唯一Key
    /// </summary>
    public string Name { get; private set; } = null!;

    /// <summary>
    /// 数据库业务描述 (如: "包含所有销售订单、客户资料及发货记录")
    /// 该字段将被注入到System Prompt中,辅助LLM进行意图路由判断
    /// </summary>
    public string Description { get; private set; } = null!;

    /// <summary>
    /// 连接字符串
    /// </summary>
    public string ConnectionString { get; private set; } = null!;

    /// <summary>
    /// 数据库类型
    /// </summary>
    public DbProviderType Provider { get; private set; }

    /// <summary>
    /// 是否启用
    /// </summary>
    public bool IsEnabled { get; private set; }

    public DateTime CreatedAt { get; private set; }

    /// <summary>
    /// 更新连接信息
    /// </summary>
    public void UpdateConnection(string connectionString, DbProviderType provider)
    {
        ConnectionString = connectionString;
        Provider = provider;
    }

    /// <summary>
    /// 更新描述信息
    /// </summary>
    public void UpdateInfo(string name, string description)
    {
        Name = name;
        Description = description;
    }
}
  • 配置数据库映射关系,然后重新生成数据库
//Qjy.AICopilot.EntityFrameworkCore/Configuration/DataAnalysis/BusinessDatabaseConfiguration.cs
public class BusinessDatabaseConfiguration : IEntityTypeConfiguration<BusinessDatabase>
{
    public void Configure(EntityTypeBuilder<BusinessDatabase> builder)
    {
        builder.ToTable("business_databases");

        builder.HasKey(b => b.Id);
        builder.Property(b => b.Id).HasColumnName("id");

        builder.Property(b => b.Name)
            .IsRequired()
            .HasMaxLength(100)
            .HasColumnName("name");
        
        // 保证名称唯一,便于路由查找
        builder.HasIndex(b => b.Name).IsUnique();

        builder.Property(b => b.Description)
            .IsRequired()
            .HasMaxLength(500)
            .HasColumnName("description");

        builder.Property(b => b.ConnectionString)
            .IsRequired()
            .HasMaxLength(1000)
            .HasColumnName("connection_string");

        builder.Property(b => b.Provider)
            .IsRequired()
            .HasConversion<string>() // 存储枚举字符串,增强可读性
            .HasMaxLength(50)
            .HasColumnName("provider");

        builder.Property(b => b.IsEnabled)
            .IsRequired()
            .HasColumnName("is_enabled");

        builder.Property(b => b.CreatedAt)
            .IsRequired()
            .HasColumnName("created_at");
    }
}

//Qjy.AICopilot.EntityFrameworkCore/AiCopilotDbContext.cs
public DbSet<BusinessDatabase> BusinessDatabases => Set<BusinessDatabase>();

//Qjy.AICopilot.EntityFrameworkCore/DataQueryService.cs
public IQueryable<BusinessDatabase> BusinessDatabases => dbContext.BusinessDatabases.AsNoTracking();

//Qjy.AICopilot.Services.Common/Contracts/IDataQueryService.cs
public IQueryable<BusinessDatabase> BusinessDatabases { get; }
  • 创建 Dapper 基础设施层
    • 在 Infrastructure,创建 Qjy.AICopilot.Dapper 类库项目
    • 添加 Dapper,3种数据库需要的引用包,以及 Hosting 抽象包(用于配置 SQL 安全服务和数据库连接器)。
<ItemGroup>
  <PackageReference Include="Dapper" Version="2.1.66" />
  <PackageReference Include="Microsoft.Data.SqlClient" Version="6.1.3" />
  <PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="10.0.1" />
  <PackageReference Include="MySql.Data" Version="9.5.0" />
  <PackageReference Include="Npgsql" Version="10.0.1" />
</ItemGroup>
  • 实现 SQL 安全验证器
//Qjy.AICopilot.Dapper/ISqlGuardrail.cs
public interface ISqlGuardrail
{
    /// <summary>
    /// 验证SQL语句是否安全
    /// </summary>
    /// <param name="sql">待执行的SQL</param>
    /// <returns>验证结果,包含是否通过及错误信息</returns>
    (bool IsSafe, string? ErrorMessage) Validate(string sql);
}

//Qjy.AICopilot.Dapper/KeywordSqlGuardrail.cs
/// <summary>
/// 基于关键词黑名单的SQL安全服务
/// 注意:这只是第一道防线,不能完全替代数据库层面的权限控制,数据库提供只读帐号
/// </summary>
public class KeywordSqlGuardrail : ISqlGuardrail
{
    // 定义高危操作关键词
    private static readonly string[] ForbiddenKeywords = 
    [
        "DROP", "TRUNCATE", "DELETE", "UPDATE", "INSERT", 
        "ALTER", "GRANT", "REVOKE", "CREATE", "EXEC", "EXECUTE",
        "MERGE", "REPLACE", "UPSERT"
    ];

    public (bool IsSafe, string? ErrorMessage) Validate(string sql)
    {
        if (string.IsNullOrWhiteSpace(sql))
            return (false, "SQL语句为空");

        // 移除注释,防止通过注释绕过检测 (简单的 -- 或 /* */)
        // 生产环境建议使用更完善的 SQL Parser 库进行 AST 分析
        var cleanSql = RemoveComments(sql).ToUpperInvariant();

        foreach (var keyword in ForbiddenKeywords)
        {
            // 使用单词边界匹配,避免误杀 (例如: "SELECT * FROM UPDATE_LOG" 不应被拦截)
            // \b{keyword}\b 确保匹配的是完整的单词
            var regex = new Regex($@"\b{keyword}\b", RegexOptions.IgnoreCase);
            if (regex.IsMatch(cleanSql))
            {
                return (false, $"安全拦截:检测到禁止的关键字 '{keyword}'。Agent 仅允许执行查询操作。");
            }
        }

        // 检查是否包含分号,防止多语句执行注入 (如: SELECT * FROM Users; DROP TABLE Logs)
        // 大多数 ORM 在单次执行中只允许一条语句,但进行显式检查更为安全
        if (cleanSql.Count(c => c == ';') > 1 || (cleanSql.Contains(';') && !cleanSql.TrimEnd().EndsWith(';')))
        {
            return (false, "安全拦截:禁止在单次调用中执行多条 SQL 语句。");
        }

        return (true, null);
    }

    private static string RemoveComments(string sql)
    {
        // 移除 -- 单行注释
        var noSingleLine = Regex.Replace(sql, "--.*", "");
        // 移除 /* */ 多行注释
        var noComments = Regex.Replace(noSingleLine, @"/\*[\s\S]*?\*/", "");
        return noComments;
    }
}
  • 实现数据库连接器
    • 先定义一个接口,提供3个方法(获取数据库连接、获取数据库架构信息,以及执行查询并返回动态列表)
    • 然后提供 Dappar 的接口实现
//Qjy.AICopilot.Services.Common/Contracts/IDatabaseConnector.cs
public interface IDatabaseConnector
{
    /// <summary>
    /// 获取数据库连接(不打开,仅创建对象)
    /// </summary>
    IDbConnection GetConnection(BusinessDatabase database);

    /// <summary>
    /// 获取数据库架构信息(表名、列名等)
    /// </summary>
    Task<IEnumerable<dynamic>> GetSchemaInfoAsync(BusinessDatabase database, CancellationToken cancellationToken = default);

    /// <summary>
    /// 执行查询并返回动态列表
    /// </summary>
    /// <param name="database">目标数据库配置</param>
    /// <param name="sql">SQL语句</param>
    /// <param name="parameters">参数</param>
    /// <param name="cancellationToken">取消令牌</param>
    /// <returns>动态对象列表(IEnumerable of dynamic)</returns>
    Task<IEnumerable<dynamic>> ExecuteQueryAsync(BusinessDatabase database, string sql, object? parameters = null, CancellationToken cancellationToken = default);
}

//Qjy.AICopilot.Dapper/DapperDatabaseConnector.cs
public class DapperDatabaseConnector(ISqlGuardrail sqlGuardrail, ILogger<DapperDatabaseConnector> logger) : IDatabaseConnector
{
    public IDbConnection GetConnection(BusinessDatabase database)
    {
        var connectionString = database.ConnectionString;

        return database.Provider switch
        {
            DbProviderType.PostgreSql => new NpgsqlConnection(connectionString),
            DbProviderType.SqlServer => new SqlConnection(connectionString),
            DbProviderType.MySql => new MySqlConnection(connectionString),
            _ => throw new NotSupportedException($"不支持的数据库提供程序: {database.Provider}")
        };
    }

    public async Task<IEnumerable<dynamic>> GetSchemaInfoAsync(BusinessDatabase database, CancellationToken cancellationToken = default)
    {
        // 获取所有用户表的元数据SQL
        string sql = database.Provider switch
        {
            DbProviderType.PostgreSql => @"
                SELECT table_name, table_schema 
                FROM information_schema.tables 
                WHERE table_schema = 'public' AND table_type = 'BASE TABLE';",

            _ => throw new NotSupportedException("不支持的数据库类型")
        };

        return await ExecuteQueryAsync(database, sql, cancellationToken: cancellationToken);
    }

    public async Task<IEnumerable<dynamic>> ExecuteQueryAsync(BusinessDatabase database, string sql, object? parameters = null, CancellationToken cancellationToken = default)
    {
        // 1. 安全检查
        var guardResult = sqlGuardrail.Validate(sql);
        if (!guardResult.IsSafe)
        {
            logger.LogWarning("SQL安全拦截: {Reason}. SQL: {Sql}", guardResult.ErrorMessage, sql);
            throw new InvalidOperationException(guardResult.ErrorMessage);
        }

        // 2. 创建连接
        using var connection = GetConnection(database);
        
        try
        {
            // 3. 执行查询
            // 使用 CommandDefinition 支持 CancellationToken
            var command = new CommandDefinition(sql, parameters, cancellationToken: cancellationToken);
            
            // Dapper 的 QueryAsync 返回的是 IEnumerable<dynamic>
            // 这对于无法预知列名的动态查询非常合适
            var result = await connection.QueryAsync(command);
            
            return result;
        }
        catch (Exception ex)
        {
            logger.LogError(ex, "在数据库 {DbName} 上执行 SQL 失败。SQL: {Sql}", database.Name, sql);
            throw; // 抛出异常供上层 Agent 捕获并进行自我修正
        }
    }
}
  • 注入 Dappar
//Qjy.AICopilot.Dapper/DependencyInjection.cs
public static class DependencyInjection
{
    public static void AddDapper(this IHostApplicationBuilder builder)
    {
        // 注册 SQL 安全服务
        builder.Services.AddSingleton<ISqlGuardrail, KeywordSqlGuardrail>();
        
        // 注册 数据库连接器
        builder.Services.AddScoped<IDatabaseConnector, DapperDatabaseConnector>();
    }
}

二、实现数据分析插件

1. 架构设计

我们采用项目中已经提供的插件框架,来实现一个数据分析插件。

  • 认知漏斗模型

    • 先用宽口径扫描(表名+表注释),轻量级的获取表名+注释
    • 然后用窄口径聚焦(3~5候选表的详细结构),重量级的获取表结构(数据库名称、数据表名数组)
    • 最后执行验证
  • Token 阶段策略

    比如一些表存在 nvchar(max) 这样的字段,里面存放的内容可能非常多,如果把需要的数据都加载出来,可能会占用大量的 Token。一方面会挤占上下文空间,另外也会浪费 Token,所以针对这样的内容,读取数据时需要做数据截断。

2. 实现插件

//Qjy.AICopilot.DataAnalysisService/Plugins/DataAnalysisPlugin.cs
// 用于映射元数据查询结果
public record ColumnMetadata
{
    public string ColumnName { get; set; } = "";
    public string DataType { get; set; } = "";
    public bool IsPrimaryKey { get; set; }
    public string? Description { get; set; }
}

/// <summary>
/// 数据分析插件
/// 提供数据库元数据探索和SQL执行能力,是Text-to-SQL的核心组件。
/// </summary>
public class DataAnalysisPlugin(IServiceProvider serviceProvider, IDatabaseConnector dbConnector, ILogger<DataAnalysisPlugin> logger) : AgentPluginBase
{
    public override string Description => "提供数据库结构查询和SQL执行能力,用于回答涉及业务数据的统计分析问题。";

    // 辅助方法:根据名称获取数据库配置
    // 这个方法不暴露给 AI,仅供内部使用
    private async Task<BusinessDatabase> GetDatabaseAsync(string databaseName, CancellationToken ct)
    {
        using var scope = serviceProvider.CreateScope();
        var dataQuery = scope.ServiceProvider.GetRequiredService<IDataQueryService>();
        var queryable = dataQuery.BusinessDatabases.Where(d => d.Name == databaseName);
        var db = await dataQuery.FirstOrDefaultAsync(queryable);

        if (db == null)
        {
            throw new ArgumentException($"未找到名称为 '{databaseName}' 的数据库。请检查名称是否正确。");
        }

        if (!db.IsEnabled)
        {
            throw new InvalidOperationException($"数据库 '{databaseName}' 已被禁用。");
        }

        return db;
    }
    
    [Description("获取指定数据库中所有表的名称和描述。这是探索数据库结构的第一步。")]
    public async Task<string> GetTableNamesAsync([Description("目标数据库的名称")] string databaseName)
    {
        try
        {
            // 获取数据库配置
            var db = await GetDatabaseAsync(databaseName, CancellationToken.None);

            // 根据数据库类型构建查询元数据的 SQL
            var sql = string.Empty;
            switch (db.Provider)
            {
                case DbProviderType.PostgreSql:
                    // PostgreSQL: 从 information_schema 获取表名,关联 pg_description 获取注释
                    sql = @"
                        SELECT 
                            t.table_name AS ""TableName"",
                            obj_description(pgc.oid) AS ""Description""
                        FROM information_schema.tables t
                        INNER JOIN pg_class pgc ON t.table_name = pgc.relname
                        WHERE t.table_schema = 'public' 
                          AND t.table_type = 'BASE TABLE';";
                    break;
                case DbProviderType.SqlServer:
                    // SQL Server
                    break;
                default:
                    return $"错误:不支持的数据库类型 {db.Provider}";
            }

            // 执行查询
            // 这里使用了基础设施层的 ExecuteQueryAsync,它返回 IEnumerable<dynamic>
            var result = await dbConnector.ExecuteQueryAsync(db, sql);

            // 序列化结果
            return result.ToJson();
        }
        catch (Exception ex)
        {
            logger.LogError(ex, "获取表名失败。Database: {DbName}", databaseName);
            return $"获取表名时发生错误: {ex.Message}";
        }
    }
    
    
    // 内部辅助方法:查询单个表的列元数据
    private async Task<List<ColumnMetadata>> GetColumnsAsync(BusinessDatabase db, string tableName)
    {
        var sql = string.Empty;
        switch (db.Provider)
        {
            case DbProviderType.PostgreSql:
                // PostgreSQL 元数据查询
                // 包含列名、类型、是否主键
                // 注意:此处简化了查询,实际生产中可能需要更复杂的关联来获取外键
                sql = @"
            SELECT 
                c.column_name AS ""ColumnName"",
                c.data_type AS ""DataType"",
                CASE WHEN tc.constraint_type = 'PRIMARY KEY' THEN 1 ELSE 0 END AS ""IsPrimaryKey"",
                pg_catalog.col_description(format('%s.%s', c.table_schema, c.table_name)::regclass::oid, c.ordinal_position) AS ""Description""
            FROM information_schema.columns c
            LEFT JOIN information_schema.key_column_usage kcu 
                ON c.table_name = kcu.table_name AND c.column_name = kcu.column_name
            LEFT JOIN information_schema.table_constraints tc 
                ON kcu.constraint_name = tc.constraint_name AND tc.constraint_type = 'PRIMARY KEY'
            WHERE c.table_name = @TableName AND c.table_schema = 'public';";
                break;
            case DbProviderType.SqlServer:
                // SQL Server 元数据查询
                break;
            default:
                return [];
        }

        var result = await dbConnector.ExecuteQueryAsync(db, sql, new { TableName = tableName });

        // Dapper 返回的是 dynamic,需要手动映射到强类型
        var columns = new List<ColumnMetadata>();
        foreach (var row in result)
        {
            var dict = (IDictionary<string, object>)row;
            columns.Add(new ColumnMetadata
            {
                ColumnName = dict["ColumnName"] as string ?? "",
                DataType = dict["DataType"] as string ?? "",
                IsPrimaryKey = Convert.ToInt32(dict["IsPrimaryKey"]) == 1,
                Description = dict["Description"] as string ?? ""
            });
        }

        return columns;
    }
    
    [Description("获取指定表的详细结构定义(DDL),包含列名、数据类型、主键和外键信息。")]
    public async Task<string> GetTableSchemaAsync(
        [Description("目标数据库的名称")] string databaseName,
        [Description("需要查询的表名列表,如 'Orders, Customers'")] string[] tableNames)
    {
        if (tableNames.Length == 0)
        {
            return "错误:请提供至少一个表名。";
        }

        try
        {
            var db = await GetDatabaseAsync(databaseName, CancellationToken.None);
            var ddlBuilder = new StringBuilder();

            foreach (var tableName in tableNames)
            {
                // 1. 查询列信息
                var columns = await GetColumnsAsync(db, tableName);

                if (!columns.Any())
                {
                    ddlBuilder.AppendLine($"-- 警告: 表 '{tableName}' 不存在或没有列。");
                    continue;
                }

                // 2. 构建 CREATE TABLE 语句
                ddlBuilder.AppendLine($"CREATE TABLE {tableName} (");

                var columnDefs = new List<string>();
                foreach (var col in columns)
                {
                    // 格式: ColumnName DataType [PK/FK] [Comment]
                    var colDef = $"  {col.ColumnName} {col.DataType}";

                    if (col.IsPrimaryKey) colDef += " PRIMARY KEY";

                    // 如果有描述,作为注释添加,帮助 AI 理解字段含义
                    if (!string.IsNullOrWhiteSpace(col.Description))
                    {
                        colDef += $" -- {col.Description}";
                    }

                    columnDefs.Add(colDef);
                }

                ddlBuilder.AppendLine(string.Join(",\n", columnDefs));
                ddlBuilder.AppendLine(");");
                ddlBuilder.AppendLine();
            }

            return ddlBuilder.ToString();
        }
        catch (Exception ex)
        {
            logger.LogError(ex, "获取表结构失败。Database: {DbName}", databaseName);
            return $"获取表结构时发生错误: {ex.Message}";
        }
    }
}

internal static class JsonHelper
{
    private static readonly JsonSerializerOptions DefaultOptions = new()
    {
        // 正式环境使用 WriteIndented = false 压缩 JSON,节省 Token
        WriteIndented = true,
        // 不转义中文字符串,避免编码问题
        Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping
    };

    public static string ToJson(this object obj)
    {
        return JsonSerializer.Serialize(obj, DefaultOptions);
    }
}

注入插件

//Qjy.AICopilot.DataAnalysisService/DependencyInjection.cs
public static class DependencyInjection
{
    public static void AddDataAnalysisService(this IHostApplicationBuilder builder)
    {
        // 注册 Dapper 基础服务
        builder.AddDapper();
        // 注册插件加载器
        builder.Services.AddAgentPlugin(registrar =>
        {
            registrar.RegisterPluginFromAssembly(Assembly.GetExecutingAssembly());
        });
    }
}