情景配置支持多模型版本
This commit is contained in:
parent
75bf0af5cd
commit
e191934a53
@ -0,0 +1,41 @@
|
||||
package com.yfd.business.css.model;
|
||||
|
||||
public class DeviceAlgoConfigItem {
|
||||
private String deviceId;
|
||||
private String algorithmType;
|
||||
private String algorithmModelId;
|
||||
|
||||
public DeviceAlgoConfigItem() {
|
||||
}
|
||||
|
||||
public DeviceAlgoConfigItem(String deviceId, String algorithmType, String algorithmModelId) {
|
||||
this.deviceId = deviceId;
|
||||
this.algorithmType = algorithmType;
|
||||
this.algorithmModelId = algorithmModelId;
|
||||
}
|
||||
|
||||
public String getDeviceId() {
|
||||
return deviceId;
|
||||
}
|
||||
|
||||
public void setDeviceId(String deviceId) {
|
||||
this.deviceId = deviceId;
|
||||
}
|
||||
|
||||
public String getAlgorithmType() {
|
||||
return algorithmType;
|
||||
}
|
||||
|
||||
public void setAlgorithmType(String algorithmType) {
|
||||
this.algorithmType = algorithmType;
|
||||
}
|
||||
|
||||
public String getAlgorithmModelId() {
|
||||
return algorithmModelId;
|
||||
}
|
||||
|
||||
public void setAlgorithmModelId(String algorithmModelId) {
|
||||
this.algorithmModelId = algorithmModelId;
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import com.yfd.business.css.domain.ScenarioResult;
|
||||
import com.yfd.business.css.domain.AlgorithmModel;
|
||||
import com.yfd.business.css.common.exception.ScenarioInferException;
|
||||
import com.yfd.business.css.model.DeviceStepInfo;
|
||||
import com.yfd.business.css.model.DeviceAlgoConfigItem;
|
||||
import com.yfd.business.css.model.InferRequest;
|
||||
import com.yfd.business.css.model.InferResponse;
|
||||
|
||||
@ -69,15 +70,7 @@ public class DeviceInferService {
|
||||
globalAlgorithmType = "GPR";
|
||||
}
|
||||
|
||||
// 解析设备级算法配置
|
||||
Map<String, String> deviceAlgoConfig = new HashMap<>();
|
||||
if (scenario.getDeviceAlgoConfig() != null && !scenario.getDeviceAlgoConfig().isBlank()) {
|
||||
try {
|
||||
deviceAlgoConfig = objectMapper.readValue(scenario.getDeviceAlgoConfig(), new TypeReference<Map<String, String>>() {});
|
||||
} catch (Exception e) {
|
||||
log.error("解析设备算法配置失败: {}", e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
Map<String, DeviceAlgoConfigItem> deviceAlgoConfig = parseDeviceAlgoConfig(scenario.getDeviceAlgoConfig(), scenarioId);
|
||||
|
||||
// 2. 遍历每个设备类型组
|
||||
for (Map.Entry<String, List<DeviceStepInfo>> entry : groupedDevices.entrySet()) {
|
||||
@ -86,12 +79,11 @@ public class DeviceInferService {
|
||||
|
||||
if (devices == null || devices.isEmpty()) continue;
|
||||
|
||||
// 3. 将同一类型的设备,按算法类型进行二级分组
|
||||
Map<String, List<DeviceStepInfo>> algoGroup = new HashMap<>();
|
||||
|
||||
for (DeviceStepInfo device : devices) {
|
||||
// 优先使用设备特定配置,否则使用全局配置
|
||||
String algoType = deviceAlgoConfig.getOrDefault(device.getDeviceId(), globalAlgorithmType);
|
||||
DeviceAlgoConfigItem cfg = deviceAlgoConfig.get(device.getDeviceId());
|
||||
String algoType = (cfg != null && trimToNull(cfg.getAlgorithmType()) != null) ? trimToNull(cfg.getAlgorithmType()) : globalAlgorithmType;
|
||||
algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device);
|
||||
}
|
||||
|
||||
@ -113,29 +105,50 @@ public class DeviceInferService {
|
||||
String currentMaterialType = matEntry.getKey();
|
||||
List<DeviceStepInfo> batchDevices = matEntry.getValue();
|
||||
|
||||
// 获取模型对象(根据算法类型、设备类型、材料类型)
|
||||
log.info("Processing inference for algorithmType: {}, deviceType: {}, materialType: {}", currentAlgoType, deviceType, currentMaterialType);
|
||||
AlgorithmModel model = algorithmModelService.getCurrentModel(currentAlgoType, deviceType, currentMaterialType);
|
||||
Map<String, List<DeviceStepInfo>> modelGroup = new HashMap<>();
|
||||
for (DeviceStepInfo d : batchDevices) {
|
||||
DeviceAlgoConfigItem cfg = deviceAlgoConfig.get(d.getDeviceId());
|
||||
String modelId = (cfg != null) ? trimToNull(cfg.getAlgorithmModelId()) : null;
|
||||
String k = (modelId == null) ? "" : modelId;
|
||||
modelGroup.computeIfAbsent(k, x -> new ArrayList<>()).add(d);
|
||||
}
|
||||
|
||||
if (model == null || model.getModelPath() == null) {
|
||||
log.error("Model path not found for algorithmType: {}, deviceType: {}, materialType: {}", currentAlgoType, deviceType, currentMaterialType);
|
||||
for (Map.Entry<String, List<DeviceStepInfo>> mg : modelGroup.entrySet()) {
|
||||
String algorithmModelId = mg.getKey().isEmpty() ? null : mg.getKey();
|
||||
List<DeviceStepInfo> modelBatchDevices = mg.getValue();
|
||||
|
||||
log.info("Processing inference for algorithmType: {}, deviceType: {}, materialType: {}, algorithmModelId: {}", currentAlgoType, deviceType, currentMaterialType, algorithmModelId);
|
||||
|
||||
AlgorithmModel model;
|
||||
if (algorithmModelId != null) {
|
||||
try {
|
||||
model = resolveModelByIdChecked(algorithmModelId, currentAlgoType, deviceType, currentMaterialType);
|
||||
} catch (IllegalArgumentException ex) {
|
||||
hasAnyError = true;
|
||||
missingModels.add(Map.of(
|
||||
"algorithmType", currentAlgoType,
|
||||
"deviceType", deviceType,
|
||||
"materialType", currentMaterialType
|
||||
));
|
||||
if (ex.getMessage() != null && !ex.getMessage().isBlank()) {
|
||||
errorMessages.add(ex.getMessage());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
model = algorithmModelService.getCurrentModel(currentAlgoType, deviceType, currentMaterialType);
|
||||
}
|
||||
|
||||
if (model == null || trimToNull(model.getModelPath()) == null) {
|
||||
hasAnyError = true;
|
||||
Map<String, Object> miss = new LinkedHashMap<>();
|
||||
miss.put("algorithmType", currentAlgoType);
|
||||
miss.put("deviceType", deviceType);
|
||||
miss.put("materialType", currentMaterialType);
|
||||
if (algorithmModelId != null) miss.put("algorithmModelId", algorithmModelId);
|
||||
missingModels.add(miss);
|
||||
continue;
|
||||
}
|
||||
|
||||
String modelRelPath = model.getModelPath();
|
||||
log.debug("modelRelPath={}", modelRelPath);
|
||||
|
||||
// 解析模型的特征映射(feature_map_snapshot),优先以 input_cols 为准进行特征过滤
|
||||
List<String> requiredFeatures = new ArrayList<>();
|
||||
if (model.getFeatureMapSnapshot() != null && !model.getFeatureMapSnapshot().isBlank()) {
|
||||
try {
|
||||
// 由于数据库里存的可能是转义的 JSON 字符串,我们需要多次解析直到它变成真正的 Object
|
||||
JsonNode fNode = objectMapper.readTree(model.getFeatureMapSnapshot());
|
||||
if (fNode.isTextual()) {
|
||||
fNode = objectMapper.readTree(fNode.asText());
|
||||
@ -146,7 +159,6 @@ public class DeviceInferService {
|
||||
} else if (fNode.has("input_cols") && fNode.get("input_cols").isArray()) {
|
||||
for (JsonNode node : fNode.get("input_cols")) requiredFeatures.add(node.asText());
|
||||
} else if (fNode.has("features") && fNode.get("features").isArray()) {
|
||||
// 兼容旧版本
|
||||
for (JsonNode node : fNode.get("features")) requiredFeatures.add(node.asText());
|
||||
}
|
||||
log.info("Parsed requiredFeatures from feature_map_snapshot: {}", requiredFeatures);
|
||||
@ -157,22 +169,14 @@ public class DeviceInferService {
|
||||
log.warn("模型 feature_map_snapshot 为空,将不进行特征过滤");
|
||||
}
|
||||
|
||||
// 将相对路径转换为绝对路径
|
||||
String absoluteModelPath = Paths.get(modelRootPath).resolve(modelRelPath).toAbsolutePath().normalize().toString();
|
||||
log.debug("Absolute modelPath={}", absoluteModelPath);
|
||||
|
||||
// 封装推理请求
|
||||
InferRequest request = buildInferenceRequest(deviceType, batchDevices, absoluteModelPath, requiredFeatures);
|
||||
log.debug("request={}", request);
|
||||
InferRequest request = buildInferenceRequest(deviceType, modelBatchDevices, absoluteModelPath, requiredFeatures);
|
||||
|
||||
try {
|
||||
// 调用Python推理服务
|
||||
InferResponse response = infer(request);
|
||||
log.info("推理服务返回结果: code={}", (response != null ? response.getCode() : "null"));
|
||||
|
||||
// 处理推理结果
|
||||
if (response != null && response.getCode() == 0) {
|
||||
// 重新构建InferResponse对象示例 (为了兼容性保留原有逻辑)
|
||||
InferResponse.InferData originalData = response.getData();
|
||||
InferResponse.InferData newData = new InferResponse.InferData();
|
||||
newData.setItems(originalData.getItems());
|
||||
@ -185,7 +189,7 @@ public class DeviceInferService {
|
||||
reconstructedResponse.setMsg(response.getMsg());
|
||||
reconstructedResponse.setData(newData);
|
||||
|
||||
processInferenceResults(projectId, scenarioId, deviceType, batchDevices, reconstructedResponse);
|
||||
processInferenceResults(projectId, scenarioId, deviceType, modelBatchDevices, reconstructedResponse);
|
||||
hasAnySuccess = true;
|
||||
} else {
|
||||
log.error("推理服务调用失败: {}", (response != null ? response.getMsg() : "未知错误"));
|
||||
@ -201,6 +205,7 @@ public class DeviceInferService {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 最终检查:如果没有任何成功的推理,且发生过错误,抛出异常以通知上层
|
||||
if (!hasAnySuccess && hasAnyError) {
|
||||
@ -236,12 +241,98 @@ public class DeviceInferService {
|
||||
private List<Map<String, Object>> dedupeMissingModels(List<Map<String, Object>> missingModels) {
|
||||
LinkedHashMap<String, Map<String, Object>> m = new LinkedHashMap<>();
|
||||
for (Map<String, Object> x : missingModels) {
|
||||
String k = String.valueOf(x.get("algorithmType")) + "|" + String.valueOf(x.get("deviceType")) + "|" + String.valueOf(x.get("materialType"));
|
||||
String k = String.valueOf(x.get("algorithmType")) + "|" + String.valueOf(x.get("deviceType")) + "|" + String.valueOf(x.get("materialType")) + "|" + String.valueOf(x.get("algorithmModelId"));
|
||||
m.putIfAbsent(k, x);
|
||||
}
|
||||
return new ArrayList<>(m.values());
|
||||
}
|
||||
|
||||
private Map<String, DeviceAlgoConfigItem> parseDeviceAlgoConfig(String raw, String scenarioId) {
|
||||
Map<String, DeviceAlgoConfigItem> out = new HashMap<>();
|
||||
if (raw == null || raw.isBlank()) return out;
|
||||
|
||||
try {
|
||||
List<DeviceAlgoConfigItem> list = objectMapper.readValue(raw, new TypeReference<List<DeviceAlgoConfigItem>>() {});
|
||||
if (list != null) {
|
||||
for (DeviceAlgoConfigItem x : list) {
|
||||
if (x == null) continue;
|
||||
String did = trimToNull(x.getDeviceId());
|
||||
if (did == null) continue;
|
||||
out.put(did, x);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
|
||||
try {
|
||||
Map<String, String> legacy = objectMapper.readValue(raw, new TypeReference<Map<String, String>>() {});
|
||||
if (legacy != null) {
|
||||
for (Map.Entry<String, String> e : legacy.entrySet()) {
|
||||
String did = trimToNull(e.getKey());
|
||||
if (did == null) continue;
|
||||
out.put(did, new DeviceAlgoConfigItem(did, e.getValue(), null));
|
||||
}
|
||||
}
|
||||
return out;
|
||||
} catch (Exception e) {
|
||||
log.error("解析设备算法配置失败 scenarioId={}: {}", scenarioId, e.getMessage(), e);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
private String trimToNull(String s) {
|
||||
if (s == null) return null;
|
||||
String x = s.trim();
|
||||
return x.isEmpty() ? null : x;
|
||||
}
|
||||
|
||||
private AlgorithmModel resolveModelByIdChecked(String algorithmModelId, String algorithmType, String deviceType, String materialType) {
|
||||
AlgorithmModel m = algorithmModelService.getById(algorithmModelId);
|
||||
if (m == null) return null;
|
||||
if (trimToNull(m.getModelPath()) == null) return null;
|
||||
|
||||
String at = trimToNull(m.getAlgorithmType());
|
||||
String dt = trimToNull(m.getDeviceType());
|
||||
if (at != null && algorithmType != null && !at.equals(algorithmType)) {
|
||||
throw new IllegalArgumentException("指定模型版本与算法类型不匹配 algorithmModelId=" + algorithmModelId + " algorithmType=" + algorithmType + " model.algorithmType=" + at);
|
||||
}
|
||||
if (dt != null && deviceType != null && !dt.equals(deviceType)) {
|
||||
throw new IllegalArgumentException("指定模型版本与设备类型不匹配 algorithmModelId=" + algorithmModelId + " deviceType=" + deviceType + " model.deviceType=" + dt);
|
||||
}
|
||||
|
||||
String mtModel = normalizeMaterialType(m.getMaterialType());
|
||||
String mtReq = normalizeMaterialType(materialType);
|
||||
if (mtReq == null) {
|
||||
if (mtModel != null) {
|
||||
throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=null model.materialType=" + mtModel);
|
||||
}
|
||||
} else {
|
||||
if (mtModel != null) {
|
||||
if ("Mixed".equals(mtReq)) {
|
||||
if (!"Mixed".equals(mtModel)) {
|
||||
throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=" + mtReq + " model.materialType=" + mtModel);
|
||||
}
|
||||
} else {
|
||||
if (!mtReq.equals(mtModel)) {
|
||||
throw new IllegalArgumentException("指定模型版本与材料类型不匹配 algorithmModelId=" + algorithmModelId + " materialType=" + mtReq + " model.materialType=" + mtModel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
private String normalizeMaterialType(String raw) {
|
||||
if (raw == null) return null;
|
||||
String s = raw.trim();
|
||||
if (s.isEmpty()) return null;
|
||||
if ("U".equalsIgnoreCase(s)) return "U";
|
||||
if ("Pu".equalsIgnoreCase(s)) return "Pu";
|
||||
if ("Mixed".equalsIgnoreCase(s) || "MIX".equalsIgnoreCase(s)) return "Mixed";
|
||||
return s;
|
||||
}
|
||||
|
||||
private InferRequest buildInferenceRequest(String deviceType,List<DeviceStepInfo> devices,String modelPath, List<String> requiredFeatures) {
|
||||
InferRequest request = new InferRequest();
|
||||
request.setModelDir(modelPath); // 设置模型路径
|
||||
|
||||
@ -564,7 +564,7 @@ public class ProjectServiceImpl
|
||||
if (mnode != null) {
|
||||
String topoMid = optText(mnode, "materialId");
|
||||
if (topoMid != null && !topoMid.isEmpty()) {
|
||||
Map<String, Object> info = new HashMap<>();
|
||||
Map<String, Object> info = new java.util.LinkedHashMap<>();
|
||||
String mname = optText(mnode, "name");
|
||||
if (mname != null) info.put("materialName", mname);
|
||||
JsonNode mstatic = mnode.path("static");
|
||||
@ -572,7 +572,12 @@ public class ProjectServiceImpl
|
||||
if (mstatic.path("u_concentration").isNumber()) info.put("u_concentration", mstatic.path("u_concentration").numberValue());
|
||||
if (mstatic.path("u_enrichment").isNumber()) info.put("u_enrichment", mstatic.path("u_enrichment").numberValue());
|
||||
if (mstatic.path("pu_concentration").isNumber()) info.put("pu_concentration", mstatic.path("pu_concentration").numberValue());
|
||||
if (mstatic.path("pu_isotope").isNumber()) info.put("pu_isotope", mstatic.path("pu_isotope").numberValue());
|
||||
// if (mstatic.path("pu_isotope").isNumber()) info.put("pu_isotope", mstatic.path("pu_isotope").numberValue());
|
||||
if (mstatic.path("e_pu238").isNumber()) info.put("e_pu238", mstatic.path("e_pu238").numberValue());
|
||||
if (mstatic.path("e_pu239").isNumber()) info.put("e_pu239", mstatic.path("e_pu239").numberValue());
|
||||
if (mstatic.path("e_pu240").isNumber()) info.put("e_pu240", mstatic.path("e_pu240").numberValue());
|
||||
if (mstatic.path("e_pu241").isNumber()) info.put("e_pu241", mstatic.path("e_pu241").numberValue());
|
||||
if (mstatic.path("e_pu242").isNumber()) info.put("e_pu242", mstatic.path("e_pu242").numberValue());
|
||||
}
|
||||
matTopo.put(topoMid, info);
|
||||
}
|
||||
@ -589,7 +594,7 @@ public class ProjectServiceImpl
|
||||
for (String did : deviceIds) {
|
||||
Device d = devMap.get(did);
|
||||
if (d == null) continue;
|
||||
Map<String, Object> row = new HashMap<>();
|
||||
Map<String, Object> row = new java.util.LinkedHashMap<>();
|
||||
row.put("deviceId", d.getDeviceId());
|
||||
row.put("deviceName", d.getName());
|
||||
row.put("deviceType", d.getType());
|
||||
@ -602,12 +607,27 @@ public class ProjectServiceImpl
|
||||
if (mid != null) {
|
||||
Material m = matMap.get(mid);
|
||||
row.put("materialId", mid);
|
||||
row.put("materialName", null);
|
||||
row.put("u_concentration", null);
|
||||
row.put("u_enrichment", null);
|
||||
row.put("pu_concentration", null);
|
||||
// row.put("pu_isotope", null);
|
||||
row.put("e_pu238", null);
|
||||
row.put("e_pu239", null);
|
||||
row.put("e_pu240", null);
|
||||
row.put("e_pu241", null);
|
||||
row.put("e_pu242", null);
|
||||
if (m != null) {
|
||||
row.put("materialName", m.getName());
|
||||
row.put("u_concentration", m.getUConcentration());
|
||||
row.put("u_enrichment", m.getUEnrichment());
|
||||
row.put("pu_concentration", m.getPuConcentration());
|
||||
row.put("pu_isotope", m.getPuIsotope());
|
||||
// row.put("pu_isotope", m.getPuIsotope());
|
||||
row.put("e_pu238", m.getEPu238());
|
||||
row.put("e_pu239", m.getEPu239());
|
||||
row.put("e_pu240", m.getEPu240());
|
||||
row.put("e_pu241", m.getEPu241());
|
||||
row.put("e_pu242", m.getEPu242());
|
||||
} else {
|
||||
Map<String, Object> info = matTopo.get(mid);
|
||||
if (info != null) {
|
||||
@ -615,7 +635,12 @@ public class ProjectServiceImpl
|
||||
if (info.get("u_concentration") != null) row.put("u_concentration", info.get("u_concentration"));
|
||||
if (info.get("u_enrichment") != null) row.put("u_enrichment", info.get("u_enrichment"));
|
||||
if (info.get("pu_concentration") != null) row.put("pu_concentration", info.get("pu_concentration"));
|
||||
if (info.get("pu_isotope") != null) row.put("pu_isotope", info.get("pu_isotope"));
|
||||
// if (info.get("pu_isotope") != null) row.put("pu_isotope", info.get("pu_isotope"));
|
||||
if (info.get("e_pu238") != null) row.put("e_pu238", info.get("e_pu238"));
|
||||
if (info.get("e_pu239") != null) row.put("e_pu239", info.get("e_pu239"));
|
||||
if (info.get("e_pu240") != null) row.put("e_pu240", info.get("e_pu240"));
|
||||
if (info.get("e_pu241") != null) row.put("e_pu241", info.get("e_pu241"));
|
||||
if (info.get("e_pu242") != null) row.put("e_pu242", info.get("e_pu242"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user