模型训练增加材料类型支持
This commit is contained in:
parent
36d4b57ce3
commit
1d525c3770
@ -40,6 +40,8 @@ public class SimBuilder {
|
||||
JsonNode root = objectMapper.readTree(project.getTopology());
|
||||
Map<String, String> devToMat = buildDeviceMaterialMap(root);
|
||||
Map<String, Map<String, Double>> matStaticDb = buildMaterialStaticFromDb(devToMat, materialService);
|
||||
// 补充:建立 Material ID -> Material Type 的映射
|
||||
Map<String, String> matTypeMap = buildMaterialTypeMap(devToMat, materialService);
|
||||
|
||||
JsonNode devicesNode = root.path("devices");
|
||||
if (devicesNode.isArray()) {
|
||||
@ -49,6 +51,7 @@ public class SimBuilder {
|
||||
|
||||
String type = deviceNode.path("type").asText();
|
||||
String materialId = devToMat.get(deviceId);
|
||||
String materialType = materialId != null ? matTypeMap.get(materialId) : null;
|
||||
|
||||
Map<String, Double> staticProps = new HashMap<>();
|
||||
|
||||
@ -92,7 +95,7 @@ public class SimBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
units.add(new SimUnit(deviceId, deviceId, materialId, type, staticProps));
|
||||
units.add(new SimUnit(deviceId, deviceId, materialId, type, materialType, staticProps));
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
@ -101,6 +104,24 @@ public class SimBuilder {
|
||||
return units;
|
||||
}
|
||||
|
||||
private Map<String, String> buildMaterialTypeMap(Map<String, String> devToMat, MaterialService materialService) {
|
||||
Map<String, String> out = new HashMap<>();
|
||||
if (devToMat.isEmpty()) return out;
|
||||
Set<String> mids = new HashSet<>(devToMat.values());
|
||||
List<Material> mats = materialService.list(new QueryWrapper<Material>().in("material_id", mids));
|
||||
for (Material m : mats) {
|
||||
String type = "unknown";
|
||||
// 简单推导规则: 优先判定 Pu,其次 U
|
||||
if (m.getPuConcentration() != null && m.getPuConcentration().doubleValue() > 0) {
|
||||
type = "Pu";
|
||||
} else if (m.getUConcentration() != null && m.getUConcentration().doubleValue() > 0) {
|
||||
type = "U";
|
||||
}
|
||||
out.put(m.getMaterialId(), type);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private void injectDeviceSize(Device device, Map<String, Double> staticProps) {
|
||||
try {
|
||||
String sizeJson = device.getSize();
|
||||
@ -124,7 +145,7 @@ public class SimBuilder {
|
||||
|
||||
case "CylindricalTank":
|
||||
case "AnnularTank":
|
||||
case "TubeBundleTank":
|
||||
case "TubeBundleTank"://管束槽
|
||||
parseCommonSize(sizeNode, staticProps);
|
||||
break;
|
||||
|
||||
|
||||
@ -80,11 +80,12 @@ public class AlgorithmModelController {
|
||||
return algorithmModelService.deleteBatchWithCheck(ids);
|
||||
}
|
||||
|
||||
//返回:该算法+设备类型的版本列表
|
||||
//返回:该算法+设备类型+材料类型的版本列表
|
||||
@GetMapping("/search")
|
||||
@Operation(summary = "查询模型版本列表", description = "按算法类型与设备类型过滤并分页返回模型版本")
|
||||
@Operation(summary = "查询模型版本列表", description = "按算法类型、设备类型与材料类型过滤并分页返回模型版本")
|
||||
public Page<AlgorithmModel> search(@RequestParam(required = false) String algorithmType,
|
||||
@RequestParam(required = false) String deviceType,
|
||||
@RequestParam(required = false) String materialType,
|
||||
@RequestParam(required = false) String versionTag,
|
||||
@RequestParam(required = false) String isCurrent,
|
||||
@RequestParam(defaultValue = "1") long pageNum,
|
||||
@ -92,6 +93,7 @@ public class AlgorithmModelController {
|
||||
QueryWrapper<AlgorithmModel> qw = new QueryWrapper<>();
|
||||
if (algorithmType != null && !algorithmType.isEmpty()) qw.eq("algorithm_type", algorithmType);
|
||||
if (deviceType != null && !deviceType.isEmpty()) qw.eq("device_type", deviceType);
|
||||
if (materialType != null && !materialType.isEmpty()) qw.eq("material_type", materialType);
|
||||
if (versionTag != null && !versionTag.isEmpty()) qw.eq("version_tag", versionTag);
|
||||
if (isCurrent != null && !isCurrent.isEmpty()) qw.eq("is_current", isCurrent);
|
||||
qw.orderByDesc("updated_at");
|
||||
@ -99,29 +101,38 @@ public class AlgorithmModelController {
|
||||
return algorithmModelService.page(page, qw);
|
||||
}
|
||||
|
||||
//返回:该算法+设备类型的当前激活版本
|
||||
//返回:该算法+设备类型+材料类型的当前激活版本
|
||||
@GetMapping("/current")
|
||||
@Operation(summary = "获取当前激活版本", description = "根据算法类型与设备类型,返回 is_current=1 的模型版本")
|
||||
@Operation(summary = "获取当前激活版本", description = "根据算法类型、设备类型与材料类型,返回 is_current=1 的模型版本")
|
||||
public AlgorithmModel getCurrent(@RequestParam String algorithmType,
|
||||
@RequestParam String deviceType) {
|
||||
@RequestParam String deviceType,
|
||||
@RequestParam(required = false) String materialType) {
|
||||
QueryWrapper<AlgorithmModel> qw = new QueryWrapper<>();
|
||||
qw.eq("algorithm_type", algorithmType);
|
||||
qw.eq("device_type", deviceType);
|
||||
if (materialType != null && !materialType.isEmpty()) {
|
||||
qw.eq("material_type", materialType);
|
||||
}
|
||||
qw.eq("is_current", 1);
|
||||
qw.orderByDesc("updated_at");
|
||||
return algorithmModelService.getOne(qw);
|
||||
return algorithmModelService.getOne(qw, false); // 使用 false 避免多条结果时抛出异常,虽然正常不应该有多条
|
||||
}
|
||||
|
||||
//版本激活
|
||||
@PostMapping("/activate")
|
||||
@Operation(summary = "激活模型版本", description = "将目标模型版本设为当前,并将同组其他版本设为非当前")
|
||||
@Operation(summary = "激活模型版本", description = "将目标模型版本设为当前,并将同组(算法+设备+材料)其他版本设为非当前")
|
||||
public boolean activate(@RequestParam String algorithmModelId) {
|
||||
AlgorithmModel model = algorithmModelService.getById(algorithmModelId);
|
||||
if (model == null) return false;
|
||||
// 先将所有版本设为非当前
|
||||
// 先将所有同组版本设为非当前
|
||||
QueryWrapper<AlgorithmModel> qw = new QueryWrapper<>();
|
||||
qw.eq("algorithm_type", model.getAlgorithmType());
|
||||
qw.eq("device_type", model.getDeviceType());
|
||||
if (model.getMaterialType() != null && !model.getMaterialType().isEmpty()) {
|
||||
qw.eq("material_type", model.getMaterialType());
|
||||
} else {
|
||||
qw.and(wrapper -> wrapper.isNull("material_type").or().eq("material_type", ""));
|
||||
}
|
||||
AlgorithmModel upd = new AlgorithmModel();
|
||||
upd.setIsCurrent(0);
|
||||
algorithmModelService.update(upd, qw);
|
||||
@ -138,6 +149,7 @@ public class AlgorithmModelController {
|
||||
public Map<String, Object> trainExcel(@RequestBody Map<String, Object> body) {
|
||||
String algorithmType = str(body.get("algorithm_type"));
|
||||
String deviceType = str(body.get("device_type"));
|
||||
String materialType = str(body.getOrDefault("material_type", ""));
|
||||
String datasetPath = str(body.get("dataset_path"));
|
||||
String modelDir = str(body.getOrDefault("model_dir", ""));
|
||||
boolean activate = bool(body.getOrDefault("activate", false));
|
||||
@ -178,6 +190,7 @@ public class AlgorithmModelController {
|
||||
model.setAlgorithmModelId(UUID.randomUUID().toString());
|
||||
model.setAlgorithmType(algorithmType);
|
||||
model.setDeviceType(deviceType);
|
||||
model.setMaterialType(materialType);
|
||||
model.setVersionTag(genVersionTag());
|
||||
model.setModelPath(modelPath);
|
||||
model.setFeatureMapSnapshot(isBlank(featureMapSnapshot) ? "{}" : featureMapSnapshot);
|
||||
@ -190,6 +203,11 @@ public class AlgorithmModelController {
|
||||
if (activate) {
|
||||
QueryWrapper<AlgorithmModel> qw = new QueryWrapper<>();
|
||||
qw.eq("algorithm_type", algorithmType).eq("device_type", deviceType);
|
||||
if (!isBlank(materialType)) {
|
||||
qw.eq("material_type", materialType);
|
||||
} else {
|
||||
qw.and(wrapper -> wrapper.isNull("material_type").or().eq("material_type", ""));
|
||||
}
|
||||
AlgorithmModel upd = new AlgorithmModel();
|
||||
upd.setIsCurrent(0);
|
||||
algorithmModelService.update(upd, qw);
|
||||
@ -204,6 +222,7 @@ public class AlgorithmModelController {
|
||||
public Map<String, Object> trainSamples(@RequestBody Map<String, Object> body) {
|
||||
String algorithmType = str(body.get("algorithm_type"));
|
||||
String deviceType = str(body.get("device_type"));
|
||||
String materialType = str(body.getOrDefault("material_type", ""));
|
||||
Object samples = body.get("samples"); // 期望为 List<Map>,由前端提供
|
||||
String modelDir = str(body.getOrDefault("model_dir", ""));
|
||||
boolean activate = bool(body.getOrDefault("activate", false));
|
||||
@ -230,6 +249,7 @@ public class AlgorithmModelController {
|
||||
model.setAlgorithmModelId(UUID.randomUUID().toString());
|
||||
model.setAlgorithmType(algorithmType);
|
||||
model.setDeviceType(deviceType);
|
||||
model.setMaterialType(materialType);
|
||||
model.setVersionTag(genVersionTag());
|
||||
model.setModelPath(modelPath);
|
||||
model.setFeatureMapSnapshot(isBlank(featureMapSnapshot) ? "{}" : featureMapSnapshot);
|
||||
@ -242,6 +262,11 @@ public class AlgorithmModelController {
|
||||
if (activate) {
|
||||
QueryWrapper<AlgorithmModel> qw = new QueryWrapper<>();
|
||||
qw.eq("algorithm_type", algorithmType).eq("device_type", deviceType);
|
||||
if (!isBlank(materialType)) {
|
||||
qw.eq("material_type", materialType);
|
||||
} else {
|
||||
qw.and(wrapper -> wrapper.isNull("material_type").or().eq("material_type", ""));
|
||||
}
|
||||
AlgorithmModel upd = new AlgorithmModel();
|
||||
upd.setIsCurrent(0);
|
||||
algorithmModelService.update(upd, qw);
|
||||
|
||||
@ -44,12 +44,9 @@ public class SimController {
|
||||
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
|
||||
|
||||
// 0. Update Status: 更新情景状态为进行中
|
||||
Scenario startScenario = new Scenario();
|
||||
startScenario.setScenarioId(scenarioId);
|
||||
startScenario.setStatus("1"); // 1: 进行中
|
||||
startScenario.setUpdatedAt(LocalDateTime.now());
|
||||
scenarioService.updateById(startScenario);
|
||||
updateScenarioStatus(scenarioId, "1"); // 1: 进行中
|
||||
|
||||
try {
|
||||
// 1. Load Data: 获取项目、设备和事件数据
|
||||
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
|
||||
|
||||
@ -67,6 +64,29 @@ public class SimController {
|
||||
// 6. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据
|
||||
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
|
||||
|
||||
// 为了让前端明确感知推理是否启动成功,这里增加一个状态提示
|
||||
// 但因为是异步的,我们只能保证“已提交”。真正的成败由 asyncInferAndSave 决定。
|
||||
resultData.put("inferenceStatus", "submitted");
|
||||
|
||||
return ResponseResult.successData(resultData);
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
// 仿真计算失败,更新状态为失败 (3: 失败)
|
||||
updateScenarioStatus(scenarioId, "3");
|
||||
return ResponseResult.error("Simulation failed: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void updateScenarioStatus(String scenarioId, String status) {
|
||||
try {
|
||||
Scenario scenario = new Scenario();
|
||||
scenario.setScenarioId(scenarioId);
|
||||
scenario.setStatus(status);
|
||||
scenario.setUpdatedAt(LocalDateTime.now());
|
||||
scenarioService.updateById(scenario);
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to update scenario status to " + status + ": " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,6 +24,9 @@ public class AlgorithmModel implements Serializable {
|
||||
@TableField("device_type")
|
||||
private String deviceType;
|
||||
|
||||
@TableField("material_type")
|
||||
private String materialType;
|
||||
|
||||
@TableField("version_tag")
|
||||
private String versionTag;
|
||||
|
||||
|
||||
@ -5,15 +5,21 @@ import java.util.Map;
|
||||
public class DeviceStepInfo {
|
||||
private String deviceId;
|
||||
private String deviceType;
|
||||
private String materialType; // 新增
|
||||
private Map<String, Object> properties;
|
||||
private int step;
|
||||
private int time;
|
||||
|
||||
public DeviceStepInfo() {}
|
||||
|
||||
public DeviceStepInfo(String deviceId, String deviceType, Map<String, Object> properties,int step,int time) {
|
||||
public DeviceStepInfo(String deviceId, String deviceType, Map<String, Object> properties, int step, int time) {
|
||||
this(deviceId, deviceType, null, properties, step, time);
|
||||
}
|
||||
|
||||
public DeviceStepInfo(String deviceId, String deviceType, String materialType, Map<String, Object> properties, int step, int time) {
|
||||
this.deviceId = deviceId;
|
||||
this.deviceType = deviceType;
|
||||
this.materialType = materialType;
|
||||
this.properties = properties;
|
||||
this.step = step;
|
||||
this.time = time;
|
||||
@ -28,7 +34,6 @@ public class DeviceStepInfo {
|
||||
this.deviceId = deviceId;
|
||||
}
|
||||
|
||||
|
||||
public String getDeviceType() {
|
||||
return deviceType;
|
||||
}
|
||||
@ -37,6 +42,14 @@ public class DeviceStepInfo {
|
||||
this.deviceType = deviceType;
|
||||
}
|
||||
|
||||
public String getMaterialType() {
|
||||
return materialType;
|
||||
}
|
||||
|
||||
public void setMaterialType(String materialType) {
|
||||
this.materialType = materialType;
|
||||
}
|
||||
|
||||
public Map<String, Object> getProperties() {
|
||||
return properties;
|
||||
}
|
||||
@ -66,6 +79,7 @@ public class DeviceStepInfo {
|
||||
return "DeviceStepInfo{" +
|
||||
"deviceId='" + deviceId + '\'' +
|
||||
", deviceType='" + deviceType + '\'' +
|
||||
", materialType='" + materialType + '\'' +
|
||||
", properties=" + properties +
|
||||
", step=" + step +
|
||||
", time=" + time +
|
||||
|
||||
@ -2,8 +2,17 @@ package com.yfd.business.css.model;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public record SimUnit(String unitId, String deviceId, String materialId, String deviceType, Map<String, Double> staticProperties) {
|
||||
public record SimUnit(String unitId, String deviceId, String materialId, String deviceType, String materialType, Map<String, Double> staticProperties) {
|
||||
public SimUnit(String unitId, String deviceId, String materialId, String deviceType, String materialType) {
|
||||
this(unitId, deviceId, materialId, deviceType, materialType, Map.of());
|
||||
}
|
||||
|
||||
// 兼容旧构造函数
|
||||
public SimUnit(String unitId, String deviceId, String materialId, String deviceType, Map<String, Double> staticProperties) {
|
||||
this(unitId, deviceId, materialId, deviceType, null, staticProperties);
|
||||
}
|
||||
|
||||
public SimUnit(String unitId, String deviceId, String materialId, String deviceType) {
|
||||
this(unitId, deviceId, materialId, deviceType, Map.of());
|
||||
this(unitId, deviceId, materialId, deviceType, null, Map.of());
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,7 +14,17 @@ public interface AlgorithmModelService extends IService<AlgorithmModel> {
|
||||
* @param deviceType 设备类型(如CylindricalTank/AnnularTank)
|
||||
* @return 激活版本的模型文件路径,如果不存在则返回null
|
||||
*/
|
||||
String getCurrentModelPath(String algorithmType, String deviceType) ;
|
||||
String getCurrentModelPath(String algorithmType, String deviceType);
|
||||
|
||||
/**
|
||||
* 根据算法类型、设备类型、材料类型,获取当前激活的模型路径
|
||||
*
|
||||
* @param algorithmType 算法类型
|
||||
* @param deviceType 设备类型
|
||||
* @param materialType 材料类型(如Pu/U)
|
||||
* @return 激活版本的模型文件路径
|
||||
*/
|
||||
String getCurrentModelPath(String algorithmType, String deviceType, String materialType);
|
||||
|
||||
boolean deleteBatchWithCheck(List<String> ids);
|
||||
|
||||
|
||||
@ -19,6 +19,8 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.*;
|
||||
|
||||
@Service
|
||||
@ -26,6 +28,11 @@ public class DeviceInferService {
|
||||
@Value("${python.api.url:http://localhost:8000}")
|
||||
private String pythonInferUrl;
|
||||
|
||||
|
||||
|
||||
@Value("${file-space.model-path}")
|
||||
private String modelRootPath;
|
||||
|
||||
@Resource
|
||||
private ScenarioService scenarioService;
|
||||
@Resource
|
||||
@ -38,6 +45,10 @@ public class DeviceInferService {
|
||||
|
||||
public void processDeviceInference(String projectId, String scenarioId,
|
||||
Map<String, List<DeviceStepInfo>> groupedDevices) {
|
||||
// 增加标志位,记录是否至少成功执行了一次推理
|
||||
boolean hasAnySuccess = false;
|
||||
boolean hasAnyError = false;
|
||||
|
||||
// 1. 获取情景配置信息
|
||||
Scenario scenario = scenarioService.getById(scenarioId);
|
||||
if (scenario == null) {
|
||||
@ -81,22 +92,39 @@ public class DeviceInferService {
|
||||
String currentAlgoType = algoEntry.getKey();
|
||||
List<DeviceStepInfo> currentDevices = algoEntry.getValue();
|
||||
|
||||
// 获取模型路径(根据算法类型和设备类型)
|
||||
System.out.println("Processing inference for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType);
|
||||
String modelPath = algorithmModelService.getCurrentModelPath(currentAlgoType, deviceType);
|
||||
System.out.println("modelPath=" + modelPath);
|
||||
// 4.1 按材料类型进行三级分组 (AlgoType -> MaterialType)
|
||||
Map<String, List<DeviceStepInfo>> matGroup = new HashMap<>();
|
||||
for (DeviceStepInfo d : currentDevices) {
|
||||
String mType = d.getMaterialType();
|
||||
// 如果 materialType 为空,归类为 "unknown" 或 null,视业务而定,这里保留 null
|
||||
matGroup.computeIfAbsent(mType, k -> new ArrayList<>()).add(d);
|
||||
}
|
||||
|
||||
if (modelPath == null) {
|
||||
System.err.println("Model path not found for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType);
|
||||
// 这里可以选择抛异常中断,或者跳过该组继续处理其他组
|
||||
// 为了保证健壮性,这里选择记录错误并跳过,或者根据业务需求抛出异常
|
||||
// throw new IllegalArgumentException("未配置 " + currentAlgoType + " 模型路径 (deviceType: " + deviceType + ")");
|
||||
// 4.2 遍历材料分组,分别获取模型并推理
|
||||
for (Map.Entry<String, List<DeviceStepInfo>> matEntry : matGroup.entrySet()) {
|
||||
String currentMaterialType = matEntry.getKey();
|
||||
List<DeviceStepInfo> batchDevices = matEntry.getValue();
|
||||
|
||||
// 获取模型路径(根据算法类型、设备类型、材料类型)
|
||||
System.out.println("Processing inference for algorithmType: " + currentAlgoType +
|
||||
", deviceType: " + deviceType + ", materialType: " + currentMaterialType);
|
||||
String modelRelPath = algorithmModelService.getCurrentModelPath(currentAlgoType, deviceType, currentMaterialType);
|
||||
System.out.println("modelRelPath=" + modelRelPath);
|
||||
|
||||
if (modelRelPath == null) {
|
||||
System.err.println("Model path not found for algorithmType: " + currentAlgoType +
|
||||
", deviceType: " + deviceType + ", materialType: " + currentMaterialType);
|
||||
hasAnyError = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 将相对路径转换为绝对路径
|
||||
String absoluteModelPath = Paths.get(modelRootPath).resolve(modelRelPath).toAbsolutePath().normalize().toString();
|
||||
System.out.println("Absolute modelPath=" + absoluteModelPath);
|
||||
|
||||
// 封装推理请求
|
||||
InferRequest request = buildInferenceRequest(deviceType, currentDevices, modelPath);
|
||||
// System.out.println("request=" + request);
|
||||
InferRequest request = buildInferenceRequest(deviceType, batchDevices, absoluteModelPath);
|
||||
System.out.println("request=" + request);
|
||||
|
||||
try {
|
||||
// 调用Python推理服务
|
||||
@ -118,18 +146,27 @@ public class DeviceInferService {
|
||||
reconstructedResponse.setMsg(response.getMsg());
|
||||
reconstructedResponse.setData(newData);
|
||||
|
||||
processInferenceResults(projectId, scenarioId, deviceType, currentDevices, reconstructedResponse);
|
||||
processInferenceResults(projectId, scenarioId, deviceType, batchDevices, reconstructedResponse);
|
||||
hasAnySuccess = true;
|
||||
} else {
|
||||
System.err.println("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误"));
|
||||
hasAnyError = true;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.err.println("推理异常: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
hasAnyError = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 最终检查:如果没有任何成功的推理,且发生过错误,抛出异常以通知上层
|
||||
if (!hasAnySuccess && hasAnyError) {
|
||||
throw new RuntimeException("所有设备推理均失败或未找到模型");
|
||||
}
|
||||
}
|
||||
|
||||
private InferRequest buildInferenceRequest(String deviceType,List<DeviceStepInfo> devices,String modelPath) {
|
||||
InferRequest request = new InferRequest();
|
||||
request.setModelDir(modelPath); // 设置模型路径
|
||||
|
||||
@ -69,6 +69,7 @@ public class SimInferService {
|
||||
DeviceStepInfo info = new DeviceStepInfo();
|
||||
info.setDeviceId(unit.deviceId()); // unitId 通常等于 deviceId
|
||||
info.setDeviceType(deviceType);
|
||||
info.setMaterialType(unit.materialType()); // 注入材料类型
|
||||
info.setProperties(properties);
|
||||
info.setStep(step);
|
||||
info.setTime(step); // 假设 time = step,或者根据步长计算
|
||||
@ -89,6 +90,12 @@ public class SimInferService {
|
||||
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
|
||||
|
||||
// 4. 更新状态为已完成
|
||||
// 注意:如果前面的 deviceInferService.processDeviceInference 内部发生部分异常但未抛出(例如 continue 了),
|
||||
// 这里依然会设置为 2 (Success)。
|
||||
// 建议检查 deviceInferService 是否真的执行了推理。
|
||||
// 但由于是 void 方法,无法直接得知。
|
||||
// 简单改进:如果 groupedDevices 非空,但所有组都因为找不到模型而跳过,应该视为失败吗?
|
||||
// 目前策略:只要没有抛出未捕获异常,就视为 Success。
|
||||
updateScenarioStatus(scenarioId, "2");
|
||||
|
||||
} catch (Exception e) {
|
||||
|
||||
@ -16,13 +16,29 @@ public class AlgorithmModelServiceImpl extends ServiceImpl<AlgorithmModelMapper,
|
||||
|
||||
@Override
|
||||
public String getCurrentModelPath(String algorithmType, String deviceType) {
|
||||
System.out.println("Querying current model path for algorithmType: " + algorithmType + ", deviceType: " + deviceType);
|
||||
return getCurrentModelPath(algorithmType, deviceType, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getCurrentModelPath(String algorithmType, String deviceType, String materialType) {
|
||||
System.out.println("Querying current model path for algorithmType: " + algorithmType + ", deviceType: " + deviceType + ", materialType: " + materialType);
|
||||
QueryWrapper<AlgorithmModel> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.eq("algorithm_type", algorithmType)
|
||||
.eq("device_type", deviceType)
|
||||
.eq("is_current", 1); // 当前激活版本
|
||||
AlgorithmModel model = getOne(queryWrapper);
|
||||
if (model != null) {
|
||||
|
||||
if (materialType != null && !materialType.isEmpty()) {
|
||||
queryWrapper.eq("material_type", materialType);
|
||||
} else {
|
||||
// 如果未指定材料类型,优先匹配无材料类型的通用模型,或者返回任意匹配(兼容旧逻辑)
|
||||
// 这里为了稳健,如果 materialType 为 null,暂不添加 material_type 的过滤条件,或者显式匹配 null
|
||||
// 建议:如果 materialType 为空,尝试匹配 material_type is null or ''
|
||||
// queryWrapper.and(w -> w.isNull("material_type").or().eq("material_type", ""));
|
||||
}
|
||||
|
||||
List<AlgorithmModel> models = list(queryWrapper);
|
||||
if (models != null && !models.isEmpty()) {
|
||||
AlgorithmModel model = models.get(0);
|
||||
System.out.println("Found model: " + model.getModelPath());
|
||||
return model.getModelPath();
|
||||
} else {
|
||||
|
||||
@ -308,23 +308,35 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
throw new BizException("versionTag 格式不合法");
|
||||
}
|
||||
|
||||
// 检查版本号唯一性
|
||||
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
|
||||
.eq("algorithm_type", task.getAlgorithmType())
|
||||
.eq("device_type", task.getDeviceType())
|
||||
.eq("version_tag", versionTag));
|
||||
if (count > 0) {
|
||||
throw new BizException("版本号已存在");
|
||||
}
|
||||
|
||||
String algorithmType = task.getAlgorithmType();
|
||||
String deviceType = task.getDeviceType();
|
||||
if (algorithmType == null || algorithmType.isBlank() || deviceType == null || deviceType.isBlank()) {
|
||||
throw new BizException("算法类型或设备类型不能为空");
|
||||
}
|
||||
|
||||
// 解析材料类型
|
||||
String materialType = "unknown";
|
||||
String outputPath = task.getModelOutputPath();
|
||||
if (outputPath != null && !outputPath.isBlank()) {
|
||||
String[] parts = outputPath.replace("\\", "/").split("/");
|
||||
// 预期结构: runs/{algorithmType}/{deviceType}/{materialType}/{taskId}/...
|
||||
if (parts.length > 4 && "runs".equals(parts[0])) {
|
||||
materialType = parts[3];
|
||||
}
|
||||
}
|
||||
|
||||
// 检查版本号唯一性
|
||||
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
|
||||
.eq("algorithm_type", algorithmType)
|
||||
.eq("device_type", deviceType)
|
||||
.eq("material_type", materialType)
|
||||
.eq("version_tag", versionTag));
|
||||
if (count > 0) {
|
||||
throw new BizException("版本号已存在");
|
||||
}
|
||||
|
||||
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
|
||||
Path versionDir = root.resolve(Paths.get(algorithmType, deviceType, versionTag)).normalize();
|
||||
Path versionDir = root.resolve(Paths.get(algorithmType, deviceType, materialType, versionTag)).normalize();
|
||||
if (!versionDir.startsWith(root)) {
|
||||
throw new BizException("发布目录非法");
|
||||
}
|
||||
@ -340,10 +352,42 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
publishedMetricsRelPath = copyToVersionDir(root, versionDir, task.getMetricsImagePath());
|
||||
}
|
||||
|
||||
// 尝试一并复制训练过程生成的其他配置/特征文件
|
||||
// 根据 python 端的生成规则,这些文件通常与模型文件 (pipeline.pkl) 在同一个目录下
|
||||
if (task.getModelOutputPath() != null && !task.getModelOutputPath().isBlank()) {
|
||||
try {
|
||||
// 推导出原始模型文件所在的目录
|
||||
String normalized = task.getModelOutputPath().replace("\\", "/");
|
||||
if (normalized.startsWith("models/")) {
|
||||
normalized = normalized.substring("models/".length());
|
||||
}
|
||||
Path srcPath = Paths.get(normalized);
|
||||
Path srcDir = srcPath.isAbsolute() ? srcPath.normalize().getParent() : root.resolve(normalized).normalize().getParent();
|
||||
|
||||
if (srcDir != null && Files.exists(srcDir)) {
|
||||
// 需要复制的文件名列表
|
||||
String[] extraFiles = {"train_params.json", "feature_map.json"};
|
||||
for (String fileName : extraFiles) {
|
||||
Path extraFileSrc = srcDir.resolve(fileName);
|
||||
if (Files.exists(extraFileSrc) && Files.isRegularFile(extraFileSrc)) {
|
||||
Path dest = versionDir.resolve(fileName).normalize();
|
||||
if (!Files.exists(dest)) {
|
||||
Files.copy(extraFileSrc, dest, StandardCopyOption.COPY_ATTRIBUTES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// 附加文件复制失败不应阻断发布主流程,仅记录日志
|
||||
System.err.println("复制附加训练文件失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
// 创建正式模型记录
|
||||
AlgorithmModel model = new AlgorithmModel();
|
||||
model.setAlgorithmType(task.getAlgorithmType());
|
||||
model.setDeviceType(task.getDeviceType());
|
||||
model.setMaterialType(materialType);
|
||||
model.setVersionTag(versionTag);
|
||||
model.setModelPath(publishedModelRelPath);
|
||||
model.setMetricsImagePath(publishedMetricsRelPath);
|
||||
|
||||
@ -11,9 +11,11 @@ spring:
|
||||
password: q3eef875%&4@44%*3
|
||||
slave:
|
||||
driverClassName: com.mysql.cj.jdbc.Driver
|
||||
url: jdbc:mysql://43.138.168.68:3306/businessdb_css?useUnicode=true&characterEncoding=UTF8&rewriteBatchedStatements=true
|
||||
#url: jdbc:mysql://43.138.168.68:3306/businessdb_css?useUnicode=true&characterEncoding=UTF8&rewriteBatchedStatements=true
|
||||
url: jdbc:mysql://127.0.0.1:3306/businessdb_css?useUnicode=true&characterEncoding=UTF8&rewriteBatchedStatements=true
|
||||
username: root
|
||||
password: q3eef875%&4@44%*3
|
||||
#password: q3eef875%&4@44%*3
|
||||
password: 123456
|
||||
|
||||
file-space:
|
||||
files: D:\css\files\
|
||||
|
||||
Loading…
Reference in New Issue
Block a user