From aff2d7feeafc54d55b439f076b47fa440d649a0a Mon Sep 17 00:00:00 2001 From: wanxiaoli Date: Wed, 11 Mar 2026 13:55:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../css/CriticalScenarioApplication.java | 4 + .../yfd/business/css/build/SimBuilder.java | 390 ++++++++++++++++-- .../css/config/RestTemplateConfig.java | 14 + .../css/controller/EventController.java | 2 +- .../css/controller/ModelTrainController.java | 76 ++++ .../css/controller/SimController.java | 78 ++-- .../business/css/domain/ModelTrainTask.java | 58 +++ .../com/yfd/business/css/domain/Scenario.java | 13 + .../business/css/facade/SimDataFacade.java | 39 ++ .../css/mapper/ModelTrainTaskMapper.java | 9 + .../yfd/business/css/model/SimContext.java | 2 +- .../business/css/model/SimDataPackage.java | 40 ++ .../com/yfd/business/css/model/SimEvent.java | 9 +- .../css/model/SimResultConverter.java | 72 ++++ .../com/yfd/business/css/model/SimUnit.java | 8 +- .../css/service/DeviceInferService.java | 140 ++++--- .../css/service/ModelTrainService.java | 36 ++ .../business/css/service/SimInferService.java | 90 ++++ .../yfd/business/css/service/SimService.java | 122 ++++-- .../impl/AlgorithmModelServiceImpl.java | 9 +- .../service/impl/ModelTrainServiceImpl.java | 228 ++++++++++ 21 files changed, 1292 insertions(+), 147 deletions(-) create mode 100644 business-css/src/main/java/com/yfd/business/css/config/RestTemplateConfig.java create mode 100644 business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java create mode 100644 business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java create mode 100644 business-css/src/main/java/com/yfd/business/css/facade/SimDataFacade.java create mode 100644 business-css/src/main/java/com/yfd/business/css/mapper/ModelTrainTaskMapper.java create mode 100644 business-css/src/main/java/com/yfd/business/css/model/SimDataPackage.java create mode 100644 business-css/src/main/java/com/yfd/business/css/model/SimResultConverter.java create mode 100644 business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java create mode 100644 business-css/src/main/java/com/yfd/business/css/service/SimInferService.java create mode 100644 business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java diff --git a/business-css/src/main/java/com/yfd/business/css/CriticalScenarioApplication.java b/business-css/src/main/java/com/yfd/business/css/CriticalScenarioApplication.java index b557784..b536a10 100644 --- a/business-css/src/main/java/com/yfd/business/css/CriticalScenarioApplication.java +++ b/business-css/src/main/java/com/yfd/business/css/CriticalScenarioApplication.java @@ -5,6 +5,7 @@ import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.web.servlet.ServletComponentScan; @SpringBootApplication( scanBasePackages = { @@ -18,6 +19,9 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; }, exclude = {DataSourceAutoConfiguration.class, RedisAutoConfiguration.class} ) +@ServletComponentScan(basePackages = { + "com.yfd.platform.config" +}) @MapperScan(basePackages = { "com.yfd.platform.**.mapper", "com.yfd.business.css.**.mapper" diff --git a/business-css/src/main/java/com/yfd/business/css/build/SimBuilder.java b/business-css/src/main/java/com/yfd/business/css/build/SimBuilder.java index cb697e3..689b94d 100644 --- a/business-css/src/main/java/com/yfd/business/css/build/SimBuilder.java +++ b/business-css/src/main/java/com/yfd/business/css/build/SimBuilder.java @@ -1,44 +1,368 @@ package com.yfd.business.css.build; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yfd.business.css.domain.Device; +import com.yfd.business.css.domain.Event; +import com.yfd.business.css.domain.Material; +import com.yfd.business.css.domain.Project; import com.yfd.business.css.model.*; +import com.yfd.business.css.service.MaterialService; +import org.springframework.stereotype.Component; -import java.util.List; -import java.util.Map; +import java.util.*; -/* -*将数据库原始表 Map(拓扑、事件、影响关系) -*转成 SimUnit / SimEvent / SimInfluenceNode -*/ +/** + * 仿真模型构建器 + * 负责将原始的 Project (Topology) 和 Event 数据解析转换为 SimUnit, SimEvent, SimInfluenceNode + */ +@Component public class SimBuilder { - // public static List buildUnitsFromTopo(List topoRows) { - // return topoRows.stream() - // .map(t -> new SimUnit(t.getDeviceId(), t.getDeviceId(), t.getMaterialId(), t.getDeviceType())) - // .toList(); - // } + private final ObjectMapper objectMapper = new ObjectMapper(); - // public static List buildEventsFromAttrChanges(List attrChanges) { - // return attrChanges.stream() - // .map(e -> new SimEvent(e.getStep(), - // SimPropertyKey.of(e.getDeviceId(), e.getAttr()), - // e.getValue(), - // SimEvent.EventType.SET)) - // .toList(); - // } + // 1. 构建单元 (SimUnit) - 包含静态属性 + public List buildUnits(SimDataPackage data, MaterialService materialService) { + List units = new ArrayList<>(); + Project project = data.getProject(); + if (project == null || project.getTopology() == null) return units; - // public static List buildInfluenceNodesFromInfluences(List influenceRows) { - // return influenceRows.stream() - // .map(i -> { - // List sources = i.getSources().stream() - // .map(s -> new SimInfluenceSource( - // SimPropertyKey.of(s.getUnitId(), s.getProp()), - // s.getCoeff(), - // s.getDelay() - // )).toList(); - // return new SimInfluenceNode(SimPropertyKey.of(i.getTargetUnit(), i.getTargetProp()), - // sources, - // i.getBias()); - // }).toList(); - // } + // 建立 Device ID -> Device 实体的映射,方便查找 + Map deviceMap = new HashMap<>(); + if (data.getDevices() != null) { + for (Device d : data.getDevices()) { + deviceMap.put(d.getDeviceId(), d); + } + } + try { + JsonNode root = objectMapper.readTree(project.getTopology()); + Map devToMat = buildDeviceMaterialMap(root); + Map> matStaticDb = buildMaterialStaticFromDb(devToMat, materialService); + + JsonNode devicesNode = root.path("devices"); + if (devicesNode.isArray()) { + for (JsonNode deviceNode : devicesNode) { + String deviceId = deviceNode.path("deviceId").asText(); + if (deviceId == null || deviceId.isEmpty()) continue; + + String type = deviceNode.path("type").asText(); + String materialId = devToMat.get(deviceId); + + Map staticProps = new HashMap<>(); + + // 1.1 解析 topology 中的 static 节点 + JsonNode st = deviceNode.path("static"); + if (st.isObject()) { + st.fieldNames().forEachRemaining(k -> { + if (!"unit".equals(k)) { + JsonNode v = st.path(k); + if (v.isNumber()) staticProps.put(k, v.asDouble()); + else if (v.isTextual()) staticProps.put(k, parseDouble(v.asText())); + } + }); + } + + // 1.2 解析 Device.size 并注入 + Device device = deviceMap.get(deviceId); + if (device != null) { + injectDeviceSize(device, staticProps); + } + + // 1.3 注入物料静态属性 + if (materialId != null) { + Map mProps = matStaticDb.get(materialId); + if (mProps != null) staticProps.putAll(mProps); + + // 解析 topology 中 material 的 static + JsonNode mats = deviceNode.path("materials"); + if (mats.isMissingNode() || mats.isNull()) mats = deviceNode.path("material"); + if (mats.isObject()) { + JsonNode mst = mats.path("static"); + if (mst.isObject()) { + mst.fieldNames().forEachRemaining(k -> { + if (!"unit".equals(k)) { + JsonNode v = mst.path(k); + if (v.isNumber()) staticProps.put(k, v.asDouble()); + else if (v.isTextual()) staticProps.put(k, parseDouble(v.asText())); + } + }); + } + } + } + + units.add(new SimUnit(deviceId, deviceId, materialId, type, staticProps)); + } + } + } catch (Exception e) { + throw new RuntimeException("Build units failed", e); + } + return units; + } + + private void injectDeviceSize(Device device, Map staticProps) { + try { + String sizeJson = device.getSize(); + if (sizeJson == null || sizeJson.isBlank()) return; + + JsonNode sizeNode = objectMapper.readTree(sizeJson); + String type = device.getType(); // 假设 Device 有 getType() 方法,或从外部传入 + + if (type == null) { + // Fallback: 尝试通用解析 (现有逻辑) + parseCommonSize(sizeNode, staticProps); + return; + } + + switch (type) { + case "FlatTank": + if (sizeNode.has("width")) staticProps.put("width", sizeNode.get("width").asDouble()); + if (sizeNode.has("length")) staticProps.put("length", sizeNode.get("length").asDouble()); + if (sizeNode.has("height")) staticProps.put("height", sizeNode.get("height").asDouble()); + break; + + case "CylindricalTank": + case "AnnularTank": + case "TubeBundleTank": + parseCommonSize(sizeNode, staticProps); + break; + + case "ExtractionColumn": + // 优先提取 tray_section (塔身) + if (sizeNode.has("tray_section")) { + parseCommonSize(sizeNode.get("tray_section"), staticProps); + } + // 可选:将其他段的尺寸作为特殊属性注入,例如 lower_expanded_height + break; + + case "FluidizedBed": + // 优先提取 reaction_section (反应段) + if (sizeNode.has("reaction_section")) { + parseCommonSize(sizeNode.get("reaction_section"), staticProps); + } + break; + + case "ACFTank": + // 优先提取 annular_cylinder (圆柱段) + if (sizeNode.has("annular_cylinder")) { + parseCommonSize(sizeNode.get("annular_cylinder"), staticProps); + } + break; + + default: + parseCommonSize(sizeNode, staticProps); + } + } catch (Exception e) { + System.err.println("解析Device.size失败:" + e.getMessage()); + } + } + + private void parseCommonSize(JsonNode node, Map staticProps) { + if (node.has("outer_diameter")) { + staticProps.put("diameter", node.get("outer_diameter").asDouble()); + } else if (node.has("diameter")) { + staticProps.put("diameter", node.get("diameter").asDouble()); + } + + if (node.has("height")) { + staticProps.put("height", node.get("height").asDouble()); + } + } + + // 2. 构建事件 (SimEvent) + public List buildEvents(List events) { + List simEvents = new ArrayList<>(); + if (events == null) return simEvents; + + for (Event ev : events) { + String json = ev.getAttrChanges(); + if (json == null || json.isBlank()) continue; + try { + JsonNode root = objectMapper.readTree(json); + JsonNode target = root.path("target"); + String entityType = optText(target, "entityType"); + String entityId = optText(target, "entityId"); + String property = optText(target, "property"); + if (entityType == null || entityId == null || property == null) continue; + + // 区分 entityType? SimUnit id 是 deviceId,如果是 material,这里需要注意 SimUnit 的定义 + // ProjectServiceImpl 中是分开处理 device 和 material 的 state + // 这里假设 SimUnit 以 deviceId 为主键,material 属性也挂在 deviceId 下 (SimContext key: deviceId + property) + // 或者 SimContext key 包含 entityType? SimPropertyKey(unitId, property) + // ProjectServiceImpl: entityType + ":" + entityId + ":" + property + // SimPropertyKey: unitId, property. unitId 通常是 deviceId. + // 如果是 material 属性,ProjectServiceImpl 中是把 material 属性 copy 到了 device state 中。 + // 所以这里 key 的 unitId 应该是 deviceId (如果 target 是 material,需找到对应的 deviceId) + // 但 Event 中只记录了 materialId,没记录 deviceId (虽然 Event 表有 device_id 字段) + + String targetUnitId = entityId; + // 注意:如果 event 针对 material,且 material 被多个 device 共享,这里逻辑需要确认 + // ProjectServiceImpl 是遍历所有 device,如果 device 关联了该 material,就应用覆盖。 + // SimService 中 SimEvent 直接作用于 SimContext。 + // 如果 SimUnit 对应 Device,那么 Material 的事件需要转换为针对所有引用该 Material 的 Device 的事件。 + // 这需要在 buildEvents 时知道 Device-Material 关系,或者 SimContext 支持 Material 维度的存储。 + // 简化起见,假设 Event 表中 device_id 有值,或者我们在 buildEvents 时传入 devToMat 映射关系? + // Event 表有 device_id 字段,可以直接用。 + if ("material".equals(entityType)) { + if (ev.getDeviceId() != null) targetUnitId = ev.getDeviceId(); + } + + SimPropertyKey key = SimPropertyKey.of(targetUnitId, property); + + JsonNode segments = root.path("segments"); + if (segments.isArray()) { + for (JsonNode seg : segments) { + String interp = optText(seg, "interp"); + JsonNode timeline = seg.path("timeline"); + if (timeline.isArray()) { + // 简化处理:全部转为 step-set,ramp 也离散化? + // ProjectServiceImpl 是保留 ramp 结构,readValue 时计算。 + // SimService 是 step 推进。 + // 这里我们将 timeline 展开为每个 step 的 SimEvent + + // 线性插值处理比较复杂,这里先处理离散点 + for (JsonNode p : timeline) { + double t = p.path("t").asDouble(); + double val = p.path("value").asDouble(); + // Step ? Time to Step 转换 + // 假设 1s = 1 step + int step = (int) t; + // isOverride=true 在所有计算完成后,强制修改属性值。有 1 帧延迟 。 + // Input 事件 ( isOverride=false ),在 applyInfluences (影响计算) 之前 执行。下游设备读取的是 修改后 的值(当前帧的值)。 + simEvents.add(new SimEvent(step, key, val, SimEvent.EventType.SET, false)); // 默认为强制覆盖? + } + } + } + } + } catch (Exception e) { + // log error + } + } + return simEvents; + } + + // 3. 构建影响关系 (SimInfluenceNode) + public List buildInfluenceNodes(Project project) { + List nodes = new ArrayList<>(); + if (project == null || project.getTopology() == null) return nodes; + + try { + JsonNode root = objectMapper.readTree(project.getTopology()); + JsonNode devicesNode = root.path("devices"); + if (devicesNode.isArray()) { + for (JsonNode dn : devicesNode) { + String deviceId = optText(dn, "deviceId"); + if (deviceId == null) continue; + + // Device Properties Influence + parseInfluence(dn.path("properties"), deviceId, nodes); + + // Material Properties Influence (mapped to device) + JsonNode mats = dn.path("materials"); + if (mats.isMissingNode() || mats.isNull()) mats = dn.path("material"); + if (mats.isObject()) { + parseInfluence(mats.path("properties"), deviceId, nodes); + } else if (mats.isArray()) { + for (JsonNode mn : mats) { + parseInfluence(mn.path("properties"), deviceId, nodes); + } + } + } + } + } catch (Exception e) { + throw new RuntimeException("Build influence nodes failed", e); + } + return nodes; + } + + private void parseInfluence(JsonNode props, String targetUnitId, List nodes) { + if (!props.isObject()) return; + props.fieldNames().forEachRemaining(propName -> { + JsonNode prop = props.path(propName); + if ("influence".equalsIgnoreCase(optText(prop, "type"))) { + double bias = prop.path("bias").asDouble(0.0); + List sources = new ArrayList<>(); + JsonNode srcs = prop.path("sources"); + if (srcs.isArray()) { + for (JsonNode sNode : srcs) { + String seId = optText(sNode, "entityId"); // source unit id + String seProp = optText(sNode, "property"); + double coef = sNode.path("coefficient").asDouble(1.0); + + long delayMs = 0L; + JsonNode delay = sNode.path("delay"); + if (delay.path("enabled").asBoolean(false)) { + long t = delay.path("time").asLong(0L); + String u = optText(delay, "unit"); + delayMs = toMillis(t, u); + } + int delayStep = (int) (delayMs / 1000); // 假设 1s = 1 step + + sources.add(new SimInfluenceSource(SimPropertyKey.of(seId, seProp), coef, delayStep)); + } + } + nodes.add(new SimInfluenceNode(SimPropertyKey.of(targetUnitId, propName), sources, bias)); + } + }); + } + + // Helpers + private Map buildDeviceMaterialMap(JsonNode root) { + Map map = new HashMap<>(); + JsonNode devicesNode = root.path("devices"); + if (devicesNode.isArray()) { + for (JsonNode dn : devicesNode) { + String did = optText(dn, "deviceId"); + if (did == null) continue; + JsonNode mats = dn.path("materials"); + if (mats.isMissingNode() || mats.isNull()) mats = dn.path("material"); + + String mid = null; + if (mats.isArray() && mats.size() > 0) mid = optText(mats.get(0), "materialId"); + else if (mats.isObject()) mid = optText(mats, "materialId"); + + if (mid != null) map.put(did, mid); + } + } + return map; + } + + private Map> buildMaterialStaticFromDb(Map devToMat, MaterialService materialService) { + Map> out = new HashMap<>(); + if (devToMat.isEmpty()) return out; + Set mids = new HashSet<>(devToMat.values()); + List mats = materialService.list(new QueryWrapper().in("material_id", mids)); + for (Material m : mats) { + Map s = new HashMap<>(); + if (m.getUConcentration() != null) s.put("u_concentration", m.getUConcentration().doubleValue()); + if (m.getPuConcentration() != null) s.put("pu_concentration", m.getPuConcentration().doubleValue()); + if (m.getUEnrichment() != null) s.put("u_enrichment", m.getUEnrichment().doubleValue()); + if (m.getPuIsotope() != null) s.put("pu_isotope", m.getPuIsotope().doubleValue()); + if (m.getUo2Density() != null) s.put("uo2_density", m.getUo2Density().doubleValue()); + if (m.getPuo2Density() != null) s.put("puo2_density", m.getPuo2Density().doubleValue()); + if (m.getHno3Acidity() != null) s.put("hno3_acidity", m.getHno3Acidity().doubleValue()); + if (m.getH2c2o4Concentration() != null) s.put("h2c2o4_concentration", m.getH2c2o4Concentration().doubleValue()); + if (m.getOrganicRatio() != null) s.put("organic_ratio", m.getOrganicRatio().doubleValue()); + if (m.getMoistureContent() != null) s.put("moisture_content", m.getMoistureContent().doubleValue()); + // ... 其他属性 + out.put(m.getMaterialId(), s); + } + return out; + } + + private String optText(JsonNode node, String key) { + if (node.has(key)) return node.get(key).asText(); + return null; + } + + private double parseDouble(String s) { + try { return Double.parseDouble(s); } catch (Exception e) { return 0.0; } + } + + private long toMillis(long time, String unit) { + if ("s".equalsIgnoreCase(unit)) return time * 1000; + if ("ms".equalsIgnoreCase(unit)) return time; + return time * 1000; // default s + } } diff --git a/business-css/src/main/java/com/yfd/business/css/config/RestTemplateConfig.java b/business-css/src/main/java/com/yfd/business/css/config/RestTemplateConfig.java new file mode 100644 index 0000000..3f16f5b --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/config/RestTemplateConfig.java @@ -0,0 +1,14 @@ +package com.yfd.business.css.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestTemplate; + +@Configuration +public class RestTemplateConfig { + + @Bean + public RestTemplate restTemplate() { + return new RestTemplate(); + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/controller/EventController.java b/business-css/src/main/java/com/yfd/business/css/controller/EventController.java index 30e62ee..7f4188f 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/EventController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/EventController.java @@ -154,7 +154,7 @@ public class EventController { new QueryWrapper() .select("event_id","scenario_id","device_id","material_id","attr_changes","trigger_time","created_at","modifier") .eq("scenario_id", scenarioId) - .orderByDesc("created_at") + .orderByAsc("created_at") ); } diff --git a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java new file mode 100644 index 0000000..32de20e --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java @@ -0,0 +1,76 @@ +package com.yfd.business.css.controller; + +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.yfd.business.css.domain.ModelTrainTask; +import com.yfd.business.css.service.ModelTrainService; +import com.yfd.platform.config.ResponseResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; + +import java.util.Map; + +@RestController +@RequestMapping("/train") +public class ModelTrainController { + + @Autowired + private ModelTrainService modelTrainService; + + /** + * 上传数据集 + */ + @PostMapping("/upload") + public ResponseResult upload(@RequestParam("file") MultipartFile file) { + String path = modelTrainService.uploadDataset(file); + return ResponseResult.successData(path); + } + + /** + * 提交训练任务 + */ + @PostMapping("/submit") + public ResponseResult submit(@RequestBody ModelTrainTask task) { + String taskId = modelTrainService.submitTask(task); + return ResponseResult.successData(taskId); + } + + /** + * 查询任务列表 + */ + @GetMapping("/list") + public ResponseResult list(@RequestParam(defaultValue = "1") Integer current, + @RequestParam(defaultValue = "10") Integer size) { + Page page = modelTrainService.page(new Page<>(current, size)); + return ResponseResult.successData(page); + } + + /** + * 查询任务详情/状态 + */ + @GetMapping("/status/{taskId}") + public ResponseResult status(@PathVariable String taskId) { + ModelTrainTask task = modelTrainService.syncTaskStatus(taskId); + return ResponseResult.successData(task); + } + + /** + * 发布模型 + */ + @PostMapping("/publish") + public ResponseResult publish(@RequestBody Map body) { + String taskId = body.get("taskId"); + String versionTag = body.get("versionTag"); + boolean success = modelTrainService.publishModel(taskId, versionTag); + return success ? ResponseResult.success() : ResponseResult.error("发布失败"); + } + + /** + * 删除训练任务 + */ + @DeleteMapping("/{taskId}") + public ResponseResult delete(@PathVariable String taskId) { + boolean success = modelTrainService.removeById(taskId); + return success ? ResponseResult.success() : ResponseResult.error("删除失败"); + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/controller/SimController.java b/business-css/src/main/java/com/yfd/business/css/controller/SimController.java index 62dee12..02dd625 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/SimController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/SimController.java @@ -1,45 +1,61 @@ package com.yfd.business.css.controller; +import com.yfd.business.css.build.SimBuilder; +import com.yfd.business.css.facade.SimDataFacade; +import com.yfd.business.css.model.*; +import com.yfd.business.css.service.MaterialService; +import com.yfd.business.css.service.SimService; +import com.yfd.business.css.service.SimInferService; +import com.yfd.platform.config.ResponseResult; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; + import java.util.List; import java.util.Map; +/** + * 仿真服务统一入口控制器 + * 负责接收仿真请求,协调 Facade、Builder 和 Service 完成仿真计算 + */ @RestController @RequestMapping("/sim") public class SimController { - // private final ProjectRepository projectRepo; - // private final EventRepository eventRepo; - // private final InfluenceRepository influenceRepo; - // private final SimService simService; + @Autowired private SimDataFacade simDataFacade; + @Autowired private SimBuilder simBuilder; + @Autowired private SimService simService; + @Autowired private MaterialService materialService; + @Autowired private SimInferService simInferService; - // public SimController(ProjectRepository projectRepo, - // EventRepository eventRepo, - // InfluenceRepository influenceRepo, - // SimService simService) { - // this.projectRepo = projectRepo; - // this.eventRepo = eventRepo; - // this.influenceRepo = influenceRepo; - // this.simService = simService; - // } + /** + * 执行仿真计算 + * + * @param req 请求参数,包含 projectId, scenarioId, steps + * @return 仿真结果,包含 code, msg, data (data.frames) + */ + @PostMapping("/run") + public ResponseResult run(@RequestBody Map req) { + String projectId = (String) req.get("projectId"); + String scenarioId = (String) req.get("scenarioId"); + int steps = req.containsKey("steps") ? (int) req.get("steps") : 10; - // @PostMapping("/run") - // public Map runSimulation(@RequestParam String projectId, - // @RequestParam String scenarioId, - // @RequestParam int steps) { + // 1. Load Data: 获取项目、设备和事件数据 + SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId); + + // 2. Build Model: 将原始数据转换为仿真模型 (Units, Events, Nodes) + List units = simBuilder.buildUnits(data, materialService); + List events = simBuilder.buildEvents(data.getEvents()); + List nodes = simBuilder.buildInfluenceNodes(data.getProject()); + + // 3. Run Engine: 执行核心仿真计算,返回上下文结果 + SimContext ctx = simService.runSimulation(units, events, nodes, steps); + + // 4. Async Infer: 异步执行推理并保存结果 + simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units); - // List> topoRows = projectRepo.loadTopo(projectId); - // List> attrChanges = eventRepo.loadAttrChanges(projectId, scenarioId); - // List> influenceRows = influenceRepo.loadInfluences(projectId); - - // List units = SimBuilder.buildUnits(topoRows); - // List events = SimBuilder.buildEvents(attrChanges); - // List nodes = SimBuilder.buildInfluenceNodes(influenceRows); - - // SimContext ctx = simService.runSimulation(units, events, nodes, steps); - - // // 转换成推理接口 JSON - // return InferenceConverter.toInferenceInput(ctx, units, projectId, scenarioId); - // } + // 5. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据 + Map resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId); + + return ResponseResult.successData(resultData); + } } - diff --git a/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java b/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java new file mode 100644 index 0000000..92b4bdc --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java @@ -0,0 +1,58 @@ +package com.yfd.business.css.domain; + +import com.baomidou.mybatisplus.annotation.*; +import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; +import lombok.Data; + +import java.io.Serializable; +import java.time.LocalDateTime; +import java.util.Map; + +@Data +@TableName(value = "model_train_task", autoResultMap = true) +public class ModelTrainTask implements Serializable { + + private static final long serialVersionUID = 1L; + + @TableId(value = "task_id", type = IdType.ASSIGN_UUID) + private String taskId; + + @TableField("task_name") + private String taskName; + + @TableField("algorithm_type") + private String algorithmType; + + @TableField("device_type") + private String deviceType; + + @TableField("dataset_path") + private String datasetPath; + + @TableField(value = "train_params", typeHandler = JacksonTypeHandler.class) + private Map trainParams; // 使用 Map 存储 JSON + + @TableField("status") + private String status; // PENDING, TRAINING, SUCCESS, FAILED + + @TableField(value = "metrics", typeHandler = JacksonTypeHandler.class) + private Map metrics; + + @TableField("model_output_path") + private String modelOutputPath; + + @TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class) + private Map featureMapSnapshot; + + @TableField("metrics_image_path") + private String metricsImagePath; + + @TableField("error_log") + private String errorLog; + + @TableField(value = "created_at", fill = FieldFill.INSERT) + private LocalDateTime createdAt; + + @TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updatedAt; +} diff --git a/business-css/src/main/java/com/yfd/business/css/domain/Scenario.java b/business-css/src/main/java/com/yfd/business/css/domain/Scenario.java index d9740ba..26685d8 100644 --- a/business-css/src/main/java/com/yfd/business/css/domain/Scenario.java +++ b/business-css/src/main/java/com/yfd/business/css/domain/Scenario.java @@ -7,6 +7,7 @@ import com.baomidou.mybatisplus.annotation.TableName; import lombok.Data; import java.io.Serializable; +import java.math.BigDecimal; import java.time.LocalDateTime; @Data @@ -41,4 +42,16 @@ public class Scenario implements Serializable { @TableField("algorithm_type") private String algorithmType; + + /** + * Keff预警阈值 + */ + @TableField("keff_threshold") + private BigDecimal keffThreshold; + + /** + * 设备算法配置映射,格式:{"deviceId1":"GPR", "deviceId2":"MLP"} + */ + @TableField("device_algo_config") + private String deviceAlgoConfig; } diff --git a/business-css/src/main/java/com/yfd/business/css/facade/SimDataFacade.java b/business-css/src/main/java/com/yfd/business/css/facade/SimDataFacade.java new file mode 100644 index 0000000..9686da1 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/facade/SimDataFacade.java @@ -0,0 +1,39 @@ +package com.yfd.business.css.facade; + +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.yfd.business.css.domain.Device; +import com.yfd.business.css.domain.Event; +import com.yfd.business.css.domain.Project; +import com.yfd.business.css.model.SimDataPackage; +import com.yfd.business.css.service.DeviceService; +import com.yfd.business.css.service.EventService; +import com.yfd.business.css.service.ProjectService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.List; + +/** + * 仿真数据门面 + * 负责与各个业务Service交互,获取仿真所需的原始数据 (Project, Events 等) + */ +@Component +public class SimDataFacade { + @Autowired private ProjectService projectService; + @Autowired private EventService eventService; + @Autowired private DeviceService deviceService; + + public SimDataPackage loadSimulationData(String projectId, String scenarioId) { + // 1. 获取项目与拓扑 + Project project = projectService.getById(projectId); + if (project == null) throw new IllegalArgumentException("Project not found: " + projectId); + + // 2. 获取设备列表 (用于解析静态尺寸属性) + List devices = deviceService.list(new QueryWrapper().eq("project_id", projectId)); + + // 3. 获取事件 + List events = eventService.list(new QueryWrapper().eq("scenario_id", scenarioId)); + + return new SimDataPackage(project, devices, events); + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/mapper/ModelTrainTaskMapper.java b/business-css/src/main/java/com/yfd/business/css/mapper/ModelTrainTaskMapper.java new file mode 100644 index 0000000..43425ac --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/mapper/ModelTrainTaskMapper.java @@ -0,0 +1,9 @@ +package com.yfd.business.css.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.yfd.business.css.domain.ModelTrainTask; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface ModelTrainTaskMapper extends BaseMapper { +} diff --git a/business-css/src/main/java/com/yfd/business/css/model/SimContext.java b/business-css/src/main/java/com/yfd/business/css/model/SimContext.java index d104c53..b11c884 100644 --- a/business-css/src/main/java/com/yfd/business/css/model/SimContext.java +++ b/business-css/src/main/java/com/yfd/business/css/model/SimContext.java @@ -12,4 +12,4 @@ public class SimContext { public void snapshot(int step) { timeline.put(step, new HashMap<>(currentValues)); } public Map> getTimeline() { return timeline; } -} \ No newline at end of file +} diff --git a/business-css/src/main/java/com/yfd/business/css/model/SimDataPackage.java b/business-css/src/main/java/com/yfd/business/css/model/SimDataPackage.java new file mode 100644 index 0000000..1d6b3a2 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/model/SimDataPackage.java @@ -0,0 +1,40 @@ +package com.yfd.business.css.model; + +import com.yfd.business.css.domain.Device; +import com.yfd.business.css.domain.Event; +import com.yfd.business.css.domain.Project; + +import java.util.Collections; +import java.util.List; + +/** + * 仿真数据包 + * 封装从数据库加载的原始项目、设备和事件数据 + */ +public class SimDataPackage { + private final Project project; + private final List devices; + private final List events; + + public SimDataPackage(Project project, List events) { + this(project, Collections.emptyList(), events); + } + + public SimDataPackage(Project project, List devices, List events) { + this.project = project; + this.devices = devices; + this.events = events; + } + + public Project getProject() { + return project; + } + + public List getDevices() { + return devices; + } + + public List getEvents() { + return events; + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/model/SimEvent.java b/business-css/src/main/java/com/yfd/business/css/model/SimEvent.java index bc4a9c7..e56185d 100644 --- a/business-css/src/main/java/com/yfd/business/css/model/SimEvent.java +++ b/business-css/src/main/java/com/yfd/business/css/model/SimEvent.java @@ -7,16 +7,23 @@ public class SimEvent { private final SimPropertyKey key; private final double value; private final EventType type; + private final boolean isOverride; // 新增:是否强制覆盖 - public SimEvent(int step, SimPropertyKey key, double value, EventType type) { + public SimEvent(int step, SimPropertyKey key, double value, EventType type, boolean isOverride) { this.step = step; this.key = key; this.value = value; this.type = type; + this.isOverride = isOverride; + } + + public SimEvent(int step, SimPropertyKey key, double value, EventType type) { + this(step, key, value, type, false); // 默认为非强制 } public int getStep() { return step; } public SimPropertyKey getKey() { return key; } public double getValue() { return value; } public EventType getType() { return type; } + public boolean isOverride() { return isOverride; } } diff --git a/business-css/src/main/java/com/yfd/business/css/model/SimResultConverter.java b/business-css/src/main/java/com/yfd/business/css/model/SimResultConverter.java new file mode 100644 index 0000000..9e27a95 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/model/SimResultConverter.java @@ -0,0 +1,72 @@ +package com.yfd.business.css.model; + +import java.util.*; + +/** + * 仿真结果转换器 + * 负责将 SimContext 中的仿真状态转换为前端友好的帧结构 + */ +public class SimResultConverter { + + /** + * 将仿真上下文转换为帧列表 + * @param ctx 仿真上下文 + * @param units 仿真单元列表 (用于补充设备类型等元数据) + * @param projectId 项目ID + * @param scenarioId 情景ID + * @return 包含完整结果数据的 Map + */ + public static Map toFrames(SimContext ctx, List units, String projectId, String scenarioId) { + List> frames = new ArrayList<>(); + Map unitMap = new HashMap<>(); + for (SimUnit u : units) { + unitMap.put(u.unitId(), u); + } + + Map> timeline = ctx.getTimeline(); + List steps = new ArrayList<>(timeline.keySet()); + Collections.sort(steps); + + for (Integer step : steps) { + Map snapshot = timeline.get(step); + Map> devices = new HashMap<>(); + + // 1. 先初始化所有设备及其静态属性 + for (SimUnit unit : units) { + Map devState = new HashMap<>(); + devState.put("deviceType", unit.deviceType()); + // 注入所有静态属性作为基线 + if (unit.staticProperties() != null) { + devState.putAll(unit.staticProperties()); + } + devices.put(unit.unitId(), devState); + } + + // 2. 覆盖动态计算出的属性值 + for (Map.Entry entry : snapshot.entrySet()) { + String unitId = entry.getKey().unitId(); + String prop = entry.getKey().property(); + Double val = entry.getValue(); + + Map devState = devices.get(unitId); + if (devState == null) continue; // 理论上不应该发生,除非有未定义的unit + + devState.put(prop, val); + } + + Map frame = new HashMap<>(); + frame.put("step", step); + frame.put("time", step); // 假设 1 step = 1 time unit + frame.put("devices", devices); + frames.add(frame); + } + + Map data = new HashMap<>(); + data.put("frames", frames); + data.put("projectId", projectId); + data.put("scenarioId", scenarioId); + data.put("generated", Map.of("snapshots", steps.size())); + + return data; + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/model/SimUnit.java b/business-css/src/main/java/com/yfd/business/css/model/SimUnit.java index b8fb3ec..10e8a8e 100644 --- a/business-css/src/main/java/com/yfd/business/css/model/SimUnit.java +++ b/business-css/src/main/java/com/yfd/business/css/model/SimUnit.java @@ -1,3 +1,9 @@ package com.yfd.business.css.model; -public record SimUnit(String unitId, String deviceId, String materialId, String deviceType) {} +import java.util.Map; + +public record SimUnit(String unitId, String deviceId, String materialId, String deviceType, Map staticProperties) { + public SimUnit(String unitId, String deviceId, String materialId, String deviceType) { + this(unitId, deviceId, materialId, deviceType, Map.of()); + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java b/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java index 83dc302..e301fb0 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java @@ -1,6 +1,8 @@ package com.yfd.business.css.service; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.yfd.business.css.domain.Scenario; import com.yfd.business.css.domain.ScenarioResult; import com.yfd.business.css.model.DeviceStepInfo; import com.yfd.business.css.model.InferRequest; @@ -18,6 +20,7 @@ import org.springframework.web.client.RestTemplate; import org.springframework.stereotype.Service; import java.util.*; + @Service public class DeviceInferService { @Value("${python.api.url:http://localhost:8000}") @@ -35,63 +38,94 @@ public class DeviceInferService { public void processDeviceInference(String projectId, String scenarioId, Map> groupedDevices) { - // 遍历每个设备类型,调用对应的模型进行推理 + // 1. 获取情景配置信息 + Scenario scenario = scenarioService.getById(scenarioId); + if (scenario == null) { + throw new IllegalArgumentException("场景 " + scenarioId + " 不存在"); + } + + // 全局算法类型 + String globalAlgorithmType = scenario.getAlgorithmType(); + if (globalAlgorithmType == null) { + throw new IllegalArgumentException("场景 " + scenarioId + " 未配置全局算法类型"); + } + + // 解析设备级算法配置 + Map deviceAlgoConfig = new HashMap<>(); + if (scenario.getDeviceAlgoConfig() != null && !scenario.getDeviceAlgoConfig().isBlank()) { + try { + deviceAlgoConfig = objectMapper.readValue(scenario.getDeviceAlgoConfig(), new TypeReference>() {}); + } catch (Exception e) { + System.err.println("解析设备算法配置失败: " + e.getMessage()); + } + } + + // 2. 遍历每个设备类型组 for (Map.Entry> entry : groupedDevices.entrySet()) { String deviceType = entry.getKey(); - // algorithmType通过scenarioId获取 - String algorithmType = scenarioService.getAlgorithmType(scenarioId); - if (algorithmType == null) { - throw new IllegalArgumentException("场景 " + scenarioId + " 未配置算法类型"); - } - - //modelPath根据模型类型、设备类型,从algorithm_model获取活动的model_path - String modelPath = algorithmModelService.getCurrentModelPath(algorithmType, deviceType); - System.out.println("modelPath="+modelPath); - if (modelPath == null) { - throw new IllegalArgumentException("未配置 " + algorithmType + " 模型路径"); - } List devices = entry.getValue(); - System.out.println("devices="+devices); - // 校验设备数据是否完整 - if (devices == null || devices.isEmpty()) { - throw new IllegalArgumentException("设备数据为空,无法进行模拟"); + + if (devices == null || devices.isEmpty()) continue; + + // 3. 将同一类型的设备,按算法类型进行二级分组 + Map> algoGroup = new HashMap<>(); + + for (DeviceStepInfo device : devices) { + // 优先使用设备特定配置,否则使用全局配置 + String algoType = deviceAlgoConfig.getOrDefault(device.getDeviceId(), globalAlgorithmType); + algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device); } - - // 封装推理请求 - InferRequest request = buildInferenceRequest(deviceType,devices,modelPath); - System.out.println("request="+request); - - // 调用Python推理服务 - InferResponse response = infer(request); - System.out.println("推理服务返回结果: " + response); - - // 处理推理结果 - if (response != null && response.getCode() == 0) { - // 重新构建InferResponse对象示例 - // 1. 从response获取数据 - int code = response.getCode(); - String msg = response.getMsg(); - InferResponse.InferData originalData = response.getData(); - - // 2. 重新构建InferData - InferResponse.InferData newData = new InferResponse.InferData(); - newData.setItems(originalData.getItems()); - newData.setMeta(originalData.getMeta()); - newData.setFeatures(originalData.getFeatures()); - newData.setKeff(originalData.getKeff()); - - // 3. 构建新的InferResponse - InferResponse reconstructedResponse = new InferResponse(); - reconstructedResponse.setCode(code); - reconstructedResponse.setMsg(msg); - reconstructedResponse.setData(newData); - - System.out.println("重新构建的response: " + reconstructedResponse); - - // 使用重新构建的response处理结果 - processInferenceResults(projectId, scenarioId, deviceType, devices, reconstructedResponse); - } else { - throw new RuntimeException("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误")); + + // 4. 对每个算法分组进行推理 + for (Map.Entry> algoEntry : algoGroup.entrySet()) { + String currentAlgoType = algoEntry.getKey(); + List currentDevices = algoEntry.getValue(); + + // 获取模型路径 + System.out.println("Processing inference for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType); + String modelPath = algorithmModelService.getCurrentModelPath(currentAlgoType, deviceType); + System.out.println("modelPath=" + modelPath); + + if (modelPath == null) { + System.err.println("Model path not found for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType); + // 这里可以选择抛异常中断,或者跳过该组继续处理其他组 + // 为了保证健壮性,这里选择记录错误并跳过,或者根据业务需求抛出异常 + // throw new IllegalArgumentException("未配置 " + currentAlgoType + " 模型路径 (deviceType: " + deviceType + ")"); + continue; + } + + // 封装推理请求 + InferRequest request = buildInferenceRequest(deviceType, currentDevices, modelPath); + // System.out.println("request=" + request); + + try { + // 调用Python推理服务 + InferResponse response = infer(request); + System.out.println("推理服务返回结果: code=" + (response != null ? response.getCode() : "null")); + + // 处理推理结果 + if (response != null && response.getCode() == 0) { + // 重新构建InferResponse对象示例 (为了兼容性保留原有逻辑) + InferResponse.InferData originalData = response.getData(); + InferResponse.InferData newData = new InferResponse.InferData(); + newData.setItems(originalData.getItems()); + newData.setMeta(originalData.getMeta()); + newData.setFeatures(originalData.getFeatures()); + newData.setKeff(originalData.getKeff()); + + InferResponse reconstructedResponse = new InferResponse(); + reconstructedResponse.setCode(response.getCode()); + reconstructedResponse.setMsg(response.getMsg()); + reconstructedResponse.setData(newData); + + processInferenceResults(projectId, scenarioId, deviceType, currentDevices, reconstructedResponse); + } else { + System.err.println("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误")); + } + } catch (Exception e) { + System.err.println("推理异常: " + e.getMessage()); + e.printStackTrace(); + } } } } diff --git a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java new file mode 100644 index 0000000..289d121 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java @@ -0,0 +1,36 @@ +package com.yfd.business.css.service; + +import com.baomidou.mybatisplus.extension.service.IService; +import com.yfd.business.css.domain.ModelTrainTask; +import org.springframework.web.multipart.MultipartFile; + +public interface ModelTrainService extends IService { + /** + * 上传数据集 + * @param file 文件 + * @return 文件保存路径 + */ + String uploadDataset(MultipartFile file); + + /** + * 提交训练任务 + * @param task 任务信息 + * @return 任务ID + */ + String submitTask(ModelTrainTask task); + + /** + * 同步任务状态 + * @param taskId 任务ID + * @return 最新任务信息 + */ + ModelTrainTask syncTaskStatus(String taskId); + + /** + * 发布模型 + * @param taskId 任务ID + * @param versionTag 版本号 + * @return 是否成功 + */ + boolean publishModel(String taskId, String versionTag); +} diff --git a/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java b/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java new file mode 100644 index 0000000..d31c426 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/service/SimInferService.java @@ -0,0 +1,90 @@ +package com.yfd.business.css.service; + +import com.yfd.business.css.model.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * 仿真推理服务 + * 负责将仿真计算结果 (SimContext) 转换为推理请求,并异步调用 Python 推理服务入库 + */ +@Service +public class SimInferService { + + @Autowired + private DeviceInferService deviceInferService; + + @Async + public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List units) { + try { + // 1. 准备元数据映射: unitId -> SimUnit + Map unitMap = units.stream() + .collect(Collectors.toMap(SimUnit::unitId, u -> u)); + + // 2. 转换 SimContext 为按 DeviceType 分组的 DeviceStepInfo 列表 + Map> groupedDevices = new HashMap<>(); + + // 遍历时间轴 + Map> timeline = context.getTimeline(); + for (Map.Entry> entry : timeline.entrySet()) { + int step = entry.getKey(); + Map props = entry.getValue(); + + // 按 unitId 聚合属性 + Map> unitProps = new HashMap<>(); + for (Map.Entry propEntry : props.entrySet()) { + SimPropertyKey key = propEntry.getKey(); + String unitId = key.unitId(); + String propName = key.property(); + Double value = propEntry.getValue(); + + unitProps.computeIfAbsent(unitId, k -> new HashMap<>()).put(propName, value); + } + + // 构建 DeviceStepInfo + for (Map.Entry> unitEntry : unitProps.entrySet()) { + String unitId = unitEntry.getKey(); + Map properties = unitEntry.getValue(); + + SimUnit unit = unitMap.get(unitId); + if (unit == null) continue; // 可能是中间变量或非设备单元 + + String deviceType = unit.deviceType(); + if (deviceType == null || deviceType.isEmpty()) continue; + + // 确保静态属性也包含在内(如果 SimContext 中没有,从 SimUnit 补充) + // SimService 应该已经初始化了静态属性到 Context,但为了保险起见: + for (Map.Entry staticProp : unit.staticProperties().entrySet()) { + properties.putIfAbsent(staticProp.getKey(), staticProp.getValue()); + } + + DeviceStepInfo info = new DeviceStepInfo(); + info.setDeviceId(unit.deviceId()); // unitId 通常等于 deviceId + info.setDeviceType(deviceType); + info.setProperties(properties); + info.setStep(step); + info.setTime(step); // 假设 time = step,或者根据步长计算 + + groupedDevices.computeIfAbsent(deviceType, k -> new ArrayList<>()).add(info); + } + } + + if (groupedDevices.isEmpty()) { + System.out.println("No device data found for inference."); + return; + } + + // 3. 调用 DeviceInferService 进行推理和入库 + // 复用现有的 processDeviceInference 方法,它处理了按类型分组的数据 + deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices); + + } catch (Exception e) { + System.err.println("Async inference failed: " + e.getMessage()); + e.printStackTrace(); + } + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/service/SimService.java b/business-css/src/main/java/com/yfd/business/css/service/SimService.java index b5c5b1a..0a25c6c 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/SimService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/SimService.java @@ -1,11 +1,28 @@ package com.yfd.business.css.service; import com.yfd.business.css.model.*; +import org.springframework.stereotype.Service; +import java.util.ArrayList; import java.util.List; +import java.util.Map; +/** + * 仿真核心计算引擎 + * 负责执行时间步推进,处理静态属性、事件输入、影响传播和强制覆盖 + */ +@Service public class SimService { + /** + * 运行仿真 + * + * @param units 仿真单元列表 (设备/物料) + * @param events 仿真事件列表 (属性变更) + * @param nodes 影响关系节点列表 + * @param steps 仿真总步数 + * @return 仿真上下文,包含所有时间步的状态快照 + */ public SimContext runSimulation(List units, List events, List nodes, @@ -13,41 +30,96 @@ public class SimService { SimContext ctx = new SimContext(); - // 初始化设备/物料属性 + // Step 1: 初始化静态基线 (t=0) + // 将所有单元的静态属性写入初始上下文 for (SimUnit unit : units) { - ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "device.power")); - ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "material.quantity")); - ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "keff")); // 占位 + unit.staticProperties().forEach((k, v) -> + ctx.setValue(SimPropertyKey.of(unit.unitId(), k), v) + ); } + ctx.snapshot(0); // 记录初始状态 - // 每步仿真 + // Step 2: 循环推进 (t=1 to steps) for (int step = 1; step <= steps; step++) { + // 2.1 Event (Input): 应用普通事件 (作为输入条件) + applyEvents(ctx, events, step, false); - // 事件 - for (SimEvent e : events) { - if (e.getStep() == step) { - ctx.setValue(e.getKey(), e.getValue()); - } - } + // 2.2 Influence: 计算影响关系 + // 基于当前 ctx (包含 static + input event) 计算派生属性 + applyInfluences(ctx, nodes, step); - // 影响关系 - for (SimInfluenceNode node : nodes) { - double sum = node.getBias(); - for (SimInfluenceSource s : node.getSources()) { - sum += ctx.getValue(s.getSource()) * s.getCoeff(); - } - ctx.setValue(node.getTarget(), sum); - } - - // keff 占位(后续可以替换为推理接口计算) - for (SimUnit unit : units) { - ctx.setValue(SimPropertyKey.of(unit.unitId(), "keff"), 0.0); - } + // 2.3 Event (Override): 应用强制覆盖事件 + // 再次覆盖,确保强制逻辑生效 (如故障注入、手动锁定) + applyEvents(ctx, events, step, true); + // 2.4 Snapshot: 保存当前步的状态快照 ctx.snapshot(step); } return ctx; } -} + /** + * 应用事件 + * @param ctx 上下文 + * @param events 事件列表 + * @param step 当前步 + * @param isOverride 是否为强制覆盖事件 + */ + private void applyEvents(SimContext ctx, List events, int step, boolean isOverride) { + for (SimEvent e : events) { + if (e.getStep() == step && e.isOverride() == isOverride) { + // 当前仅支持 SET 类型,ADD/MULTIPLY 可按需扩展 + ctx.setValue(e.getKey(), e.getValue()); + } + } + } + + /** + * 计算影响关系 + * @param ctx 上下文 + * @param nodes 影响节点列表 + * @param currentStep 当前步 + */ + private void applyInfluences(SimContext ctx, List nodes, int currentStep) { + // 简单实现:直接计算并更新。如果存在依赖环,可能需要多轮迭代或拓扑排序。 + // 这里假设无环或允许一帧内的即时传播。 + // 为了避免更新顺序影响结果(A依赖B,先算A还是先算B),理想情况下应使用双缓冲(读取上一帧或本帧快照)。 + // 但根据需求 "Influence(派生计算):基于当前状态执行影响关系计算",且 ProjectServiceImpl 中是直接读取当前 state。 + // ProjectServiceImpl 中 readValue 会读取 delay。 + + // 我们可以先计算所有变更,再统一应用,避免计算中间态影响后续计算(除非是有意为之的级联) + // 这里采用:先计算所有目标值,存入临时 Map,最后统一写入 ctx + + List> updates = new ArrayList<>(); + + for (SimInfluenceNode node : nodes) { + double sum = node.getBias(); + for (SimInfluenceSource s : node.getSources()) { + // 处理延迟: t - delay + // SimContext 目前只存储当前值。如果需要历史值,SimContext 需要支持根据 step 获取 snapshot。 + // 假设 SimContext.getValue(key) 返回当前值。 + // 如果有 delay,我们需要访问历史 snapshot。 + + // 暂时假设无 delay 或 delay=0,读取当前值 + // 如果 SimInfluenceSource 有 delay 属性,需要 SimContext 支持历史回溯 + // 修改 SimContext 增加 getValue(key, step) ? + // 目前 SimContext 实现未知,假设只能读当前。 + // 如果需要支持 delay,SimContext 需要保留 history。 + + // 简单起见,这里先读当前值。完善版本需要 SimContext 提供 getHistoryValue(key, step - delay) + + // 修正:SimContext 应该能获取历史。我们先看 SimContext 定义。 + // 假设 SimContext 内部有 snapshots。 + + double sourceVal = ctx.getValue(s.getSource()); // 暂读当前 + sum += sourceVal * s.getCoeff(); + } + updates.add(Map.entry(node.getTarget(), sum)); + } + + for (Map.Entry entry : updates) { + ctx.setValue(entry.getKey(), entry.getValue()); + } + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/AlgorithmModelServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/AlgorithmModelServiceImpl.java index 7d251a2..2075f6f 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/AlgorithmModelServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/AlgorithmModelServiceImpl.java @@ -16,12 +16,19 @@ public class AlgorithmModelServiceImpl extends ServiceImpl queryWrapper = new QueryWrapper<>(); queryWrapper.eq("algorithm_type", algorithmType) .eq("device_type", deviceType) .eq("is_current", 1); // 当前激活版本 AlgorithmModel model = getOne(queryWrapper); - return model != null ? model.getModelPath() : null; + if (model != null) { + System.out.println("Found model: " + model.getModelPath()); + return model.getModelPath(); + } else { + System.out.println("Model not found in database."); + return null; + } } @Override diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java new file mode 100644 index 0000000..f03d541 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java @@ -0,0 +1,228 @@ +package com.yfd.business.css.service.impl; + +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yfd.business.css.common.exception.BizException; +import com.yfd.business.css.domain.AlgorithmModel; +import com.yfd.business.css.domain.ModelTrainTask; +import com.yfd.business.css.mapper.ModelTrainTaskMapper; +import com.yfd.business.css.service.AlgorithmModelService; +import com.yfd.business.css.service.ModelTrainService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.multipart.MultipartFile; + +import java.io.File; +import java.io.IOException; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +@Service +public class ModelTrainServiceImpl extends ServiceImpl implements ModelTrainService { + + @Value("${file-space.upload-path:./data/uploads/}") + private String uploadPath; + + @Value("${python.api.url:http://localhost:8000}") + private String pythonApiUrl; + + @Autowired + private RestTemplate restTemplate; + + @Autowired + private AlgorithmModelService algorithmModelService; + + @Autowired + private ObjectMapper objectMapper; + + @Override + public String uploadDataset(MultipartFile file) { + if (file.isEmpty()) { + throw new BizException("上传文件不能为空"); + } + + String originalFilename = file.getOriginalFilename(); + String dateDir = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd")); + String uuid = UUID.randomUUID().toString(); + String extension = originalFilename.substring(originalFilename.lastIndexOf(".")); + String newFilename = uuid + extension; + + String saveDir = uploadPath + File.separator + dateDir; + File dir = new File(saveDir); + if (!dir.exists()) { + dir.mkdirs(); + } + + String fullPath = saveDir + File.separator + newFilename; + try { + file.transferTo(new File(fullPath)); + return fullPath; + } catch (IOException e) { + throw new BizException("文件保存失败: " + e.getMessage()); + } + } + + @Override + @Transactional + public String submitTask(ModelTrainTask task) { + // 1. 初始化状态 + task.setStatus("PENDING"); + if (task.getTaskId() == null) { + task.setTaskId(UUID.randomUUID().toString()); + } + this.save(task); + + // 2. 异步调用 Python 训练 + asyncCallTrain(task); + + return task.getTaskId(); + } + + @Async + public void asyncCallTrain(ModelTrainTask task) { + try { + // 更新状态为 TRAINING + task.setStatus("TRAINING"); + this.updateById(task); + + // 构建请求参数 + Map request = new HashMap<>(); + request.put("task_id", task.getTaskId()); + request.put("algorithm_type", task.getAlgorithmType()); + request.put("device_type", task.getDeviceType()); + request.put("dataset_path", task.getDatasetPath()); + request.put("hyperparameters", task.getTrainParams()); + + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity> entity = new HttpEntity<>(request, headers); + + String url = pythonApiUrl + "/v1/train"; + // 这里假设 Python 服务会立即返回,启动后台线程 + // 如果 Python 服务是阻塞的,这里的 Async 会起作用 + ResponseEntity response = restTemplate.postForEntity(url, entity, Map.class); + + if (response.getStatusCode().is2xxSuccessful()) { + // 调用成功,等待后续轮询状态 + System.out.println("训练任务提交成功: " + task.getTaskId()); + } else { + task.setStatus("FAILED"); + task.setErrorLog("提交训练任务失败: " + response.getStatusCode()); + this.updateById(task); + } + + } catch (Exception e) { + task.setStatus("FAILED"); + task.setErrorLog("调用 Python 服务异常: " + e.getMessage()); + this.updateById(task); + e.printStackTrace(); + } + } + + @Override + public ModelTrainTask syncTaskStatus(String taskId) { + ModelTrainTask task = this.getById(taskId); + if (task == null) { + throw new BizException("任务不存在"); + } + + // 只有在 TRAINING 状态才去查询 Python 服务 + if ("TRAINING".equals(task.getStatus()) || "PENDING".equals(task.getStatus())) { + try { + String url = pythonApiUrl + "/v1/train/status/" + taskId; + ResponseEntity response = restTemplate.getForEntity(url, Map.class); + + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + Map body = response.getBody(); + String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED + + if (status != null) { + task.setStatus(status); + + if ("SUCCESS".equals(status)) { + task.setModelOutputPath((String) body.get("model_path")); + task.setMetrics((Map) body.get("metrics")); + task.setFeatureMapSnapshot((Map) body.get("feature_map")); + task.setMetricsImagePath((String) body.get("metrics_image")); + } else if ("FAILED".equals(status)) { + task.setErrorLog((String) body.get("error")); + } else if ("TRAINING".equals(status)) { + // 可以更新进度或其他中间指标 + if (body.containsKey("metrics")) { + task.setMetrics((Map) body.get("metrics")); + } + } + + this.updateById(task); + } + } + } catch (Exception e) { + // 查询失败,暂不更新状态,或者记录日志 + System.err.println("同步任务状态失败: " + e.getMessage()); + } + } + + return task; + } + + @Override + @Transactional + public boolean publishModel(String taskId, String versionTag) { + ModelTrainTask task = this.getById(taskId); + if (task == null) { + throw new BizException("任务不存在"); + } + + if (!"SUCCESS".equals(task.getStatus())) { + throw new BizException("任务未完成或失败,无法发布"); + } + + // 检查版本号唯一性 + long count = algorithmModelService.count(new QueryWrapper() + .eq("algorithm_type", task.getAlgorithmType()) + .eq("device_type", task.getDeviceType()) + .eq("version_tag", versionTag)); + if (count > 0) { + throw new BizException("版本号已存在"); + } + + // 创建正式模型记录 + AlgorithmModel model = new AlgorithmModel(); + model.setAlgorithmType(task.getAlgorithmType()); + model.setDeviceType(task.getDeviceType()); + model.setVersionTag(versionTag); + model.setModelPath(task.getModelOutputPath()); // 这里简化处理,直接引用临时路径,实际生产建议移动文件到正式目录 + model.setMetricsImagePath(task.getMetricsImagePath()); + model.setTrainedAt(LocalDateTime.now()); + model.setIsCurrent(0); // 默认不激活 + + try { + if (task.getFeatureMapSnapshot() != null) { + model.setFeatureMapSnapshot(objectMapper.writeValueAsString(task.getFeatureMapSnapshot())); + } else { + model.setFeatureMapSnapshot("{}"); + } + + if (task.getMetrics() != null) { + model.setMetrics(objectMapper.writeValueAsString(task.getMetrics())); + } + } catch (JsonProcessingException e) { + throw new BizException("JSON 序列化失败"); + } + + return algorithmModelService.save(model); + } +}