From 6ca6516b12ba08f52a049ea54dbb0bd58f180c23 Mon Sep 17 00:00:00 2001 From: wanxiaoli Date: Mon, 16 Mar 2026 11:00:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../business/css/config/MybatisConfig.java | 39 +-- .../css/controller/ModelTrainController.java | 1 + .../css/controller/SimController.java | 15 +- .../business/css/domain/ModelTrainTask.java | 6 +- .../yfd/business/css/model/InferRequest.java | 4 + .../css/service/DeviceInferService.java | 2 +- .../business/css/service/SimInferService.java | 25 +- .../service/impl/ModelTrainServiceImpl.java | 224 ++++++++++++++++-- 8 files changed, 258 insertions(+), 58 deletions(-) diff --git a/business-css/src/main/java/com/yfd/business/css/config/MybatisConfig.java b/business-css/src/main/java/com/yfd/business/css/config/MybatisConfig.java index 89f05e6..d1e2e96 100644 --- a/business-css/src/main/java/com/yfd/business/css/config/MybatisConfig.java +++ b/business-css/src/main/java/com/yfd/business/css/config/MybatisConfig.java @@ -1,25 +1,32 @@ package com.yfd.business.css.config; +import com.baomidou.mybatisplus.annotation.DbType; import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean; import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; +import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor; import org.apache.ibatis.session.SqlSessionFactory; import org.apache.ibatis.plugin.Interceptor; import org.mybatis.spring.SqlSessionTemplate; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; -import jakarta.annotation.Resource; import javax.sql.DataSource; @Configuration public class MybatisConfig { - @Resource - private MybatisPlusInterceptor mybatisPlusInterceptor; + @Bean + @ConditionalOnMissingBean(MybatisPlusInterceptor.class) + public MybatisPlusInterceptor mybatisPlusInterceptor() { + MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor(); + interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); + return interceptor; + } @Bean - public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception { + public SqlSessionFactory sqlSessionFactory(DataSource dataSource, MybatisPlusInterceptor mybatisPlusInterceptor) throws Exception { MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean(); factoryBean.setDataSource(dataSource); factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver() @@ -34,27 +41,3 @@ public class MybatisConfig { } } -// @Configuration -// public class MybatisConfig { - -// @Bean -// public MybatisPlusInterceptor mybatisPlusInterceptor() { -// return new MybatisPlusInterceptor(); -// } - -// @Bean -// public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception { -// MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean(); -// factoryBean.setDataSource(dataSource); -// factoryBean.setMapperLocations( -// new PathMatchingResourcePatternResolver() -// .getResources("classpath*:/mapper/**/*.xml")); -// return factoryBean.getObject(); -// } - -// @Bean -// public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) { -// return new SqlSessionTemplate(sqlSessionFactory); -// } -// } - diff --git a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java index 6f7ecb8..dc4e24e 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java @@ -53,6 +53,7 @@ public class ModelTrainController { } String taskId = modelTrainService.submitTask(task); + System.out.println("提交任务成功,任务ID: " + taskId); return ResponseResult.successData(taskId); } catch (JsonProcessingException e) { return ResponseResult.error("参数解析失败: " + e.getMessage()); diff --git a/business-css/src/main/java/com/yfd/business/css/controller/SimController.java b/business-css/src/main/java/com/yfd/business/css/controller/SimController.java index 02dd625..8d96ee2 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/SimController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/SimController.java @@ -1,15 +1,18 @@ package com.yfd.business.css.controller; import com.yfd.business.css.build.SimBuilder; +import com.yfd.business.css.domain.Scenario; import com.yfd.business.css.facade.SimDataFacade; import com.yfd.business.css.model.*; import com.yfd.business.css.service.MaterialService; +import com.yfd.business.css.service.ScenarioService; import com.yfd.business.css.service.SimService; import com.yfd.business.css.service.SimInferService; import com.yfd.platform.config.ResponseResult; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; +import java.time.LocalDateTime; import java.util.List; import java.util.Map; @@ -26,6 +29,7 @@ public class SimController { @Autowired private SimService simService; @Autowired private MaterialService materialService; @Autowired private SimInferService simInferService; + @Autowired private ScenarioService scenarioService; /** * 执行仿真计算 @@ -39,6 +43,13 @@ public class SimController { String scenarioId = (String) req.get("scenarioId"); int steps = req.containsKey("steps") ? (int) req.get("steps") : 10; + // 0. Update Status: 更新情景状态为进行中 + Scenario startScenario = new Scenario(); + startScenario.setScenarioId(scenarioId); + startScenario.setStatus("1"); // 1: 进行中 + startScenario.setUpdatedAt(LocalDateTime.now()); + scenarioService.updateById(startScenario); + // 1. Load Data: 获取项目、设备和事件数据 SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId); @@ -50,10 +61,10 @@ public class SimController { // 3. Run Engine: 执行核心仿真计算,返回上下文结果 SimContext ctx = simService.runSimulation(units, events, nodes, steps); - // 4. Async Infer: 异步执行推理并保存结果 + // 4. Async Infer: 异步执行推理并保存结果,完成后更新情景状态 simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units); - // 5. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据 + // 6. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据 Map resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId); return ResponseResult.successData(resultData); diff --git a/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java b/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java index f138e84..6bf72e1 100644 --- a/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java +++ b/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java @@ -31,7 +31,7 @@ public class ModelTrainTask implements Serializable { private String trainParams; // JSON String @TableField("status") - private String status; // PENDING, TRAINING, SUCCESS, FAILED + private String status; // Pending, Training, Success, Failed @TableField(value = "metrics") private String metrics; // JSON String @@ -48,9 +48,9 @@ public class ModelTrainTask implements Serializable { @TableField("error_log") private String errorLog; - @TableField(value = "created_at", fill = FieldFill.INSERT) + @TableField(value = "created_at") private LocalDateTime createdAt; - @TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE) + @TableField(value = "updated_at") private LocalDateTime updatedAt; } diff --git a/business-css/src/main/java/com/yfd/business/css/model/InferRequest.java b/business-css/src/main/java/com/yfd/business/css/model/InferRequest.java index 08cc4ba..a8d2e52 100644 --- a/business-css/src/main/java/com/yfd/business/css/model/InferRequest.java +++ b/business-css/src/main/java/com/yfd/business/css/model/InferRequest.java @@ -1,11 +1,15 @@ package com.yfd.business.css.model; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import java.util.Map; public class InferRequest { private String modelDir; + + @JsonProperty("device_type") private String deviceType; + private List> batch; private Map features; private Map meta; diff --git a/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java b/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java index 846f60a..a24413e 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java @@ -196,7 +196,7 @@ public class DeviceInferService { // 合并的Python推理调用方法 public InferResponse infer(InferRequest request) { - String url = pythonInferUrl + "/v1/infer" ; + String url = pythonInferUrl + "/v1/infer"; RestTemplate restTemplate = new RestTemplate(); HttpHeaders headers = new HttpHeaders(); diff --git a/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java b/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java index d31c426..11bfb18 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java @@ -1,10 +1,12 @@ package com.yfd.business.css.service; +import com.yfd.business.css.domain.Scenario; import com.yfd.business.css.model.*; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; +import java.time.LocalDateTime; import java.util.*; import java.util.stream.Collectors; @@ -17,6 +19,9 @@ public class SimInferService { @Autowired private DeviceInferService deviceInferService; + + @Autowired + private ScenarioService scenarioService; @Async public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List units) { @@ -57,7 +62,6 @@ public class SimInferService { if (deviceType == null || deviceType.isEmpty()) continue; // 确保静态属性也包含在内(如果 SimContext 中没有,从 SimUnit 补充) - // SimService 应该已经初始化了静态属性到 Context,但为了保险起见: for (Map.Entry staticProp : unit.staticProperties().entrySet()) { properties.putIfAbsent(staticProp.getKey(), staticProp.getValue()); } @@ -75,16 +79,35 @@ public class SimInferService { if (groupedDevices.isEmpty()) { System.out.println("No device data found for inference."); + // 即使没有数据,也应该更新状态为完成,或者视为正常结束 + updateScenarioStatus(scenarioId, "2"); return; } // 3. 调用 DeviceInferService 进行推理和入库 // 复用现有的 processDeviceInference 方法,它处理了按类型分组的数据 deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices); + + // 4. 更新状态为已完成 + updateScenarioStatus(scenarioId, "2"); } catch (Exception e) { System.err.println("Async inference failed: " + e.getMessage()); e.printStackTrace(); + // 5. 更新状态为失败 (假设 3 代表失败) + updateScenarioStatus(scenarioId, "3"); + } + } + + private void updateScenarioStatus(String scenarioId, String status) { + try { + Scenario scenario = new Scenario(); + scenario.setScenarioId(scenarioId); + scenario.setStatus(status); + scenario.setUpdatedAt(LocalDateTime.now()); + scenarioService.updateById(scenario); + } catch (Exception e) { + System.err.println("Failed to update scenario status to " + status + ": " + e.getMessage()); } } } diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java index d93de48..3b9602e 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java @@ -25,17 +25,25 @@ import org.springframework.web.multipart.MultipartFile; import java.io.File; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.HashMap; import java.util.Map; import java.util.UUID; +import java.util.regex.Pattern; @Service public class ModelTrainServiceImpl extends ServiceImpl implements ModelTrainService { @Value("${file-space.upload-path:./data/uploads/}") private String uploadPath; + + @Value("${file-space.model-path:E:/python_coding/keffCenter/models/}") + private String modelPath; @Value("${python.api.url:http://localhost:8000}") private String pythonApiUrl; @@ -48,6 +56,8 @@ public class ModelTrainServiceImpl extends ServiceImpl Map) if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) { @@ -147,15 +159,39 @@ public class ModelTrainServiceImpl extends ServiceImpl response = restTemplate.postForEntity(url, entity, Map.class); if (response.getStatusCode().is2xxSuccessful()) { + Map body = response.getBody(); + if (body != null) { + try { + System.out.println("训练服务响应: " + objectMapper.writeValueAsString(body)); + } catch (JsonProcessingException ignored) { + System.out.println("训练服务响应: " + body); + } + Object codeObj = body.get("code"); + Integer code = null; + if (codeObj instanceof Number) { + code = ((Number) codeObj).intValue(); + } else if (codeObj instanceof String) { + try { + code = Integer.parseInt((String) codeObj); + } catch (Exception ignored) { + } + } + if (code != null && code != 0) { + task.setStatus("Failed"); + task.setErrorLog("训练服务返回失败: " + body); + this.updateById(task); + return; + } + } System.out.println("训练任务提交成功: " + task.getTaskId()); } else { - task.setStatus("FAILED"); + task.setStatus("Failed"); task.setErrorLog("提交训练任务失败: " + response.getStatusCode()); this.updateById(task); } } catch (Exception e) { - task.setStatus("FAILED"); + task.setStatus("Failed"); task.setErrorLog("调用 Python 服务异常: " + e.getMessage()); this.updateById(task); e.printStackTrace(); @@ -169,36 +205,79 @@ public class ModelTrainServiceImpl extends ServiceImpl response = restTemplate.getForEntity(url, Map.class); if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { Map body = response.getBody(); - String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED + Object codeObj = body.get("code"); + Integer code = null; + if (codeObj instanceof Number) { + code = ((Number) codeObj).intValue(); + } else if (codeObj instanceof String) { + try { + code = Integer.parseInt((String) codeObj); + } catch (Exception ignored) { + } + } + if (code != null && code != 0) { + task.setStatus("Failed"); + task.setErrorLog(String.valueOf(body.get("msg"))); + this.updateById(task); + return task; + } + + Object dataObj = body.get("data"); + if (!(dataObj instanceof Map)) { + return task; + } + Map data = (Map) dataObj; + String status = (String) data.get("status"); if (status != null) { + // 转换状态为首字母大写 + if (status.equalsIgnoreCase("Success")) { + status = "Success"; + } else if (status.equalsIgnoreCase("Failed")) { + status = "Failed"; + } else if (status.equalsIgnoreCase("Training")) { + status = "Training"; + } else if (status.equalsIgnoreCase("Pending")) { + status = "Pending"; + } + task.setStatus(status); - if ("SUCCESS".equals(status)) { - task.setModelOutputPath((String) body.get("model_path")); + if ("Success".equals(status)) { + String modelPathRaw = firstNonBlank( + (String) data.get("model_relative_path"), + (String) data.get("model_path_rel_project"), + (String) data.get("model_path") + ); + task.setModelOutputPath(normalizeModelPath(modelPathRaw)); // Map -> JSON String - if (body.get("metrics") != null) { - task.setMetrics(objectMapper.writeValueAsString(body.get("metrics"))); + if (data.get("metrics") != null) { + task.setMetrics(objectMapper.writeValueAsString(data.get("metrics"))); } - if (body.get("feature_map") != null) { - task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map"))); + if (data.get("feature_map") != null) { + task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map"))); } - task.setMetricsImagePath((String) body.get("metrics_image")); - } else if ("FAILED".equals(status)) { - task.setErrorLog((String) body.get("error")); - } else if ("TRAINING".equals(status)) { - if (body.get("metrics") != null) { - task.setMetrics(objectMapper.writeValueAsString(body.get("metrics"))); + String metricsPathRaw = firstNonBlank( + (String) data.get("metrics_image_relative_path"), + (String) data.get("metrics_image_rel_project"), + (String) data.get("metrics_image") + ); + task.setMetricsImagePath(normalizeModelPath(metricsPathRaw)); + } else if ("Failed".equals(status)) { + task.setErrorLog((String) data.get("error")); + } else if ("Training".equals(status)) { + if (data.get("metrics") != null) { + task.setMetrics(objectMapper.writeValueAsString(data.get("metrics"))); } } @@ -221,9 +300,13 @@ public class ModelTrainServiceImpl extends ServiceImpl() @@ -233,14 +316,37 @@ public class ModelTrainServiceImpl extends ServiceImpl 0) { throw new BizException("版本号已存在"); } + + String algorithmType = task.getAlgorithmType(); + String deviceType = task.getDeviceType(); + if (algorithmType == null || algorithmType.isBlank() || deviceType == null || deviceType.isBlank()) { + throw new BizException("算法类型或设备类型不能为空"); + } + + Path root = Paths.get(modelPath).toAbsolutePath().normalize(); + Path versionDir = root.resolve(Paths.get(algorithmType, deviceType, versionTag)).normalize(); + if (!versionDir.startsWith(root)) { + throw new BizException("发布目录非法"); + } + try { + Files.createDirectories(versionDir); + } catch (IOException e) { + throw new BizException("创建发布目录失败: " + e.getMessage()); + } + + String publishedModelRelPath = copyToVersionDir(root, versionDir, task.getModelOutputPath()); + String publishedMetricsRelPath = null; + if (task.getMetricsImagePath() != null && !task.getMetricsImagePath().isBlank()) { + publishedMetricsRelPath = copyToVersionDir(root, versionDir, task.getMetricsImagePath()); + } // 创建正式模型记录 AlgorithmModel model = new AlgorithmModel(); model.setAlgorithmType(task.getAlgorithmType()); model.setDeviceType(task.getDeviceType()); model.setVersionTag(versionTag); - model.setModelPath(task.getModelOutputPath()); - model.setMetricsImagePath(task.getMetricsImagePath()); + model.setModelPath(publishedModelRelPath); + model.setMetricsImagePath(publishedMetricsRelPath); model.setTrainedAt(LocalDateTime.now()); model.setIsCurrent(0); // 默认不激活 @@ -250,4 +356,76 @@ public class ModelTrainServiceImpl extends ServiceImpl