情景配置支持多模型版本

This commit is contained in:
wanxiaoli 2026-05-21 11:59:19 +08:00
parent 75bf0af5cd
commit e191934a53
3 changed files with 252 additions and 95 deletions

View File

@ -0,0 +1,41 @@
package com.yfd.business.css.model;
public class DeviceAlgoConfigItem {
private String deviceId;
private String algorithmType;
private String algorithmModelId;
public DeviceAlgoConfigItem() {
}
public DeviceAlgoConfigItem(String deviceId, String algorithmType, String algorithmModelId) {
this.deviceId = deviceId;
this.algorithmType = algorithmType;
this.algorithmModelId = algorithmModelId;
}
public String getDeviceId() {
return deviceId;
}
public void setDeviceId(String deviceId) {
this.deviceId = deviceId;
}
public String getAlgorithmType() {
return algorithmType;
}
public void setAlgorithmType(String algorithmType) {
this.algorithmType = algorithmType;
}
public String getAlgorithmModelId() {
return algorithmModelId;
}
public void setAlgorithmModelId(String algorithmModelId) {
this.algorithmModelId = algorithmModelId;
}
}

View File

@ -8,6 +8,7 @@ import com.yfd.business.css.domain.ScenarioResult;
import com.yfd.business.css.domain.AlgorithmModel; import com.yfd.business.css.domain.AlgorithmModel;
import com.yfd.business.css.common.exception.ScenarioInferException; import com.yfd.business.css.common.exception.ScenarioInferException;
import com.yfd.business.css.model.DeviceStepInfo; import com.yfd.business.css.model.DeviceStepInfo;
import com.yfd.business.css.model.DeviceAlgoConfigItem;
import com.yfd.business.css.model.InferRequest; import com.yfd.business.css.model.InferRequest;
import com.yfd.business.css.model.InferResponse; import com.yfd.business.css.model.InferResponse;
@ -69,15 +70,7 @@ public class DeviceInferService {
globalAlgorithmType = "GPR"; globalAlgorithmType = "GPR";
} }
// 解析设备级算法配置 Map<String, DeviceAlgoConfigItem> deviceAlgoConfig = parseDeviceAlgoConfig(scenario.getDeviceAlgoConfig(), scenarioId);
Map<String, String> deviceAlgoConfig = new HashMap<>();
if (scenario.getDeviceAlgoConfig() != null && !scenario.getDeviceAlgoConfig().isBlank()) {
try {
deviceAlgoConfig = objectMapper.readValue(scenario.getDeviceAlgoConfig(), new TypeReference<Map<String, String>>() {});
} catch (Exception e) {
log.error("解析设备算法配置失败: {}", e.getMessage(), e);
}
}
// 2. 遍历每个设备类型组 // 2. 遍历每个设备类型组
for (Map.Entry<String, List<DeviceStepInfo>> entry : groupedDevices.entrySet()) { for (Map.Entry<String, List<DeviceStepInfo>> entry : groupedDevices.entrySet()) {
@ -86,12 +79,11 @@ public class DeviceInferService {
if (devices == null || devices.isEmpty()) continue; if (devices == null || devices.isEmpty()) continue;
// 3. 将同一类型的设备按算法类型进行二级分组
Map<String, List<DeviceStepInfo>> algoGroup = new HashMap<>(); Map<String, List<DeviceStepInfo>> algoGroup = new HashMap<>();
for (DeviceStepInfo device : devices) { for (DeviceStepInfo device : devices) {
// 优先使用设备特定配置否则使用全局配置 DeviceAlgoConfigItem cfg = deviceAlgoConfig.get(device.getDeviceId());
String algoType = deviceAlgoConfig.getOrDefault(device.getDeviceId(), globalAlgorithmType); String algoType = (cfg != null && trimToNull(cfg.getAlgorithmType()) != null) ? trimToNull(cfg.getAlgorithmType()) : globalAlgorithmType;
algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device); algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device);
} }
@ -113,29 +105,50 @@ public class DeviceInferService {
String currentMaterialType = matEntry.getKey(); String currentMaterialType = matEntry.getKey();
List<DeviceStepInfo> batchDevices = matEntry.getValue(); List<DeviceStepInfo> batchDevices = matEntry.getValue();
// 获取模型对象根据算法类型设备类型材料类型 Map<String, List<DeviceStepInfo>> modelGroup = new HashMap<>();
log.info("Processing inference for algorithmType: {}, deviceType: {}, materialType: {}", currentAlgoType, deviceType, currentMaterialType); for (DeviceStepInfo d : batchDevices) {
AlgorithmModel model = algorithmModelService.getCurrentModel(currentAlgoType, deviceType, currentMaterialType); DeviceAlgoConfigItem cfg = deviceAlgoConfig.get(d.getDeviceId());
String modelId = (cfg != null) ? trimToNull(cfg.getAlgorithmModelId()) : null;
String k = (modelId == null) ? "" : modelId;
modelGroup.computeIfAbsent(k, x -> new ArrayList<>()).add(d);
}
if (model == null || model.getModelPath() == null) { for (Map.Entry<String, List<DeviceStepInfo>> mg : modelGroup.entrySet()) {
log.error("Model path not found for algorithmType: {}, deviceType: {}, materialType: {}", currentAlgoType, deviceType, currentMaterialType); String algorithmModelId = mg.getKey().isEmpty() ? null : mg.getKey();
List<DeviceStepInfo> modelBatchDevices = mg.getValue();
log.info("Processing inference for algorithmType: {}, deviceType: {}, materialType: {}, algorithmModelId: {}", currentAlgoType, deviceType, currentMaterialType, algorithmModelId);
AlgorithmModel model;
if (algorithmModelId != null) {
try {
model = resolveModelByIdChecked(algorithmModelId, currentAlgoType, deviceType, currentMaterialType);
} catch (IllegalArgumentException ex) {
hasAnyError = true; hasAnyError = true;
missingModels.add(Map.of( if (ex.getMessage() != null && !ex.getMessage().isBlank()) {
"algorithmType", currentAlgoType, errorMessages.add(ex.getMessage());
"deviceType", deviceType, }
"materialType", currentMaterialType continue;
)); }
} else {
model = algorithmModelService.getCurrentModel(currentAlgoType, deviceType, currentMaterialType);
}
if (model == null || trimToNull(model.getModelPath()) == null) {
hasAnyError = true;
Map<String, Object> miss = new LinkedHashMap<>();
miss.put("algorithmType", currentAlgoType);
miss.put("deviceType", deviceType);
miss.put("materialType", currentMaterialType);
if (algorithmModelId != null) miss.put("algorithmModelId", algorithmModelId);
missingModels.add(miss);
continue; continue;
} }
String modelRelPath = model.getModelPath(); String modelRelPath = model.getModelPath();
log.debug("modelRelPath={}", modelRelPath);
// 解析模型的特征映射feature_map_snapshot优先以 input_cols 为准进行特征过滤
List<String> requiredFeatures = new ArrayList<>(); List<String> requiredFeatures = new ArrayList<>();
if (model.getFeatureMapSnapshot() != null && !model.getFeatureMapSnapshot().isBlank()) { if (model.getFeatureMapSnapshot() != null && !model.getFeatureMapSnapshot().isBlank()) {
try { try {
// 由于数据库里存的可能是转义的 JSON 字符串我们需要多次解析直到它变成真正的 Object
JsonNode fNode = objectMapper.readTree(model.getFeatureMapSnapshot()); JsonNode fNode = objectMapper.readTree(model.getFeatureMapSnapshot());
if (fNode.isTextual()) { if (fNode.isTextual()) {
fNode = objectMapper.readTree(fNode.asText()); fNode = objectMapper.readTree(fNode.asText());
@ -146,7 +159,6 @@ public class DeviceInferService {
} else if (fNode.has("input_cols") && fNode.get("input_cols").isArray()) { } else if (fNode.has("input_cols") && fNode.get("input_cols").isArray()) {
for (JsonNode node : fNode.get("input_cols")) requiredFeatures.add(node.asText()); for (JsonNode node : fNode.get("input_cols")) requiredFeatures.add(node.asText());
} else if (fNode.has("features") && fNode.get("features").isArray()) { } else if (fNode.has("features") && fNode.get("features").isArray()) {
// 兼容旧版本
for (JsonNode node : fNode.get("features")) requiredFeatures.add(node.asText()); for (JsonNode node : fNode.get("features")) requiredFeatures.add(node.asText());
} }
log.info("Parsed requiredFeatures from feature_map_snapshot: {}", requiredFeatures); log.info("Parsed requiredFeatures from feature_map_snapshot: {}", requiredFeatures);
@ -157,22 +169,14 @@ public class DeviceInferService {
log.warn("模型 feature_map_snapshot 为空,将不进行特征过滤"); log.warn("模型 feature_map_snapshot 为空,将不进行特征过滤");
} }
// 将相对路径转换为绝对路径
String absoluteModelPath = Paths.get(modelRootPath).resolve(modelRelPath).toAbsolutePath().normalize().toString(); String absoluteModelPath = Paths.get(modelRootPath).resolve(modelRelPath).toAbsolutePath().normalize().toString();
log.debug("Absolute modelPath={}", absoluteModelPath); InferRequest request = buildInferenceRequest(deviceType, modelBatchDevices, absoluteModelPath, requiredFeatures);
// 封装推理请求
InferRequest request = buildInferenceRequest(deviceType, batchDevices, absoluteModelPath, requiredFeatures);
log.debug("request={}", request);
try { try {
// 调用Python推理服务
InferResponse response = infer(request); InferResponse response = infer(request);
log.info("推理服务返回结果: code={}", (response != null ? response.getCode() : "null")); log.info("推理服务返回结果: code={}", (response != null ? response.getCode() : "null"));
// 处理推理结果
if (response != null && response.getCode() == 0) { if (response != null && response.getCode() == 0) {
// 重新构建InferResponse对象示例 (为了兼容性保留原有逻辑)
InferResponse.InferData originalData = response.getData(); InferResponse.InferData originalData = response.getData();
InferResponse.InferData newData = new InferResponse.InferData(); InferResponse.InferData newData = new InferResponse.InferData();
newData.setItems(originalData.getItems()); newData.setItems(originalData.getItems());
@ -185,7 +189,7 @@ public class DeviceInferService {
reconstructedResponse.setMsg(response.getMsg()); reconstructedResponse.setMsg(response.getMsg());
reconstructedResponse.setData(newData); reconstructedResponse.setData(newData);
processInferenceResults(projectId, scenarioId, deviceType, batchDevices, reconstructedResponse); processInferenceResults(projectId, scenarioId, deviceType, modelBatchDevices, reconstructedResponse);
hasAnySuccess = true; hasAnySuccess = true;
} else { } else {
log.error("推理服务调用失败: {}", (response != null ? response.getMsg() : "未知错误")); log.error("推理服务调用失败: {}", (response != null ? response.getMsg() : "未知错误"));
@ -201,6 +205,7 @@ public class DeviceInferService {
} }
} }
} }
}
// 5. 最终检查如果没有任何成功的推理且发生过错误抛出异常以通知上层 // 5. 最终检查如果没有任何成功的推理且发生过错误抛出异常以通知上层
if (!hasAnySuccess && hasAnyError) { if (!hasAnySuccess && hasAnyError) {
@ -236,12 +241,98 @@ public class DeviceInferService {
private List<Map<String, Object>> dedupeMissingModels(List<Map<String, Object>> missingModels) { private List<Map<String, Object>> dedupeMissingModels(List<Map<String, Object>> missingModels) {
LinkedHashMap<String, Map<String, Object>> m = new LinkedHashMap<>(); LinkedHashMap<String, Map<String, Object>> m = new LinkedHashMap<>();
for (Map<String, Object> x : missingModels) { for (Map<String, Object> x : missingModels) {
String k = String.valueOf(x.get("algorithmType")) + "|" + String.valueOf(x.get("deviceType")) + "|" + String.valueOf(x.get("materialType")); String k = String.valueOf(x.get("algorithmType")) + "|" + String.valueOf(x.get("deviceType")) + "|" + String.valueOf(x.get("materialType")) + "|" + String.valueOf(x.get("algorithmModelId"));
m.putIfAbsent(k, x); m.putIfAbsent(k, x);
} }
return new ArrayList<>(m.values()); return new ArrayList<>(m.values());
} }
private Map<String, DeviceAlgoConfigItem> parseDeviceAlgoConfig(String raw, String scenarioId) {
Map<String, DeviceAlgoConfigItem> out = new HashMap<>();
if (raw == null || raw.isBlank()) return out;
try {
List<DeviceAlgoConfigItem> list = objectMapper.readValue(raw, new TypeReference<List<DeviceAlgoConfigItem>>() {});
if (list != null) {
for (DeviceAlgoConfigItem x : list) {
if (x == null) continue;
String did = trimToNull(x.getDeviceId());
if (did == null) continue;
out.put(did, x);
}
return out;
}
} catch (Exception ignored) {
}
try {
Map<String, String> legacy = objectMapper.readValue(raw, new TypeReference<Map<String, String>>() {});
if (legacy != null) {
for (Map.Entry<String, String> e : legacy.entrySet()) {
String did = trimToNull(e.getKey());
if (did == null) continue;
out.put(did, new DeviceAlgoConfigItem(did, e.getValue(), null));
}
}
return out;
} catch (Exception e) {
log.error("解析设备算法配置失败 scenarioId={}: {}", scenarioId, e.getMessage(), e);
return out;
}
}
private String trimToNull(String s) {
if (s == null) return null;
String x = s.trim();
return x.isEmpty() ? null : x;
}
private AlgorithmModel resolveModelByIdChecked(String algorithmModelId, String algorithmType, String deviceType, String materialType) {
AlgorithmModel m = algorithmModelService.getById(algorithmModelId);
if (m == null) return null;
if (trimToNull(m.getModelPath()) == null) return null;
String at = trimToNull(m.getAlgorithmType());
String dt = trimToNull(m.getDeviceType());
if (at != null && algorithmType != null && !at.equals(algorithmType)) {
throw new IllegalArgumentException("指定模型版本与算法类型不匹配 algorithmModelId=" + algorithmModelId + " algorithmType=" + algorithmType + " model.algorithmType=" + at);
}
if (dt != null && deviceType != null && !dt.equals(deviceType)) {
throw new IllegalArgumentException("指定模型版本与设备类型不匹配 algorithmModelId=" + algorithmModelId + " deviceType=" + deviceType + " model.deviceType=" + dt);
}
String mtModel = normalizeMaterialType(m.getMaterialType());
String mtReq = normalizeMaterialType(materialType);
if (mtReq == null) {
if (mtModel != null) {
throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=null model.materialType=" + mtModel);
}
} else {
if (mtModel != null) {
if ("Mixed".equals(mtReq)) {
if (!"Mixed".equals(mtModel)) {
throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=" + mtReq + " model.materialType=" + mtModel);
}
} else {
if (!mtReq.equals(mtModel)) {
throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=" + mtReq + " model.materialType=" + mtModel);
}
}
}
}
return m;
}
private String normalizeMaterialType(String raw) {
if (raw == null) return null;
String s = raw.trim();
if (s.isEmpty()) return null;
if ("U".equalsIgnoreCase(s)) return "U";
if ("Pu".equalsIgnoreCase(s)) return "Pu";
if ("Mixed".equalsIgnoreCase(s) || "MIX".equalsIgnoreCase(s)) return "Mixed";
return s;
}
private InferRequest buildInferenceRequest(String deviceType,List<DeviceStepInfo> devices,String modelPath, List<String> requiredFeatures) { private InferRequest buildInferenceRequest(String deviceType,List<DeviceStepInfo> devices,String modelPath, List<String> requiredFeatures) {
InferRequest request = new InferRequest(); InferRequest request = new InferRequest();
request.setModelDir(modelPath); // 设置模型路径 request.setModelDir(modelPath); // 设置模型路径

View File

@ -564,7 +564,7 @@ public class ProjectServiceImpl
if (mnode != null) { if (mnode != null) {
String topoMid = optText(mnode, "materialId"); String topoMid = optText(mnode, "materialId");
if (topoMid != null && !topoMid.isEmpty()) { if (topoMid != null && !topoMid.isEmpty()) {
Map<String, Object> info = new HashMap<>(); Map<String, Object> info = new java.util.LinkedHashMap<>();
String mname = optText(mnode, "name"); String mname = optText(mnode, "name");
if (mname != null) info.put("materialName", mname); if (mname != null) info.put("materialName", mname);
JsonNode mstatic = mnode.path("static"); JsonNode mstatic = mnode.path("static");
@ -572,7 +572,12 @@ public class ProjectServiceImpl
if (mstatic.path("u_concentration").isNumber()) info.put("u_concentration", mstatic.path("u_concentration").numberValue()); if (mstatic.path("u_concentration").isNumber()) info.put("u_concentration", mstatic.path("u_concentration").numberValue());
if (mstatic.path("u_enrichment").isNumber()) info.put("u_enrichment", mstatic.path("u_enrichment").numberValue()); if (mstatic.path("u_enrichment").isNumber()) info.put("u_enrichment", mstatic.path("u_enrichment").numberValue());
if (mstatic.path("pu_concentration").isNumber()) info.put("pu_concentration", mstatic.path("pu_concentration").numberValue()); if (mstatic.path("pu_concentration").isNumber()) info.put("pu_concentration", mstatic.path("pu_concentration").numberValue());
if (mstatic.path("pu_isotope").isNumber()) info.put("pu_isotope", mstatic.path("pu_isotope").numberValue()); // if (mstatic.path("pu_isotope").isNumber()) info.put("pu_isotope", mstatic.path("pu_isotope").numberValue());
if (mstatic.path("e_pu238").isNumber()) info.put("e_pu238", mstatic.path("e_pu238").numberValue());
if (mstatic.path("e_pu239").isNumber()) info.put("e_pu239", mstatic.path("e_pu239").numberValue());
if (mstatic.path("e_pu240").isNumber()) info.put("e_pu240", mstatic.path("e_pu240").numberValue());
if (mstatic.path("e_pu241").isNumber()) info.put("e_pu241", mstatic.path("e_pu241").numberValue());
if (mstatic.path("e_pu242").isNumber()) info.put("e_pu242", mstatic.path("e_pu242").numberValue());
} }
matTopo.put(topoMid, info); matTopo.put(topoMid, info);
} }
@ -589,7 +594,7 @@ public class ProjectServiceImpl
for (String did : deviceIds) { for (String did : deviceIds) {
Device d = devMap.get(did); Device d = devMap.get(did);
if (d == null) continue; if (d == null) continue;
Map<String, Object> row = new HashMap<>(); Map<String, Object> row = new java.util.LinkedHashMap<>();
row.put("deviceId", d.getDeviceId()); row.put("deviceId", d.getDeviceId());
row.put("deviceName", d.getName()); row.put("deviceName", d.getName());
row.put("deviceType", d.getType()); row.put("deviceType", d.getType());
@ -602,12 +607,27 @@ public class ProjectServiceImpl
if (mid != null) { if (mid != null) {
Material m = matMap.get(mid); Material m = matMap.get(mid);
row.put("materialId", mid); row.put("materialId", mid);
row.put("materialName", null);
row.put("u_concentration", null);
row.put("u_enrichment", null);
row.put("pu_concentration", null);
// row.put("pu_isotope", null);
row.put("e_pu238", null);
row.put("e_pu239", null);
row.put("e_pu240", null);
row.put("e_pu241", null);
row.put("e_pu242", null);
if (m != null) { if (m != null) {
row.put("materialName", m.getName()); row.put("materialName", m.getName());
row.put("u_concentration", m.getUConcentration()); row.put("u_concentration", m.getUConcentration());
row.put("u_enrichment", m.getUEnrichment()); row.put("u_enrichment", m.getUEnrichment());
row.put("pu_concentration", m.getPuConcentration()); row.put("pu_concentration", m.getPuConcentration());
row.put("pu_isotope", m.getPuIsotope()); // row.put("pu_isotope", m.getPuIsotope());
row.put("e_pu238", m.getEPu238());
row.put("e_pu239", m.getEPu239());
row.put("e_pu240", m.getEPu240());
row.put("e_pu241", m.getEPu241());
row.put("e_pu242", m.getEPu242());
} else { } else {
Map<String, Object> info = matTopo.get(mid); Map<String, Object> info = matTopo.get(mid);
if (info != null) { if (info != null) {
@ -615,7 +635,12 @@ public class ProjectServiceImpl
if (info.get("u_concentration") != null) row.put("u_concentration", info.get("u_concentration")); if (info.get("u_concentration") != null) row.put("u_concentration", info.get("u_concentration"));
if (info.get("u_enrichment") != null) row.put("u_enrichment", info.get("u_enrichment")); if (info.get("u_enrichment") != null) row.put("u_enrichment", info.get("u_enrichment"));
if (info.get("pu_concentration") != null) row.put("pu_concentration", info.get("pu_concentration")); if (info.get("pu_concentration") != null) row.put("pu_concentration", info.get("pu_concentration"));
if (info.get("pu_isotope") != null) row.put("pu_isotope", info.get("pu_isotope")); // if (info.get("pu_isotope") != null) row.put("pu_isotope", info.get("pu_isotope"));
if (info.get("e_pu238") != null) row.put("e_pu238", info.get("e_pu238"));
if (info.get("e_pu239") != null) row.put("e_pu239", info.get("e_pu239"));
if (info.get("e_pu240") != null) row.put("e_pu240", info.get("e_pu240"));
if (info.get("e_pu241") != null) row.put("e_pu241", info.get("e_pu241"));
if (info.get("e_pu242") != null) row.put("e_pu242", info.get("e_pu242"));
} }
} }
} }