Spring Boot + GPT:我做了一个能自己写 SQL 的后端系统
随着大语言模型技术的快速发展,AI在软件开发领域的应用越来越广泛。本文将详细介绍如何构建一个基于Spring Boot和GPT的智能后端系统,该系统能够根据自然语言描述自动生成SQL查询,实现智能化的数据访问层。
开始
对于大模型的选择不一定非得选择ChatGPT我们还可以选择国内的大模型
- 通义千问
- DeepSeek
系统架构设计
整体架构
智能SQL生成系统采用分层架构,主要包括:
- 自然语言输入层:处理用户自然语言请求
- AI模型层:GPT模型进行语义理解和SQL生成
- SQL解析层:验证和优化生成的SQL
- 数据访问层:执行SQL并返回结果
- 业务逻辑层:处理业务逻辑和数据转换
核心组件关系
| 组件 | 职责 | 技术栈 |
|---|---|---|
| 自然语言处理器 | 解析用户请求 | Spring Boot |
| SQL生成器 | GPT模型调用 | OpenAI API |
| SQL验证器 | 验证SQL安全性 | 自定义规则引擎 |
| 数据访问器 | 执行SQL查询 | Spring Data JPA |
Spring Boot后端实现
项目依赖配置
Spring Boot项目依赖
<dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-data-jpa</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-validation</artifactId> </dependency> <dependency> <groupId>com.theokanning.openai-gpt3-java</groupId> <artifactId>service</artifactId> <version>0.12.0</version> </dependency> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> </dependency> </dependencies>核心实体类
// 数据库表实体 @Entity @Table(name = "products") public class Product { @Id @GeneratedValue(strategy = GenerationType.IDENTITY) private Long id; @Column(name = "name", nullable = false) private String name; @Column(name = "price") private BigDecimal price; @Column(name = "category") private String category; @Column(name = "description") private String description; @Column(name = "created_at") private LocalDateTime createdAt; // 构造函数、getter和setter public Product() { } public Product(String name, BigDecimal price, String category, String description) { this.name = name; this.price = price; this.category = category; this.description = description; this.createdAt = LocalDateTime.now(); } } // SQL生成请求DTO public class SqlGenerationRequest { @NotBlank(message = "查询描述不能为空") private String queryDescription; @NotBlank(message = "数据库类型不能为空") private String databaseType; private List<String> tableSchemas; // 构造函数、getter和setter } // SQL生成响应DTO public class SqlGenerationResponse { private String generatedSql; private String explanation; private boolean isValid; private List<String> suggestions; // 构造函数、getter和setter }AI服务层实现
// GPT服务类 @Service public class GptSqlService { private final OpenAIService openAIService; private final SqlValidator sqlValidator; public GptSqlService(@Value("${openai.api.key}") String apiKey) { this.openAIService = new OpenAIService(apiKey); this.sqlValidator = new SqlValidator(); } public SqlGenerationResponse generateSql(SqlGenerationRequest request) { // 构建GPT提示词 String prompt = buildPrompt(request); // 调用GPT生成SQL CompletionRequest completionRequest = CompletionRequest.builder() .model("text-davinci-003") .prompt(prompt) .maxTokens(500) .temperature(0.3) .build(); CompletionResult result = openAIService.createCompletion(completionRequest); String generatedSql = result.getChoices().get(0).getText().trim(); // 验证生成的SQL SqlValidationResult validationResult = sqlValidator.validate(generatedSql); SqlGenerationResponse response = new SqlGenerationResponse(); response.setGeneratedSql(generatedSql); response.setIsValid(validationResult.isValid()); response.setSuggestions(validationResult.getSuggestions()); response.setExplanation("根据您的描述生成的SQL查询"); return response; } private String buildPrompt(SqlGenerationRequest request) { StringBuilder prompt = new StringBuilder(); prompt.append("你是一个专业的SQL专家。请根据以下自然语言描述生成相应的SQL查询语句。\n\n"); prompt.append("数据库类型: ").append(request.getDatabaseType()).append("\n"); prompt.append("表结构信息:\n"); for (String schema : request.getTableSchemas()) { prompt.append(schema).append("\n"); } prompt.append("\n查询需求: ").append(request.getQueryDescription()); prompt.append("\n\n请生成对应的SQL查询语句,只返回SQL代码,不要包含其他解释。"); return prompt.toString(); } }SQL验证器
// SQL安全验证器 @Component public class SqlValidator { private static final Set<String> DANGEROUS_KEYWORDS = Set.of( "DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER", "TRUNCATE", "GRANT", "REVOKE", "EXEC", "EXECUTE" ); public SqlValidationResult validate(String sql) { SqlValidationResult result = new SqlValidationResult(); if (sql == null || sql.trim().isEmpty()) { result.setValid(false); result.addSuggestion("生成的SQL为空"); return result; } // 检查危险关键字 String upperSql = sql.toUpperCase(); for (String keyword : DANGEROUS_KEYWORDS) { if (upperSql.contains(keyword)) { result.setValid(false); result.addSuggestion("检测到危险SQL关键字: " + keyword); return result; } } // 检查基本语法 if (!isValidSelectStatement(sql)) { result.setValid(false); result.addSuggestion("SQL语句不符合SELECT语法规范"); return result; } result.setValid(true); return result; } private boolean isValidSelectStatement(String sql) { String trimmedSql = sql.trim().toUpperCase(); return trimmedSql.startsWith("SELECT") && (trimmedSql.contains("FROM") || trimmedSql.contains("JOIN")); } } // SQL验证结果类 public class SqlValidationResult { private boolean valid; private List<String> suggestions; public SqlValidationResult() { this.suggestions = new ArrayList<>(); this.valid = true; } public void addSuggestion(String suggestion) { this.suggestions.add(suggestion); } }控制器层
// 主要控制器 @RestController @RequestMapping("/api/sql") @CrossOrigin(origins = "*") public class SqlGenerationController { private final GptSqlService gptSqlService; private final DatabaseService databaseService; public SqlGenerationController(GptSqlService gptSqlService, DatabaseService databaseService) { this.gptSqlService = gptSqlService; this.databaseService = databaseService; } @PostMapping("/generate") public ResponseEntity<SqlGenerationResponse> generateSql(@Valid @RequestBody SqlGenerationRequest request) { try { SqlGenerationResponse response = gptSqlService.generateSql(request); return ResponseEntity.ok(response); } catch (Exception e) { return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) .body(createErrorResponse("SQL生成失败: " + e.getMessage())); } } @PostMapping("/execute") public ResponseEntity<Object> executeSql(@Valid @RequestBody SqlGenerationRequest request) { try { // 首先生成SQL SqlGenerationResponse sqlResponse = gptSqlService.generateSql(request); if (!sqlResponse.isValid()) { return ResponseEntity.badRequest().body(sqlResponse); } // 执行SQL查询 List<Map<String, Object>> results = databaseService.executeQuery(sqlResponse.getGeneratedSql()); Map<String, Object> response = new HashMap<>(); response.put("sql", sqlResponse.getGeneratedSql()); response.put("results", results); response.put("count", results.size()); return ResponseEntity.ok(response); } catch (Exception e) { return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) .body(Map.of("error", "SQL执行失败: " + e.getMessage())); } } private SqlGenerationResponse createErrorResponse(String message) { SqlGenerationResponse response = new SqlGenerationResponse(); response.setGeneratedSql(""); response.setExplanation(message); response.setIsValid(false); return response; } }数据库服务层
// 数据库操作服务 @Service public class DatabaseService { @Autowired private JdbcTemplate jdbcTemplate; public List<Map<String, Object>> executeQuery(String sql) { // 验证SQL安全性(再次验证) if (!isSafeQuery(sql)) { throw new IllegalArgumentException("不安全的SQL查询"); } return jdbcTemplate.queryForList(sql); } private boolean isSafeQuery(String sql) { String upperSql = sql.toUpperCase().trim(); // 只允许SELECT查询 if (!upperSql.startsWith("SELECT")) { return false; } // 检查是否包含危险操作 return !upperSql.contains("DROP") && !upperSql.contains("DELETE") && !upperSql.contains("UPDATE") && !upperSql.contains("INSERT"); } // 获取数据库表结构 public List<String> getTableSchemas() { List<String> schemas = new ArrayList<>(); // 获取所有表名 List<String> tableNames = jdbcTemplate.queryForList( "SELECT TABLE_NAME FROM information_schema.tables WHERE table_schema = DATABASE()", String.class ); for (String tableName : tableNames) { String schema = getTableSchema(tableName); schemas.add(schema); } return schemas; } private String getTableSchema(String tableName) { List<Map<String, Object>> columns = jdbcTemplate.queryForList( "SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE FROM information_schema.columns WHERE table_name = ?", tableName ); StringBuilder schema = new StringBuilder(); schema.append("表: ").append(tableName).append("\n"); schema.append("字段:\n"); for (Map<String, Object> column : columns) { schema.append(" ") .append(column.get("COLUMN_NAME")) .append(" (") .append(column.get("DATA_TYPE")) .append(", ") .append(column.get("IS_NULLABLE")) .append(")\n"); } return schema.toString(); } }前端集成示例
简单的前端界面
<!DOCTYPE html> <html> <head> <title>智能SQL生成器</title> <style> body { font-family: Arial, sans-serif; margin: 20px; } .container { max-width: 800px; margin: 0 auto; } textarea { width: 100%; height: 100px; } .result { margin-top: 20px; padding: 10px; background: #f5f5f5; } button { padding: 10px 20px; margin: 5px; } </style> </head> <body> <div class="container"> <h1>智能SQL生成器</h1> <div> <label>查询描述:</label> <textarea id="queryInput" placeholder="请输入查询需求,例如:查询价格大于100的产品"></textarea> </div> <div> <button onclick="generateSql()">生成SQL</button> <button onclick="executeSql()">执行查询</button> </div> <div class="result" id="result"></div> </div> <script> async function generateSql() { const query = document.getElementById('queryInput').value; if (!query) { alert('请输入查询描述'); return; } try { const response = await fetch('/api/sql/generate', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ queryDescription: query, databaseType: 'MySQL', tableSchemas: [] // 可以从后端获取表结构 }) }); const result = await response.json(); document.getElementById('result').innerHTML = `<h3>生成的SQL:</h3><pre>${ result.generatedSql}</pre> <h3>说明:</h3><p>${ result.explanation}</p>`; } catch (error) { console.error('Error:', error); document.getElementById('result').innerHTML = `<p>错误: ${ error.message}</p>`; } } async function executeSql() { const query = document.getElementById('queryInput').value; if (!query) { alert('请输入查询描述'); return; } try { const response = await fetch('/api/sql/execute', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ queryDescription: query, databaseType: 'MySQL', tableSchemas: [] }) }); const result = await response.json(); if (response.ok) { document.getElementById('result').innerHTML = `<h3>执行结果:</h3> <p>查询到 ${ result.count} 条记录</p> <pre>${ JSON.stringify(result.results, null, 2)}</pre>`; } else { document.getElementById('result').innerHTML = `<h3>错误:</h3><p>${ result.error}</p>`; } } catch (error) { console.error('Error:', error); document.getElementById('result').innerHTML = `<p>错误: ${ error.message}</p>`; } } </script> </body> </html>安全性考虑
SQL注入防护
系统采用了多层防护机制:
- 输入验证:验证用户输入的合法性
- SQL验证:检查生成SQL的安全性
- 执行限制:只允许SELECT查询
- 参数化查询:使用参数化查询防止注入
访问控制
// 安全配置 @Configuration @EnableWebSecurity public class SecurityConfig { @Bean public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { http .authorizeHttpRequests(authz -> authz .requestMatchers("/api/sql/**").authenticated() .anyRequest().permitAll() ) .httpBasic(withDefaults()) .csrf(csrf -> csrf.disable()); return http.build(); } }性能优化
缓存机制
// 查询结果缓存 @Service public class CachedSqlService { private final Cache<String, List<Map<String, Object>>> queryCache; public CachedSqlService() { this.queryCache = Caffeine.newBuilder() .maximumSize(1000) .expireAfterWrite(Duration.ofMinutes(10)) .build(); } public List<Map<String, Object>> executeWithCache(String sql) { return queryCache.get(sql, key -> { // 执行实际查询 return executeQuery(sql); }); } }部署和监控
应用配置
application.yml
server: port: 8080 spring: datasource: url: jdbc:mysql://localhost:3306/your_database username: your_username password: your_password driver-class-name: com.mysql.cj.jdbc.Driver jpa: hibernate: ddl-auto: validate show-sql: false properties: hibernate: format_sql: true openai: api: key: your_openai_api_key logging: level: com.yourpackage: DEBUG监控指标
// 自定义监控指标 @Component public class SqlMetricsService { private final MeterRegistry meterRegistry; public SqlMetricsService(MeterRegistry meterRegistry) { this.meterRegistry = meterRegistry; } public void recordSqlGeneration(String queryType, long duration) { Counter.builder("sql.generation.count") .tag("type", queryType) .register(meterRegistry) .increment(); Timer.builder("sql.generation.duration") .register(meterRegistry) .record(duration, TimeUnit.MILLISECONDS); } }总结
通过Spring Boot和GPT的结合,我们成功构建了一个能够自动生成SQL的智能后端系统。该系统不仅提高了开发效率,还为非技术用户提供了友好的数据查询界面。
主要优势:
- 智能化:自然语言转SQL,降低使用门槛
- 安全性:多层防护机制确保数据安全
- 可扩展:模块化设计便于功能扩展
- 易用性:简洁的API接口便于集成
未来发展方向:
- 支持更多数据库类型
- 增强SQL优化功能
- 集成机器学习进行查询优化
- 提供可视化查询构建器
这个系统展示了AI技术在软件开发中的巨大潜力,为构建更智能、更高效的应用系统提供了新的思路。