diff --git a/business-css/src/main/java/com/yfd/business/css/model/DeviceAlgoConfigItem.java b/business-css/src/main/java/com/yfd/business/css/model/DeviceAlgoConfigItem.java new file mode 100644 index 0000000..56211da --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/model/DeviceAlgoConfigItem.java @@ -0,0 +1,41 @@ +package com.yfd.business.css.model; + +public class DeviceAlgoConfigItem { + private String deviceId; + private String algorithmType; + private String algorithmModelId; + + public DeviceAlgoConfigItem() { + } + + public DeviceAlgoConfigItem(String deviceId, String algorithmType, String algorithmModelId) { + this.deviceId = deviceId; + this.algorithmType = algorithmType; + this.algorithmModelId = algorithmModelId; + } + + public String getDeviceId() { + return deviceId; + } + + public void setDeviceId(String deviceId) { + this.deviceId = deviceId; + } + + public String getAlgorithmType() { + return algorithmType; + } + + public void setAlgorithmType(String algorithmType) { + this.algorithmType = algorithmType; + } + + public String getAlgorithmModelId() { + return algorithmModelId; + } + + public void setAlgorithmModelId(String algorithmModelId) { + this.algorithmModelId = algorithmModelId; + } +} + diff --git a/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java b/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java index 5b9ceb7..89002d1 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/DeviceInferService.java @@ -8,6 +8,7 @@ import com.yfd.business.css.domain.ScenarioResult; import com.yfd.business.css.domain.AlgorithmModel; import com.yfd.business.css.common.exception.ScenarioInferException; import com.yfd.business.css.model.DeviceStepInfo; +import com.yfd.business.css.model.DeviceAlgoConfigItem; import com.yfd.business.css.model.InferRequest; import com.yfd.business.css.model.InferResponse; @@ -69,15 +70,7 @@ public class DeviceInferService { globalAlgorithmType = "GPR"; } - // 解析设备级算法配置 - Map deviceAlgoConfig = new HashMap<>(); - if (scenario.getDeviceAlgoConfig() != null && !scenario.getDeviceAlgoConfig().isBlank()) { - try { - deviceAlgoConfig = objectMapper.readValue(scenario.getDeviceAlgoConfig(), new TypeReference>() {}); - } catch (Exception e) { - log.error("解析设备算法配置失败: {}", e.getMessage(), e); - } - } + Map deviceAlgoConfig = parseDeviceAlgoConfig(scenario.getDeviceAlgoConfig(), scenarioId); // 2. 遍历每个设备类型组 for (Map.Entry> entry : groupedDevices.entrySet()) { @@ -86,12 +79,11 @@ public class DeviceInferService { if (devices == null || devices.isEmpty()) continue; - // 3. 将同一类型的设备,按算法类型进行二级分组 Map> algoGroup = new HashMap<>(); for (DeviceStepInfo device : devices) { - // 优先使用设备特定配置,否则使用全局配置 - String algoType = deviceAlgoConfig.getOrDefault(device.getDeviceId(), globalAlgorithmType); + DeviceAlgoConfigItem cfg = deviceAlgoConfig.get(device.getDeviceId()); + String algoType = (cfg != null && trimToNull(cfg.getAlgorithmType()) != null) ? trimToNull(cfg.getAlgorithmType()) : globalAlgorithmType; algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device); } @@ -113,89 +105,102 @@ public class DeviceInferService { String currentMaterialType = matEntry.getKey(); List batchDevices = matEntry.getValue(); - // 获取模型对象(根据算法类型、设备类型、材料类型) - log.info("Processing inference for algorithmType: {}, deviceType: {}, materialType: {}", currentAlgoType, deviceType, currentMaterialType); - AlgorithmModel model = algorithmModelService.getCurrentModel(currentAlgoType, deviceType, currentMaterialType); - - if (model == null || model.getModelPath() == null) { - log.error("Model path not found for algorithmType: {}, deviceType: {}, materialType: {}", currentAlgoType, deviceType, currentMaterialType); - hasAnyError = true; - missingModels.add(Map.of( - "algorithmType", currentAlgoType, - "deviceType", deviceType, - "materialType", currentMaterialType - )); - continue; - } - - String modelRelPath = model.getModelPath(); - log.debug("modelRelPath={}", modelRelPath); - - // 解析模型的特征映射(feature_map_snapshot),优先以 input_cols 为准进行特征过滤 - List requiredFeatures = new ArrayList<>(); - if (model.getFeatureMapSnapshot() != null && !model.getFeatureMapSnapshot().isBlank()) { - try { - // 由于数据库里存的可能是转义的 JSON 字符串,我们需要多次解析直到它变成真正的 Object - JsonNode fNode = objectMapper.readTree(model.getFeatureMapSnapshot()); - if (fNode.isTextual()) { - fNode = objectMapper.readTree(fNode.asText()); - } - - if (fNode.isArray()) { - for (JsonNode node : fNode) requiredFeatures.add(node.asText()); - } else if (fNode.has("input_cols") && fNode.get("input_cols").isArray()) { - for (JsonNode node : fNode.get("input_cols")) requiredFeatures.add(node.asText()); - } else if (fNode.has("features") && fNode.get("features").isArray()) { - // 兼容旧版本 - for (JsonNode node : fNode.get("features")) requiredFeatures.add(node.asText()); - } - log.info("Parsed requiredFeatures from feature_map_snapshot: {}", requiredFeatures); - } catch (Exception e) { - log.warn("解析特征映射快照失败: {}", e.getMessage()); - } - } else { - log.warn("模型 feature_map_snapshot 为空,将不进行特征过滤"); + Map> modelGroup = new HashMap<>(); + for (DeviceStepInfo d : batchDevices) { + DeviceAlgoConfigItem cfg = deviceAlgoConfig.get(d.getDeviceId()); + String modelId = (cfg != null) ? trimToNull(cfg.getAlgorithmModelId()) : null; + String k = (modelId == null) ? "" : modelId; + modelGroup.computeIfAbsent(k, x -> new ArrayList<>()).add(d); } - // 将相对路径转换为绝对路径 - String absoluteModelPath = Paths.get(modelRootPath).resolve(modelRelPath).toAbsolutePath().normalize().toString(); - log.debug("Absolute modelPath={}", absoluteModelPath); + for (Map.Entry> mg : modelGroup.entrySet()) { + String algorithmModelId = mg.getKey().isEmpty() ? null : mg.getKey(); + List modelBatchDevices = mg.getValue(); - // 封装推理请求 - InferRequest request = buildInferenceRequest(deviceType, batchDevices, absoluteModelPath, requiredFeatures); - log.debug("request={}", request); + log.info("Processing inference for algorithmType: {}, deviceType: {}, materialType: {}, algorithmModelId: {}", currentAlgoType, deviceType, currentMaterialType, algorithmModelId); - try { - // 调用Python推理服务 - InferResponse response = infer(request); - log.info("推理服务返回结果: code={}", (response != null ? response.getCode() : "null")); - - // 处理推理结果 - if (response != null && response.getCode() == 0) { - // 重新构建InferResponse对象示例 (为了兼容性保留原有逻辑) - InferResponse.InferData originalData = response.getData(); - InferResponse.InferData newData = new InferResponse.InferData(); - newData.setItems(originalData.getItems()); - newData.setMeta(originalData.getMeta()); - newData.setFeatures(originalData.getFeatures()); - newData.setKeff(originalData.getKeff()); - - InferResponse reconstructedResponse = new InferResponse(); - reconstructedResponse.setCode(response.getCode()); - reconstructedResponse.setMsg(response.getMsg()); - reconstructedResponse.setData(newData); - - processInferenceResults(projectId, scenarioId, deviceType, batchDevices, reconstructedResponse); - hasAnySuccess = true; + AlgorithmModel model; + if (algorithmModelId != null) { + try { + model = resolveModelByIdChecked(algorithmModelId, currentAlgoType, deviceType, currentMaterialType); + } catch (IllegalArgumentException ex) { + hasAnyError = true; + if (ex.getMessage() != null && !ex.getMessage().isBlank()) { + errorMessages.add(ex.getMessage()); + } + continue; + } } else { - log.error("推理服务调用失败: {}", (response != null ? response.getMsg() : "未知错误")); - hasAnyError = true; + model = algorithmModelService.getCurrentModel(currentAlgoType, deviceType, currentMaterialType); } - } catch (Exception e) { - log.error("推理异常: {}", e.getMessage(), e); - hasAnyError = true; - if (e.getMessage() != null && !e.getMessage().isBlank()) { - errorMessages.add(e.getMessage()); + + if (model == null || trimToNull(model.getModelPath()) == null) { + hasAnyError = true; + Map miss = new LinkedHashMap<>(); + miss.put("algorithmType", currentAlgoType); + miss.put("deviceType", deviceType); + miss.put("materialType", currentMaterialType); + if (algorithmModelId != null) miss.put("algorithmModelId", algorithmModelId); + missingModels.add(miss); + continue; + } + + String modelRelPath = model.getModelPath(); + List requiredFeatures = new ArrayList<>(); + if (model.getFeatureMapSnapshot() != null && !model.getFeatureMapSnapshot().isBlank()) { + try { + JsonNode fNode = objectMapper.readTree(model.getFeatureMapSnapshot()); + if (fNode.isTextual()) { + fNode = objectMapper.readTree(fNode.asText()); + } + + if (fNode.isArray()) { + for (JsonNode node : fNode) requiredFeatures.add(node.asText()); + } else if (fNode.has("input_cols") && fNode.get("input_cols").isArray()) { + for (JsonNode node : fNode.get("input_cols")) requiredFeatures.add(node.asText()); + } else if (fNode.has("features") && fNode.get("features").isArray()) { + for (JsonNode node : fNode.get("features")) requiredFeatures.add(node.asText()); + } + log.info("Parsed requiredFeatures from feature_map_snapshot: {}", requiredFeatures); + } catch (Exception e) { + log.warn("解析特征映射快照失败: {}", e.getMessage()); + } + } else { + log.warn("模型 feature_map_snapshot 为空,将不进行特征过滤"); + } + + String absoluteModelPath = Paths.get(modelRootPath).resolve(modelRelPath).toAbsolutePath().normalize().toString(); + InferRequest request = buildInferenceRequest(deviceType, modelBatchDevices, absoluteModelPath, requiredFeatures); + + try { + InferResponse response = infer(request); + log.info("推理服务返回结果: code={}", (response != null ? response.getCode() : "null")); + + if (response != null && response.getCode() == 0) { + InferResponse.InferData originalData = response.getData(); + InferResponse.InferData newData = new InferResponse.InferData(); + newData.setItems(originalData.getItems()); + newData.setMeta(originalData.getMeta()); + newData.setFeatures(originalData.getFeatures()); + newData.setKeff(originalData.getKeff()); + + InferResponse reconstructedResponse = new InferResponse(); + reconstructedResponse.setCode(response.getCode()); + reconstructedResponse.setMsg(response.getMsg()); + reconstructedResponse.setData(newData); + + processInferenceResults(projectId, scenarioId, deviceType, modelBatchDevices, reconstructedResponse); + hasAnySuccess = true; + } else { + log.error("推理服务调用失败: {}", (response != null ? response.getMsg() : "未知错误")); + hasAnyError = true; + } + } catch (Exception e) { + log.error("推理异常: {}", e.getMessage(), e); + hasAnyError = true; + if (e.getMessage() != null && !e.getMessage().isBlank()) { + errorMessages.add(e.getMessage()); + } } } } @@ -236,11 +241,97 @@ public class DeviceInferService { private List> dedupeMissingModels(List> missingModels) { LinkedHashMap> m = new LinkedHashMap<>(); for (Map x : missingModels) { - String k = String.valueOf(x.get("algorithmType")) + "|" + String.valueOf(x.get("deviceType")) + "|" + String.valueOf(x.get("materialType")); + String k = String.valueOf(x.get("algorithmType")) + "|" + String.valueOf(x.get("deviceType")) + "|" + String.valueOf(x.get("materialType")) + "|" + String.valueOf(x.get("algorithmModelId")); m.putIfAbsent(k, x); } return new ArrayList<>(m.values()); } + + private Map parseDeviceAlgoConfig(String raw, String scenarioId) { + Map out = new HashMap<>(); + if (raw == null || raw.isBlank()) return out; + + try { + List list = objectMapper.readValue(raw, new TypeReference>() {}); + if (list != null) { + for (DeviceAlgoConfigItem x : list) { + if (x == null) continue; + String did = trimToNull(x.getDeviceId()); + if (did == null) continue; + out.put(did, x); + } + return out; + } + } catch (Exception ignored) { + } + + try { + Map legacy = objectMapper.readValue(raw, new TypeReference>() {}); + if (legacy != null) { + for (Map.Entry e : legacy.entrySet()) { + String did = trimToNull(e.getKey()); + if (did == null) continue; + out.put(did, new DeviceAlgoConfigItem(did, e.getValue(), null)); + } + } + return out; + } catch (Exception e) { + log.error("解析设备算法配置失败 scenarioId={}: {}", scenarioId, e.getMessage(), e); + return out; + } + } + + private String trimToNull(String s) { + if (s == null) return null; + String x = s.trim(); + return x.isEmpty() ? null : x; + } + + private AlgorithmModel resolveModelByIdChecked(String algorithmModelId, String algorithmType, String deviceType, String materialType) { + AlgorithmModel m = algorithmModelService.getById(algorithmModelId); + if (m == null) return null; + if (trimToNull(m.getModelPath()) == null) return null; + + String at = trimToNull(m.getAlgorithmType()); + String dt = trimToNull(m.getDeviceType()); + if (at != null && algorithmType != null && !at.equals(algorithmType)) { + throw new IllegalArgumentException("指定模型版本与算法类型不匹配 algorithmModelId=" + algorithmModelId + " algorithmType=" + algorithmType + " model.algorithmType=" + at); + } + if (dt != null && deviceType != null && !dt.equals(deviceType)) { + throw new IllegalArgumentException("指定模型版本与设备类型不匹配 algorithmModelId=" + algorithmModelId + " deviceType=" + deviceType + " model.deviceType=" + dt); + } + + String mtModel = normalizeMaterialType(m.getMaterialType()); + String mtReq = normalizeMaterialType(materialType); + if (mtReq == null) { + if (mtModel != null) { + throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=null model.materialType=" + mtModel); + } + } else { + if (mtModel != null) { + if ("Mixed".equals(mtReq)) { + if (!"Mixed".equals(mtModel)) { + throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=" + mtReq + " model.materialType=" + mtModel); + } + } else { + if (!mtReq.equals(mtModel)) { + throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=" + mtReq + " model.materialType=" + mtModel); + } + } + } + } + return m; + } + + private String normalizeMaterialType(String raw) { + if (raw == null) return null; + String s = raw.trim(); + if (s.isEmpty()) return null; + if ("U".equalsIgnoreCase(s)) return "U"; + if ("Pu".equalsIgnoreCase(s)) return "Pu"; + if ("Mixed".equalsIgnoreCase(s) || "MIX".equalsIgnoreCase(s)) return "Mixed"; + return s; + } private InferRequest buildInferenceRequest(String deviceType,List devices,String modelPath, List requiredFeatures) { InferRequest request = new InferRequest(); diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java index e260377..dc42a9b 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java @@ -564,7 +564,7 @@ public class ProjectServiceImpl if (mnode != null) { String topoMid = optText(mnode, "materialId"); if (topoMid != null && !topoMid.isEmpty()) { - Map info = new HashMap<>(); + Map info = new java.util.LinkedHashMap<>(); String mname = optText(mnode, "name"); if (mname != null) info.put("materialName", mname); JsonNode mstatic = mnode.path("static"); @@ -572,7 +572,12 @@ public class ProjectServiceImpl if (mstatic.path("u_concentration").isNumber()) info.put("u_concentration", mstatic.path("u_concentration").numberValue()); if (mstatic.path("u_enrichment").isNumber()) info.put("u_enrichment", mstatic.path("u_enrichment").numberValue()); if (mstatic.path("pu_concentration").isNumber()) info.put("pu_concentration", mstatic.path("pu_concentration").numberValue()); - if (mstatic.path("pu_isotope").isNumber()) info.put("pu_isotope", mstatic.path("pu_isotope").numberValue()); + // if (mstatic.path("pu_isotope").isNumber()) info.put("pu_isotope", mstatic.path("pu_isotope").numberValue()); + if (mstatic.path("e_pu238").isNumber()) info.put("e_pu238", mstatic.path("e_pu238").numberValue()); + if (mstatic.path("e_pu239").isNumber()) info.put("e_pu239", mstatic.path("e_pu239").numberValue()); + if (mstatic.path("e_pu240").isNumber()) info.put("e_pu240", mstatic.path("e_pu240").numberValue()); + if (mstatic.path("e_pu241").isNumber()) info.put("e_pu241", mstatic.path("e_pu241").numberValue()); + if (mstatic.path("e_pu242").isNumber()) info.put("e_pu242", mstatic.path("e_pu242").numberValue()); } matTopo.put(topoMid, info); } @@ -589,7 +594,7 @@ public class ProjectServiceImpl for (String did : deviceIds) { Device d = devMap.get(did); if (d == null) continue; - Map row = new HashMap<>(); + Map row = new java.util.LinkedHashMap<>(); row.put("deviceId", d.getDeviceId()); row.put("deviceName", d.getName()); row.put("deviceType", d.getType()); @@ -602,12 +607,27 @@ public class ProjectServiceImpl if (mid != null) { Material m = matMap.get(mid); row.put("materialId", mid); + row.put("materialName", null); + row.put("u_concentration", null); + row.put("u_enrichment", null); + row.put("pu_concentration", null); + // row.put("pu_isotope", null); + row.put("e_pu238", null); + row.put("e_pu239", null); + row.put("e_pu240", null); + row.put("e_pu241", null); + row.put("e_pu242", null); if (m != null) { row.put("materialName", m.getName()); row.put("u_concentration", m.getUConcentration()); row.put("u_enrichment", m.getUEnrichment()); row.put("pu_concentration", m.getPuConcentration()); - row.put("pu_isotope", m.getPuIsotope()); + // row.put("pu_isotope", m.getPuIsotope()); + row.put("e_pu238", m.getEPu238()); + row.put("e_pu239", m.getEPu239()); + row.put("e_pu240", m.getEPu240()); + row.put("e_pu241", m.getEPu241()); + row.put("e_pu242", m.getEPu242()); } else { Map info = matTopo.get(mid); if (info != null) { @@ -615,7 +635,12 @@ public class ProjectServiceImpl if (info.get("u_concentration") != null) row.put("u_concentration", info.get("u_concentration")); if (info.get("u_enrichment") != null) row.put("u_enrichment", info.get("u_enrichment")); if (info.get("pu_concentration") != null) row.put("pu_concentration", info.get("pu_concentration")); - if (info.get("pu_isotope") != null) row.put("pu_isotope", info.get("pu_isotope")); + // if (info.get("pu_isotope") != null) row.put("pu_isotope", info.get("pu_isotope")); + if (info.get("e_pu238") != null) row.put("e_pu238", info.get("e_pu238")); + if (info.get("e_pu239") != null) row.put("e_pu239", info.get("e_pu239")); + if (info.get("e_pu240") != null) row.put("e_pu240", info.get("e_pu240")); + if (info.get("e_pu241") != null) row.put("e_pu241", info.get("e_pu241")); + if (info.get("e_pu242") != null) row.put("e_pu242", info.get("e_pu242")); } } }