模型训练
This commit is contained in:
parent
632fd0f725
commit
aff2d7feea
@ -5,6 +5,7 @@ import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
|
||||
import org.springframework.boot.web.servlet.ServletComponentScan;
|
||||
|
||||
@SpringBootApplication(
|
||||
scanBasePackages = {
|
||||
@ -18,6 +19,9 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
|
||||
},
|
||||
exclude = {DataSourceAutoConfiguration.class, RedisAutoConfiguration.class}
|
||||
)
|
||||
@ServletComponentScan(basePackages = {
|
||||
"com.yfd.platform.config"
|
||||
})
|
||||
@MapperScan(basePackages = {
|
||||
"com.yfd.platform.**.mapper",
|
||||
"com.yfd.business.css.**.mapper"
|
||||
|
||||
@ -1,44 +1,368 @@
|
||||
package com.yfd.business.css.build;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yfd.business.css.domain.Device;
|
||||
import com.yfd.business.css.domain.Event;
|
||||
import com.yfd.business.css.domain.Material;
|
||||
import com.yfd.business.css.domain.Project;
|
||||
import com.yfd.business.css.model.*;
|
||||
import com.yfd.business.css.service.MaterialService;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
/*
|
||||
*将数据库原始表 Map(拓扑、事件、影响关系)
|
||||
*转成 SimUnit / SimEvent / SimInfluenceNode
|
||||
*/
|
||||
/**
|
||||
* 仿真模型构建器
|
||||
* 负责将原始的 Project (Topology) 和 Event 数据解析转换为 SimUnit, SimEvent, SimInfluenceNode
|
||||
*/
|
||||
@Component
|
||||
public class SimBuilder {
|
||||
|
||||
// public static List<SimUnit> buildUnitsFromTopo(List<ProjectTopo> topoRows) {
|
||||
// return topoRows.stream()
|
||||
// .map(t -> new SimUnit(t.getDeviceId(), t.getDeviceId(), t.getMaterialId(), t.getDeviceType()))
|
||||
// .toList();
|
||||
// }
|
||||
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
// public static List<SimEvent> buildEventsFromAttrChanges(List<EventAttrChange> attrChanges) {
|
||||
// return attrChanges.stream()
|
||||
// .map(e -> new SimEvent(e.getStep(),
|
||||
// SimPropertyKey.of(e.getDeviceId(), e.getAttr()),
|
||||
// e.getValue(),
|
||||
// SimEvent.EventType.SET))
|
||||
// .toList();
|
||||
// }
|
||||
// 1. 构建单元 (SimUnit) - 包含静态属性
|
||||
public List<SimUnit> buildUnits(SimDataPackage data, MaterialService materialService) {
|
||||
List<SimUnit> units = new ArrayList<>();
|
||||
Project project = data.getProject();
|
||||
if (project == null || project.getTopology() == null) return units;
|
||||
|
||||
// public static List<SimInfluenceNode> buildInfluenceNodesFromInfluences(List<Influence> influenceRows) {
|
||||
// return influenceRows.stream()
|
||||
// .map(i -> {
|
||||
// List<SimInfluenceSource> sources = i.getSources().stream()
|
||||
// .map(s -> new SimInfluenceSource(
|
||||
// SimPropertyKey.of(s.getUnitId(), s.getProp()),
|
||||
// s.getCoeff(),
|
||||
// s.getDelay()
|
||||
// )).toList();
|
||||
// return new SimInfluenceNode(SimPropertyKey.of(i.getTargetUnit(), i.getTargetProp()),
|
||||
// sources,
|
||||
// i.getBias());
|
||||
// }).toList();
|
||||
// }
|
||||
// 建立 Device ID -> Device 实体的映射,方便查找
|
||||
Map<String, Device> deviceMap = new HashMap<>();
|
||||
if (data.getDevices() != null) {
|
||||
for (Device d : data.getDevices()) {
|
||||
deviceMap.put(d.getDeviceId(), d);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
JsonNode root = objectMapper.readTree(project.getTopology());
|
||||
Map<String, String> devToMat = buildDeviceMaterialMap(root);
|
||||
Map<String, Map<String, Double>> matStaticDb = buildMaterialStaticFromDb(devToMat, materialService);
|
||||
|
||||
JsonNode devicesNode = root.path("devices");
|
||||
if (devicesNode.isArray()) {
|
||||
for (JsonNode deviceNode : devicesNode) {
|
||||
String deviceId = deviceNode.path("deviceId").asText();
|
||||
if (deviceId == null || deviceId.isEmpty()) continue;
|
||||
|
||||
String type = deviceNode.path("type").asText();
|
||||
String materialId = devToMat.get(deviceId);
|
||||
|
||||
Map<String, Double> staticProps = new HashMap<>();
|
||||
|
||||
// 1.1 解析 topology 中的 static 节点
|
||||
JsonNode st = deviceNode.path("static");
|
||||
if (st.isObject()) {
|
||||
st.fieldNames().forEachRemaining(k -> {
|
||||
if (!"unit".equals(k)) {
|
||||
JsonNode v = st.path(k);
|
||||
if (v.isNumber()) staticProps.put(k, v.asDouble());
|
||||
else if (v.isTextual()) staticProps.put(k, parseDouble(v.asText()));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 1.2 解析 Device.size 并注入
|
||||
Device device = deviceMap.get(deviceId);
|
||||
if (device != null) {
|
||||
injectDeviceSize(device, staticProps);
|
||||
}
|
||||
|
||||
// 1.3 注入物料静态属性
|
||||
if (materialId != null) {
|
||||
Map<String, Double> mProps = matStaticDb.get(materialId);
|
||||
if (mProps != null) staticProps.putAll(mProps);
|
||||
|
||||
// 解析 topology 中 material 的 static
|
||||
JsonNode mats = deviceNode.path("materials");
|
||||
if (mats.isMissingNode() || mats.isNull()) mats = deviceNode.path("material");
|
||||
if (mats.isObject()) {
|
||||
JsonNode mst = mats.path("static");
|
||||
if (mst.isObject()) {
|
||||
mst.fieldNames().forEachRemaining(k -> {
|
||||
if (!"unit".equals(k)) {
|
||||
JsonNode v = mst.path(k);
|
||||
if (v.isNumber()) staticProps.put(k, v.asDouble());
|
||||
else if (v.isTextual()) staticProps.put(k, parseDouble(v.asText()));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
units.add(new SimUnit(deviceId, deviceId, materialId, type, staticProps));
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Build units failed", e);
|
||||
}
|
||||
return units;
|
||||
}
|
||||
|
||||
private void injectDeviceSize(Device device, Map<String, Double> staticProps) {
|
||||
try {
|
||||
String sizeJson = device.getSize();
|
||||
if (sizeJson == null || sizeJson.isBlank()) return;
|
||||
|
||||
JsonNode sizeNode = objectMapper.readTree(sizeJson);
|
||||
String type = device.getType(); // 假设 Device 有 getType() 方法,或从外部传入
|
||||
|
||||
if (type == null) {
|
||||
// Fallback: 尝试通用解析 (现有逻辑)
|
||||
parseCommonSize(sizeNode, staticProps);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case "FlatTank":
|
||||
if (sizeNode.has("width")) staticProps.put("width", sizeNode.get("width").asDouble());
|
||||
if (sizeNode.has("length")) staticProps.put("length", sizeNode.get("length").asDouble());
|
||||
if (sizeNode.has("height")) staticProps.put("height", sizeNode.get("height").asDouble());
|
||||
break;
|
||||
|
||||
case "CylindricalTank":
|
||||
case "AnnularTank":
|
||||
case "TubeBundleTank":
|
||||
parseCommonSize(sizeNode, staticProps);
|
||||
break;
|
||||
|
||||
case "ExtractionColumn":
|
||||
// 优先提取 tray_section (塔身)
|
||||
if (sizeNode.has("tray_section")) {
|
||||
parseCommonSize(sizeNode.get("tray_section"), staticProps);
|
||||
}
|
||||
// 可选:将其他段的尺寸作为特殊属性注入,例如 lower_expanded_height
|
||||
break;
|
||||
|
||||
case "FluidizedBed":
|
||||
// 优先提取 reaction_section (反应段)
|
||||
if (sizeNode.has("reaction_section")) {
|
||||
parseCommonSize(sizeNode.get("reaction_section"), staticProps);
|
||||
}
|
||||
break;
|
||||
|
||||
case "ACFTank":
|
||||
// 优先提取 annular_cylinder (圆柱段)
|
||||
if (sizeNode.has("annular_cylinder")) {
|
||||
parseCommonSize(sizeNode.get("annular_cylinder"), staticProps);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
parseCommonSize(sizeNode, staticProps);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.err.println("解析Device.size失败:" + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void parseCommonSize(JsonNode node, Map<String, Double> staticProps) {
|
||||
if (node.has("outer_diameter")) {
|
||||
staticProps.put("diameter", node.get("outer_diameter").asDouble());
|
||||
} else if (node.has("diameter")) {
|
||||
staticProps.put("diameter", node.get("diameter").asDouble());
|
||||
}
|
||||
|
||||
if (node.has("height")) {
|
||||
staticProps.put("height", node.get("height").asDouble());
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 构建事件 (SimEvent)
|
||||
public List<SimEvent> buildEvents(List<Event> events) {
|
||||
List<SimEvent> simEvents = new ArrayList<>();
|
||||
if (events == null) return simEvents;
|
||||
|
||||
for (Event ev : events) {
|
||||
String json = ev.getAttrChanges();
|
||||
if (json == null || json.isBlank()) continue;
|
||||
try {
|
||||
JsonNode root = objectMapper.readTree(json);
|
||||
JsonNode target = root.path("target");
|
||||
String entityType = optText(target, "entityType");
|
||||
String entityId = optText(target, "entityId");
|
||||
String property = optText(target, "property");
|
||||
if (entityType == null || entityId == null || property == null) continue;
|
||||
|
||||
// 区分 entityType? SimUnit id 是 deviceId,如果是 material,这里需要注意 SimUnit 的定义
|
||||
// ProjectServiceImpl 中是分开处理 device 和 material 的 state
|
||||
// 这里假设 SimUnit 以 deviceId 为主键,material 属性也挂在 deviceId 下 (SimContext key: deviceId + property)
|
||||
// 或者 SimContext key 包含 entityType? SimPropertyKey(unitId, property)
|
||||
// ProjectServiceImpl: entityType + ":" + entityId + ":" + property
|
||||
// SimPropertyKey: unitId, property. unitId 通常是 deviceId.
|
||||
// 如果是 material 属性,ProjectServiceImpl 中是把 material 属性 copy 到了 device state 中。
|
||||
// 所以这里 key 的 unitId 应该是 deviceId (如果 target 是 material,需找到对应的 deviceId)
|
||||
// 但 Event 中只记录了 materialId,没记录 deviceId (虽然 Event 表有 device_id 字段)
|
||||
|
||||
String targetUnitId = entityId;
|
||||
// 注意:如果 event 针对 material,且 material 被多个 device 共享,这里逻辑需要确认
|
||||
// ProjectServiceImpl 是遍历所有 device,如果 device 关联了该 material,就应用覆盖。
|
||||
// SimService 中 SimEvent 直接作用于 SimContext。
|
||||
// 如果 SimUnit 对应 Device,那么 Material 的事件需要转换为针对所有引用该 Material 的 Device 的事件。
|
||||
// 这需要在 buildEvents 时知道 Device-Material 关系,或者 SimContext 支持 Material 维度的存储。
|
||||
// 简化起见,假设 Event 表中 device_id 有值,或者我们在 buildEvents 时传入 devToMat 映射关系?
|
||||
// Event 表有 device_id 字段,可以直接用。
|
||||
if ("material".equals(entityType)) {
|
||||
if (ev.getDeviceId() != null) targetUnitId = ev.getDeviceId();
|
||||
}
|
||||
|
||||
SimPropertyKey key = SimPropertyKey.of(targetUnitId, property);
|
||||
|
||||
JsonNode segments = root.path("segments");
|
||||
if (segments.isArray()) {
|
||||
for (JsonNode seg : segments) {
|
||||
String interp = optText(seg, "interp");
|
||||
JsonNode timeline = seg.path("timeline");
|
||||
if (timeline.isArray()) {
|
||||
// 简化处理:全部转为 step-set,ramp 也离散化?
|
||||
// ProjectServiceImpl 是保留 ramp 结构,readValue 时计算。
|
||||
// SimService 是 step 推进。
|
||||
// 这里我们将 timeline 展开为每个 step 的 SimEvent
|
||||
|
||||
// 线性插值处理比较复杂,这里先处理离散点
|
||||
for (JsonNode p : timeline) {
|
||||
double t = p.path("t").asDouble();
|
||||
double val = p.path("value").asDouble();
|
||||
// Step ? Time to Step 转换
|
||||
// 假设 1s = 1 step
|
||||
int step = (int) t;
|
||||
// isOverride=true 在所有计算完成后,强制修改属性值。有 1 帧延迟 。
|
||||
// Input 事件 ( isOverride=false ),在 applyInfluences (影响计算) 之前 执行。下游设备读取的是 修改后 的值(当前帧的值)。
|
||||
simEvents.add(new SimEvent(step, key, val, SimEvent.EventType.SET, false)); // 默认为强制覆盖?
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// log error
|
||||
}
|
||||
}
|
||||
return simEvents;
|
||||
}
|
||||
|
||||
// 3. 构建影响关系 (SimInfluenceNode)
|
||||
public List<SimInfluenceNode> buildInfluenceNodes(Project project) {
|
||||
List<SimInfluenceNode> nodes = new ArrayList<>();
|
||||
if (project == null || project.getTopology() == null) return nodes;
|
||||
|
||||
try {
|
||||
JsonNode root = objectMapper.readTree(project.getTopology());
|
||||
JsonNode devicesNode = root.path("devices");
|
||||
if (devicesNode.isArray()) {
|
||||
for (JsonNode dn : devicesNode) {
|
||||
String deviceId = optText(dn, "deviceId");
|
||||
if (deviceId == null) continue;
|
||||
|
||||
// Device Properties Influence
|
||||
parseInfluence(dn.path("properties"), deviceId, nodes);
|
||||
|
||||
// Material Properties Influence (mapped to device)
|
||||
JsonNode mats = dn.path("materials");
|
||||
if (mats.isMissingNode() || mats.isNull()) mats = dn.path("material");
|
||||
if (mats.isObject()) {
|
||||
parseInfluence(mats.path("properties"), deviceId, nodes);
|
||||
} else if (mats.isArray()) {
|
||||
for (JsonNode mn : mats) {
|
||||
parseInfluence(mn.path("properties"), deviceId, nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Build influence nodes failed", e);
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
private void parseInfluence(JsonNode props, String targetUnitId, List<SimInfluenceNode> nodes) {
|
||||
if (!props.isObject()) return;
|
||||
props.fieldNames().forEachRemaining(propName -> {
|
||||
JsonNode prop = props.path(propName);
|
||||
if ("influence".equalsIgnoreCase(optText(prop, "type"))) {
|
||||
double bias = prop.path("bias").asDouble(0.0);
|
||||
List<SimInfluenceSource> sources = new ArrayList<>();
|
||||
JsonNode srcs = prop.path("sources");
|
||||
if (srcs.isArray()) {
|
||||
for (JsonNode sNode : srcs) {
|
||||
String seId = optText(sNode, "entityId"); // source unit id
|
||||
String seProp = optText(sNode, "property");
|
||||
double coef = sNode.path("coefficient").asDouble(1.0);
|
||||
|
||||
long delayMs = 0L;
|
||||
JsonNode delay = sNode.path("delay");
|
||||
if (delay.path("enabled").asBoolean(false)) {
|
||||
long t = delay.path("time").asLong(0L);
|
||||
String u = optText(delay, "unit");
|
||||
delayMs = toMillis(t, u);
|
||||
}
|
||||
int delayStep = (int) (delayMs / 1000); // 假设 1s = 1 step
|
||||
|
||||
sources.add(new SimInfluenceSource(SimPropertyKey.of(seId, seProp), coef, delayStep));
|
||||
}
|
||||
}
|
||||
nodes.add(new SimInfluenceNode(SimPropertyKey.of(targetUnitId, propName), sources, bias));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Helpers
|
||||
private Map<String, String> buildDeviceMaterialMap(JsonNode root) {
|
||||
Map<String, String> map = new HashMap<>();
|
||||
JsonNode devicesNode = root.path("devices");
|
||||
if (devicesNode.isArray()) {
|
||||
for (JsonNode dn : devicesNode) {
|
||||
String did = optText(dn, "deviceId");
|
||||
if (did == null) continue;
|
||||
JsonNode mats = dn.path("materials");
|
||||
if (mats.isMissingNode() || mats.isNull()) mats = dn.path("material");
|
||||
|
||||
String mid = null;
|
||||
if (mats.isArray() && mats.size() > 0) mid = optText(mats.get(0), "materialId");
|
||||
else if (mats.isObject()) mid = optText(mats, "materialId");
|
||||
|
||||
if (mid != null) map.put(did, mid);
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
private Map<String, Map<String, Double>> buildMaterialStaticFromDb(Map<String, String> devToMat, MaterialService materialService) {
|
||||
Map<String, Map<String, Double>> out = new HashMap<>();
|
||||
if (devToMat.isEmpty()) return out;
|
||||
Set<String> mids = new HashSet<>(devToMat.values());
|
||||
List<Material> mats = materialService.list(new QueryWrapper<Material>().in("material_id", mids));
|
||||
for (Material m : mats) {
|
||||
Map<String, Double> s = new HashMap<>();
|
||||
if (m.getUConcentration() != null) s.put("u_concentration", m.getUConcentration().doubleValue());
|
||||
if (m.getPuConcentration() != null) s.put("pu_concentration", m.getPuConcentration().doubleValue());
|
||||
if (m.getUEnrichment() != null) s.put("u_enrichment", m.getUEnrichment().doubleValue());
|
||||
if (m.getPuIsotope() != null) s.put("pu_isotope", m.getPuIsotope().doubleValue());
|
||||
if (m.getUo2Density() != null) s.put("uo2_density", m.getUo2Density().doubleValue());
|
||||
if (m.getPuo2Density() != null) s.put("puo2_density", m.getPuo2Density().doubleValue());
|
||||
if (m.getHno3Acidity() != null) s.put("hno3_acidity", m.getHno3Acidity().doubleValue());
|
||||
if (m.getH2c2o4Concentration() != null) s.put("h2c2o4_concentration", m.getH2c2o4Concentration().doubleValue());
|
||||
if (m.getOrganicRatio() != null) s.put("organic_ratio", m.getOrganicRatio().doubleValue());
|
||||
if (m.getMoistureContent() != null) s.put("moisture_content", m.getMoistureContent().doubleValue());
|
||||
// ... 其他属性
|
||||
out.put(m.getMaterialId(), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private String optText(JsonNode node, String key) {
|
||||
if (node.has(key)) return node.get(key).asText();
|
||||
return null;
|
||||
}
|
||||
|
||||
private double parseDouble(String s) {
|
||||
try { return Double.parseDouble(s); } catch (Exception e) { return 0.0; }
|
||||
}
|
||||
|
||||
private long toMillis(long time, String unit) {
|
||||
if ("s".equalsIgnoreCase(unit)) return time * 1000;
|
||||
if ("ms".equalsIgnoreCase(unit)) return time;
|
||||
return time * 1000; // default s
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
package com.yfd.business.css.config;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Configuration
|
||||
public class RestTemplateConfig {
|
||||
|
||||
@Bean
|
||||
public RestTemplate restTemplate() {
|
||||
return new RestTemplate();
|
||||
}
|
||||
}
|
||||
@ -154,7 +154,7 @@ public class EventController {
|
||||
new QueryWrapper<Event>()
|
||||
.select("event_id","scenario_id","device_id","material_id","attr_changes","trigger_time","created_at","modifier")
|
||||
.eq("scenario_id", scenarioId)
|
||||
.orderByDesc("created_at")
|
||||
.orderByAsc("created_at")
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,76 @@
|
||||
package com.yfd.business.css.controller;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.yfd.business.css.domain.ModelTrainTask;
|
||||
import com.yfd.business.css.service.ModelTrainService;
|
||||
import com.yfd.platform.config.ResponseResult;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/train")
|
||||
public class ModelTrainController {
|
||||
|
||||
@Autowired
|
||||
private ModelTrainService modelTrainService;
|
||||
|
||||
/**
|
||||
* 上传数据集
|
||||
*/
|
||||
@PostMapping("/upload")
|
||||
public ResponseResult upload(@RequestParam("file") MultipartFile file) {
|
||||
String path = modelTrainService.uploadDataset(file);
|
||||
return ResponseResult.successData(path);
|
||||
}
|
||||
|
||||
/**
|
||||
* 提交训练任务
|
||||
*/
|
||||
@PostMapping("/submit")
|
||||
public ResponseResult submit(@RequestBody ModelTrainTask task) {
|
||||
String taskId = modelTrainService.submitTask(task);
|
||||
return ResponseResult.successData(taskId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询任务列表
|
||||
*/
|
||||
@GetMapping("/list")
|
||||
public ResponseResult list(@RequestParam(defaultValue = "1") Integer current,
|
||||
@RequestParam(defaultValue = "10") Integer size) {
|
||||
Page<ModelTrainTask> page = modelTrainService.page(new Page<>(current, size));
|
||||
return ResponseResult.successData(page);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询任务详情/状态
|
||||
*/
|
||||
@GetMapping("/status/{taskId}")
|
||||
public ResponseResult status(@PathVariable String taskId) {
|
||||
ModelTrainTask task = modelTrainService.syncTaskStatus(taskId);
|
||||
return ResponseResult.successData(task);
|
||||
}
|
||||
|
||||
/**
|
||||
* 发布模型
|
||||
*/
|
||||
@PostMapping("/publish")
|
||||
public ResponseResult publish(@RequestBody Map<String, String> body) {
|
||||
String taskId = body.get("taskId");
|
||||
String versionTag = body.get("versionTag");
|
||||
boolean success = modelTrainService.publishModel(taskId, versionTag);
|
||||
return success ? ResponseResult.success() : ResponseResult.error("发布失败");
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除训练任务
|
||||
*/
|
||||
@DeleteMapping("/{taskId}")
|
||||
public ResponseResult delete(@PathVariable String taskId) {
|
||||
boolean success = modelTrainService.removeById(taskId);
|
||||
return success ? ResponseResult.success() : ResponseResult.error("删除失败");
|
||||
}
|
||||
}
|
||||
@ -1,45 +1,61 @@
|
||||
package com.yfd.business.css.controller;
|
||||
|
||||
import com.yfd.business.css.build.SimBuilder;
|
||||
import com.yfd.business.css.facade.SimDataFacade;
|
||||
import com.yfd.business.css.model.*;
|
||||
import com.yfd.business.css.service.MaterialService;
|
||||
import com.yfd.business.css.service.SimService;
|
||||
import com.yfd.business.css.service.SimInferService;
|
||||
import com.yfd.platform.config.ResponseResult;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 仿真服务统一入口控制器
|
||||
* 负责接收仿真请求,协调 Facade、Builder 和 Service 完成仿真计算
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/sim")
|
||||
public class SimController {
|
||||
|
||||
// private final ProjectRepository projectRepo;
|
||||
// private final EventRepository eventRepo;
|
||||
// private final InfluenceRepository influenceRepo;
|
||||
// private final SimService simService;
|
||||
@Autowired private SimDataFacade simDataFacade;
|
||||
@Autowired private SimBuilder simBuilder;
|
||||
@Autowired private SimService simService;
|
||||
@Autowired private MaterialService materialService;
|
||||
@Autowired private SimInferService simInferService;
|
||||
|
||||
// public SimController(ProjectRepository projectRepo,
|
||||
// EventRepository eventRepo,
|
||||
// InfluenceRepository influenceRepo,
|
||||
// SimService simService) {
|
||||
// this.projectRepo = projectRepo;
|
||||
// this.eventRepo = eventRepo;
|
||||
// this.influenceRepo = influenceRepo;
|
||||
// this.simService = simService;
|
||||
// }
|
||||
/**
|
||||
* 执行仿真计算
|
||||
*
|
||||
* @param req 请求参数,包含 projectId, scenarioId, steps
|
||||
* @return 仿真结果,包含 code, msg, data (data.frames)
|
||||
*/
|
||||
@PostMapping("/run")
|
||||
public ResponseResult run(@RequestBody Map<String, Object> req) {
|
||||
String projectId = (String) req.get("projectId");
|
||||
String scenarioId = (String) req.get("scenarioId");
|
||||
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
|
||||
|
||||
// @PostMapping("/run")
|
||||
// public Map<String,Object> runSimulation(@RequestParam String projectId,
|
||||
// @RequestParam String scenarioId,
|
||||
// @RequestParam int steps) {
|
||||
// 1. Load Data: 获取项目、设备和事件数据
|
||||
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
|
||||
|
||||
// List<Map<String, String>> topoRows = projectRepo.loadTopo(projectId);
|
||||
// List<Map<String, Object>> attrChanges = eventRepo.loadAttrChanges(projectId, scenarioId);
|
||||
// List<Map<String, Object>> influenceRows = influenceRepo.loadInfluences(projectId);
|
||||
// 2. Build Model: 将原始数据转换为仿真模型 (Units, Events, Nodes)
|
||||
List<SimUnit> units = simBuilder.buildUnits(data, materialService);
|
||||
List<SimEvent> events = simBuilder.buildEvents(data.getEvents());
|
||||
List<SimInfluenceNode> nodes = simBuilder.buildInfluenceNodes(data.getProject());
|
||||
|
||||
// List<SimUnit> units = SimBuilder.buildUnits(topoRows);
|
||||
// List<SimEvent> events = SimBuilder.buildEvents(attrChanges);
|
||||
// List<SimInfluenceNode> nodes = SimBuilder.buildInfluenceNodes(influenceRows);
|
||||
// 3. Run Engine: 执行核心仿真计算,返回上下文结果
|
||||
SimContext ctx = simService.runSimulation(units, events, nodes, steps);
|
||||
|
||||
// SimContext ctx = simService.runSimulation(units, events, nodes, steps);
|
||||
// 4. Async Infer: 异步执行推理并保存结果
|
||||
simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units);
|
||||
|
||||
// // 转换成推理接口 JSON
|
||||
// return InferenceConverter.toInferenceInput(ctx, units, projectId, scenarioId);
|
||||
// }
|
||||
// 5. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据
|
||||
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
|
||||
|
||||
return ResponseResult.successData(resultData);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,58 @@
|
||||
package com.yfd.business.css.domain;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.*;
|
||||
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@TableName(value = "model_train_task", autoResultMap = true)
|
||||
public class ModelTrainTask implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@TableId(value = "task_id", type = IdType.ASSIGN_UUID)
|
||||
private String taskId;
|
||||
|
||||
@TableField("task_name")
|
||||
private String taskName;
|
||||
|
||||
@TableField("algorithm_type")
|
||||
private String algorithmType;
|
||||
|
||||
@TableField("device_type")
|
||||
private String deviceType;
|
||||
|
||||
@TableField("dataset_path")
|
||||
private String datasetPath;
|
||||
|
||||
@TableField(value = "train_params", typeHandler = JacksonTypeHandler.class)
|
||||
private Map<String, Object> trainParams; // 使用 Map 存储 JSON
|
||||
|
||||
@TableField("status")
|
||||
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
||||
|
||||
@TableField(value = "metrics", typeHandler = JacksonTypeHandler.class)
|
||||
private Map<String, Object> metrics;
|
||||
|
||||
@TableField("model_output_path")
|
||||
private String modelOutputPath;
|
||||
|
||||
@TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class)
|
||||
private Map<String, Object> featureMapSnapshot;
|
||||
|
||||
@TableField("metrics_image_path")
|
||||
private String metricsImagePath;
|
||||
|
||||
@TableField("error_log")
|
||||
private String errorLog;
|
||||
|
||||
@TableField(value = "created_at", fill = FieldFill.INSERT)
|
||||
private LocalDateTime createdAt;
|
||||
|
||||
@TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE)
|
||||
private LocalDateTime updatedAt;
|
||||
}
|
||||
@ -7,6 +7,7 @@ import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.math.BigDecimal;
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Data
|
||||
@ -41,4 +42,16 @@ public class Scenario implements Serializable {
|
||||
|
||||
@TableField("algorithm_type")
|
||||
private String algorithmType;
|
||||
|
||||
/**
|
||||
* Keff预警阈值
|
||||
*/
|
||||
@TableField("keff_threshold")
|
||||
private BigDecimal keffThreshold;
|
||||
|
||||
/**
|
||||
* 设备算法配置映射,格式:{"deviceId1":"GPR", "deviceId2":"MLP"}
|
||||
*/
|
||||
@TableField("device_algo_config")
|
||||
private String deviceAlgoConfig;
|
||||
}
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
package com.yfd.business.css.facade;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.yfd.business.css.domain.Device;
|
||||
import com.yfd.business.css.domain.Event;
|
||||
import com.yfd.business.css.domain.Project;
|
||||
import com.yfd.business.css.model.SimDataPackage;
|
||||
import com.yfd.business.css.service.DeviceService;
|
||||
import com.yfd.business.css.service.EventService;
|
||||
import com.yfd.business.css.service.ProjectService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 仿真数据门面
|
||||
* 负责与各个业务Service交互,获取仿真所需的原始数据 (Project, Events 等)
|
||||
*/
|
||||
@Component
|
||||
public class SimDataFacade {
|
||||
@Autowired private ProjectService projectService;
|
||||
@Autowired private EventService eventService;
|
||||
@Autowired private DeviceService deviceService;
|
||||
|
||||
public SimDataPackage loadSimulationData(String projectId, String scenarioId) {
|
||||
// 1. 获取项目与拓扑
|
||||
Project project = projectService.getById(projectId);
|
||||
if (project == null) throw new IllegalArgumentException("Project not found: " + projectId);
|
||||
|
||||
// 2. 获取设备列表 (用于解析静态尺寸属性)
|
||||
List<Device> devices = deviceService.list(new QueryWrapper<Device>().eq("project_id", projectId));
|
||||
|
||||
// 3. 获取事件
|
||||
List<Event> events = eventService.list(new QueryWrapper<Event>().eq("scenario_id", scenarioId));
|
||||
|
||||
return new SimDataPackage(project, devices, events);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
package com.yfd.business.css.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.yfd.business.css.domain.ModelTrainTask;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface ModelTrainTaskMapper extends BaseMapper<ModelTrainTask> {
|
||||
}
|
||||
@ -0,0 +1,40 @@
|
||||
package com.yfd.business.css.model;
|
||||
|
||||
import com.yfd.business.css.domain.Device;
|
||||
import com.yfd.business.css.domain.Event;
|
||||
import com.yfd.business.css.domain.Project;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 仿真数据包
|
||||
* 封装从数据库加载的原始项目、设备和事件数据
|
||||
*/
|
||||
public class SimDataPackage {
|
||||
private final Project project;
|
||||
private final List<Device> devices;
|
||||
private final List<Event> events;
|
||||
|
||||
public SimDataPackage(Project project, List<Event> events) {
|
||||
this(project, Collections.emptyList(), events);
|
||||
}
|
||||
|
||||
public SimDataPackage(Project project, List<Device> devices, List<Event> events) {
|
||||
this.project = project;
|
||||
this.devices = devices;
|
||||
this.events = events;
|
||||
}
|
||||
|
||||
public Project getProject() {
|
||||
return project;
|
||||
}
|
||||
|
||||
public List<Device> getDevices() {
|
||||
return devices;
|
||||
}
|
||||
|
||||
public List<Event> getEvents() {
|
||||
return events;
|
||||
}
|
||||
}
|
||||
@ -7,16 +7,23 @@ public class SimEvent {
|
||||
private final SimPropertyKey key;
|
||||
private final double value;
|
||||
private final EventType type;
|
||||
private final boolean isOverride; // 新增:是否强制覆盖
|
||||
|
||||
public SimEvent(int step, SimPropertyKey key, double value, EventType type) {
|
||||
public SimEvent(int step, SimPropertyKey key, double value, EventType type, boolean isOverride) {
|
||||
this.step = step;
|
||||
this.key = key;
|
||||
this.value = value;
|
||||
this.type = type;
|
||||
this.isOverride = isOverride;
|
||||
}
|
||||
|
||||
public SimEvent(int step, SimPropertyKey key, double value, EventType type) {
|
||||
this(step, key, value, type, false); // 默认为非强制
|
||||
}
|
||||
|
||||
public int getStep() { return step; }
|
||||
public SimPropertyKey getKey() { return key; }
|
||||
public double getValue() { return value; }
|
||||
public EventType getType() { return type; }
|
||||
public boolean isOverride() { return isOverride; }
|
||||
}
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
package com.yfd.business.css.model;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* 仿真结果转换器
|
||||
* 负责将 SimContext 中的仿真状态转换为前端友好的帧结构
|
||||
*/
|
||||
public class SimResultConverter {
|
||||
|
||||
/**
|
||||
* 将仿真上下文转换为帧列表
|
||||
* @param ctx 仿真上下文
|
||||
* @param units 仿真单元列表 (用于补充设备类型等元数据)
|
||||
* @param projectId 项目ID
|
||||
* @param scenarioId 情景ID
|
||||
* @return 包含完整结果数据的 Map
|
||||
*/
|
||||
public static Map<String, Object> toFrames(SimContext ctx, List<SimUnit> units, String projectId, String scenarioId) {
|
||||
List<Map<String, Object>> frames = new ArrayList<>();
|
||||
Map<String, SimUnit> unitMap = new HashMap<>();
|
||||
for (SimUnit u : units) {
|
||||
unitMap.put(u.unitId(), u);
|
||||
}
|
||||
|
||||
Map<Integer, Map<SimPropertyKey, Double>> timeline = ctx.getTimeline();
|
||||
List<Integer> steps = new ArrayList<>(timeline.keySet());
|
||||
Collections.sort(steps);
|
||||
|
||||
for (Integer step : steps) {
|
||||
Map<SimPropertyKey, Double> snapshot = timeline.get(step);
|
||||
Map<String, Map<String, Object>> devices = new HashMap<>();
|
||||
|
||||
// 1. 先初始化所有设备及其静态属性
|
||||
for (SimUnit unit : units) {
|
||||
Map<String, Object> devState = new HashMap<>();
|
||||
devState.put("deviceType", unit.deviceType());
|
||||
// 注入所有静态属性作为基线
|
||||
if (unit.staticProperties() != null) {
|
||||
devState.putAll(unit.staticProperties());
|
||||
}
|
||||
devices.put(unit.unitId(), devState);
|
||||
}
|
||||
|
||||
// 2. 覆盖动态计算出的属性值
|
||||
for (Map.Entry<SimPropertyKey, Double> entry : snapshot.entrySet()) {
|
||||
String unitId = entry.getKey().unitId();
|
||||
String prop = entry.getKey().property();
|
||||
Double val = entry.getValue();
|
||||
|
||||
Map<String, Object> devState = devices.get(unitId);
|
||||
if (devState == null) continue; // 理论上不应该发生,除非有未定义的unit
|
||||
|
||||
devState.put(prop, val);
|
||||
}
|
||||
|
||||
Map<String, Object> frame = new HashMap<>();
|
||||
frame.put("step", step);
|
||||
frame.put("time", step); // 假设 1 step = 1 time unit
|
||||
frame.put("devices", devices);
|
||||
frames.add(frame);
|
||||
}
|
||||
|
||||
Map<String, Object> data = new HashMap<>();
|
||||
data.put("frames", frames);
|
||||
data.put("projectId", projectId);
|
||||
data.put("scenarioId", scenarioId);
|
||||
data.put("generated", Map.of("snapshots", steps.size()));
|
||||
|
||||
return data;
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,9 @@
|
||||
package com.yfd.business.css.model;
|
||||
|
||||
public record SimUnit(String unitId, String deviceId, String materialId, String deviceType) {}
|
||||
import java.util.Map;
|
||||
|
||||
public record SimUnit(String unitId, String deviceId, String materialId, String deviceType, Map<String, Double> staticProperties) {
|
||||
public SimUnit(String unitId, String deviceId, String materialId, String deviceType) {
|
||||
this(unitId, deviceId, materialId, deviceType, Map.of());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package com.yfd.business.css.service;
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yfd.business.css.domain.Scenario;
|
||||
import com.yfd.business.css.domain.ScenarioResult;
|
||||
import com.yfd.business.css.model.DeviceStepInfo;
|
||||
import com.yfd.business.css.model.InferRequest;
|
||||
@ -18,6 +20,7 @@ import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@Service
|
||||
public class DeviceInferService {
|
||||
@Value("${python.api.url:http://localhost:8000}")
|
||||
@ -35,63 +38,94 @@ public class DeviceInferService {
|
||||
|
||||
public void processDeviceInference(String projectId, String scenarioId,
|
||||
Map<String, List<DeviceStepInfo>> groupedDevices) {
|
||||
// 遍历每个设备类型,调用对应的模型进行推理
|
||||
// 1. 获取情景配置信息
|
||||
Scenario scenario = scenarioService.getById(scenarioId);
|
||||
if (scenario == null) {
|
||||
throw new IllegalArgumentException("场景 " + scenarioId + " 不存在");
|
||||
}
|
||||
|
||||
// 全局算法类型
|
||||
String globalAlgorithmType = scenario.getAlgorithmType();
|
||||
if (globalAlgorithmType == null) {
|
||||
throw new IllegalArgumentException("场景 " + scenarioId + " 未配置全局算法类型");
|
||||
}
|
||||
|
||||
// 解析设备级算法配置
|
||||
Map<String, String> deviceAlgoConfig = new HashMap<>();
|
||||
if (scenario.getDeviceAlgoConfig() != null && !scenario.getDeviceAlgoConfig().isBlank()) {
|
||||
try {
|
||||
deviceAlgoConfig = objectMapper.readValue(scenario.getDeviceAlgoConfig(), new TypeReference<Map<String, String>>() {});
|
||||
} catch (Exception e) {
|
||||
System.err.println("解析设备算法配置失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 遍历每个设备类型组
|
||||
for (Map.Entry<String, List<DeviceStepInfo>> entry : groupedDevices.entrySet()) {
|
||||
String deviceType = entry.getKey();
|
||||
// algorithmType通过scenarioId获取
|
||||
String algorithmType = scenarioService.getAlgorithmType(scenarioId);
|
||||
if (algorithmType == null) {
|
||||
throw new IllegalArgumentException("场景 " + scenarioId + " 未配置算法类型");
|
||||
}
|
||||
|
||||
//modelPath根据模型类型、设备类型,从algorithm_model获取活动的model_path
|
||||
String modelPath = algorithmModelService.getCurrentModelPath(algorithmType, deviceType);
|
||||
System.out.println("modelPath="+modelPath);
|
||||
if (modelPath == null) {
|
||||
throw new IllegalArgumentException("未配置 " + algorithmType + " 模型路径");
|
||||
}
|
||||
List<DeviceStepInfo> devices = entry.getValue();
|
||||
System.out.println("devices="+devices);
|
||||
// 校验设备数据是否完整
|
||||
if (devices == null || devices.isEmpty()) {
|
||||
throw new IllegalArgumentException("设备数据为空,无法进行模拟");
|
||||
|
||||
if (devices == null || devices.isEmpty()) continue;
|
||||
|
||||
// 3. 将同一类型的设备,按算法类型进行二级分组
|
||||
Map<String, List<DeviceStepInfo>> algoGroup = new HashMap<>();
|
||||
|
||||
for (DeviceStepInfo device : devices) {
|
||||
// 优先使用设备特定配置,否则使用全局配置
|
||||
String algoType = deviceAlgoConfig.getOrDefault(device.getDeviceId(), globalAlgorithmType);
|
||||
algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device);
|
||||
}
|
||||
|
||||
// 封装推理请求
|
||||
InferRequest request = buildInferenceRequest(deviceType,devices,modelPath);
|
||||
System.out.println("request="+request);
|
||||
// 4. 对每个算法分组进行推理
|
||||
for (Map.Entry<String, List<DeviceStepInfo>> algoEntry : algoGroup.entrySet()) {
|
||||
String currentAlgoType = algoEntry.getKey();
|
||||
List<DeviceStepInfo> currentDevices = algoEntry.getValue();
|
||||
|
||||
// 调用Python推理服务
|
||||
InferResponse response = infer(request);
|
||||
System.out.println("推理服务返回结果: " + response);
|
||||
// 获取模型路径
|
||||
System.out.println("Processing inference for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType);
|
||||
String modelPath = algorithmModelService.getCurrentModelPath(currentAlgoType, deviceType);
|
||||
System.out.println("modelPath=" + modelPath);
|
||||
|
||||
// 处理推理结果
|
||||
if (response != null && response.getCode() == 0) {
|
||||
// 重新构建InferResponse对象示例
|
||||
// 1. 从response获取数据
|
||||
int code = response.getCode();
|
||||
String msg = response.getMsg();
|
||||
InferResponse.InferData originalData = response.getData();
|
||||
if (modelPath == null) {
|
||||
System.err.println("Model path not found for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType);
|
||||
// 这里可以选择抛异常中断,或者跳过该组继续处理其他组
|
||||
// 为了保证健壮性,这里选择记录错误并跳过,或者根据业务需求抛出异常
|
||||
// throw new IllegalArgumentException("未配置 " + currentAlgoType + " 模型路径 (deviceType: " + deviceType + ")");
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2. 重新构建InferData
|
||||
InferResponse.InferData newData = new InferResponse.InferData();
|
||||
newData.setItems(originalData.getItems());
|
||||
newData.setMeta(originalData.getMeta());
|
||||
newData.setFeatures(originalData.getFeatures());
|
||||
newData.setKeff(originalData.getKeff());
|
||||
// 封装推理请求
|
||||
InferRequest request = buildInferenceRequest(deviceType, currentDevices, modelPath);
|
||||
// System.out.println("request=" + request);
|
||||
|
||||
// 3. 构建新的InferResponse
|
||||
InferResponse reconstructedResponse = new InferResponse();
|
||||
reconstructedResponse.setCode(code);
|
||||
reconstructedResponse.setMsg(msg);
|
||||
reconstructedResponse.setData(newData);
|
||||
try {
|
||||
// 调用Python推理服务
|
||||
InferResponse response = infer(request);
|
||||
System.out.println("推理服务返回结果: code=" + (response != null ? response.getCode() : "null"));
|
||||
|
||||
System.out.println("重新构建的response: " + reconstructedResponse);
|
||||
// 处理推理结果
|
||||
if (response != null && response.getCode() == 0) {
|
||||
// 重新构建InferResponse对象示例 (为了兼容性保留原有逻辑)
|
||||
InferResponse.InferData originalData = response.getData();
|
||||
InferResponse.InferData newData = new InferResponse.InferData();
|
||||
newData.setItems(originalData.getItems());
|
||||
newData.setMeta(originalData.getMeta());
|
||||
newData.setFeatures(originalData.getFeatures());
|
||||
newData.setKeff(originalData.getKeff());
|
||||
|
||||
// 使用重新构建的response处理结果
|
||||
processInferenceResults(projectId, scenarioId, deviceType, devices, reconstructedResponse);
|
||||
} else {
|
||||
throw new RuntimeException("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误"));
|
||||
InferResponse reconstructedResponse = new InferResponse();
|
||||
reconstructedResponse.setCode(response.getCode());
|
||||
reconstructedResponse.setMsg(response.getMsg());
|
||||
reconstructedResponse.setData(newData);
|
||||
|
||||
processInferenceResults(projectId, scenarioId, deviceType, currentDevices, reconstructedResponse);
|
||||
} else {
|
||||
System.err.println("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误"));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.err.println("推理异常: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
package com.yfd.business.css.service;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.yfd.business.css.domain.ModelTrainTask;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
public interface ModelTrainService extends IService<ModelTrainTask> {
|
||||
/**
|
||||
* 上传数据集
|
||||
* @param file 文件
|
||||
* @return 文件保存路径
|
||||
*/
|
||||
String uploadDataset(MultipartFile file);
|
||||
|
||||
/**
|
||||
* 提交训练任务
|
||||
* @param task 任务信息
|
||||
* @return 任务ID
|
||||
*/
|
||||
String submitTask(ModelTrainTask task);
|
||||
|
||||
/**
|
||||
* 同步任务状态
|
||||
* @param taskId 任务ID
|
||||
* @return 最新任务信息
|
||||
*/
|
||||
ModelTrainTask syncTaskStatus(String taskId);
|
||||
|
||||
/**
|
||||
* 发布模型
|
||||
* @param taskId 任务ID
|
||||
* @param versionTag 版本号
|
||||
* @return 是否成功
|
||||
*/
|
||||
boolean publishModel(String taskId, String versionTag);
|
||||
}
|
||||
@ -0,0 +1,90 @@
|
||||
package com.yfd.business.css.service;
|
||||
|
||||
import com.yfd.business.css.model.*;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 仿真推理服务
|
||||
* 负责将仿真计算结果 (SimContext) 转换为推理请求,并异步调用 Python 推理服务入库
|
||||
*/
|
||||
@Service
|
||||
public class SimInferService {
|
||||
|
||||
@Autowired
|
||||
private DeviceInferService deviceInferService;
|
||||
|
||||
@Async
|
||||
public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List<SimUnit> units) {
|
||||
try {
|
||||
// 1. 准备元数据映射: unitId -> SimUnit
|
||||
Map<String, SimUnit> unitMap = units.stream()
|
||||
.collect(Collectors.toMap(SimUnit::unitId, u -> u));
|
||||
|
||||
// 2. 转换 SimContext 为按 DeviceType 分组的 DeviceStepInfo 列表
|
||||
Map<String, List<DeviceStepInfo>> groupedDevices = new HashMap<>();
|
||||
|
||||
// 遍历时间轴
|
||||
Map<Integer, Map<SimPropertyKey, Double>> timeline = context.getTimeline();
|
||||
for (Map.Entry<Integer, Map<SimPropertyKey, Double>> entry : timeline.entrySet()) {
|
||||
int step = entry.getKey();
|
||||
Map<SimPropertyKey, Double> props = entry.getValue();
|
||||
|
||||
// 按 unitId 聚合属性
|
||||
Map<String, Map<String, Object>> unitProps = new HashMap<>();
|
||||
for (Map.Entry<SimPropertyKey, Double> propEntry : props.entrySet()) {
|
||||
SimPropertyKey key = propEntry.getKey();
|
||||
String unitId = key.unitId();
|
||||
String propName = key.property();
|
||||
Double value = propEntry.getValue();
|
||||
|
||||
unitProps.computeIfAbsent(unitId, k -> new HashMap<>()).put(propName, value);
|
||||
}
|
||||
|
||||
// 构建 DeviceStepInfo
|
||||
for (Map.Entry<String, Map<String, Object>> unitEntry : unitProps.entrySet()) {
|
||||
String unitId = unitEntry.getKey();
|
||||
Map<String, Object> properties = unitEntry.getValue();
|
||||
|
||||
SimUnit unit = unitMap.get(unitId);
|
||||
if (unit == null) continue; // 可能是中间变量或非设备单元
|
||||
|
||||
String deviceType = unit.deviceType();
|
||||
if (deviceType == null || deviceType.isEmpty()) continue;
|
||||
|
||||
// 确保静态属性也包含在内(如果 SimContext 中没有,从 SimUnit 补充)
|
||||
// SimService 应该已经初始化了静态属性到 Context,但为了保险起见:
|
||||
for (Map.Entry<String, Double> staticProp : unit.staticProperties().entrySet()) {
|
||||
properties.putIfAbsent(staticProp.getKey(), staticProp.getValue());
|
||||
}
|
||||
|
||||
DeviceStepInfo info = new DeviceStepInfo();
|
||||
info.setDeviceId(unit.deviceId()); // unitId 通常等于 deviceId
|
||||
info.setDeviceType(deviceType);
|
||||
info.setProperties(properties);
|
||||
info.setStep(step);
|
||||
info.setTime(step); // 假设 time = step,或者根据步长计算
|
||||
|
||||
groupedDevices.computeIfAbsent(deviceType, k -> new ArrayList<>()).add(info);
|
||||
}
|
||||
}
|
||||
|
||||
if (groupedDevices.isEmpty()) {
|
||||
System.out.println("No device data found for inference.");
|
||||
return;
|
||||
}
|
||||
|
||||
// 3. 调用 DeviceInferService 进行推理和入库
|
||||
// 复用现有的 processDeviceInference 方法,它处理了按类型分组的数据
|
||||
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
|
||||
|
||||
} catch (Exception e) {
|
||||
System.err.println("Async inference failed: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,28 @@
|
||||
package com.yfd.business.css.service;
|
||||
|
||||
import com.yfd.business.css.model.*;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 仿真核心计算引擎
|
||||
* 负责执行时间步推进,处理静态属性、事件输入、影响传播和强制覆盖
|
||||
*/
|
||||
@Service
|
||||
public class SimService {
|
||||
|
||||
/**
|
||||
* 运行仿真
|
||||
*
|
||||
* @param units 仿真单元列表 (设备/物料)
|
||||
* @param events 仿真事件列表 (属性变更)
|
||||
* @param nodes 影响关系节点列表
|
||||
* @param steps 仿真总步数
|
||||
* @return 仿真上下文,包含所有时间步的状态快照
|
||||
*/
|
||||
public SimContext runSimulation(List<SimUnit> units,
|
||||
List<SimEvent> events,
|
||||
List<SimInfluenceNode> nodes,
|
||||
@ -13,41 +30,96 @@ public class SimService {
|
||||
|
||||
SimContext ctx = new SimContext();
|
||||
|
||||
// 初始化设备/物料属性
|
||||
// Step 1: 初始化静态基线 (t=0)
|
||||
// 将所有单元的静态属性写入初始上下文
|
||||
for (SimUnit unit : units) {
|
||||
ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "device.power"));
|
||||
ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "material.quantity"));
|
||||
ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "keff")); // 占位
|
||||
unit.staticProperties().forEach((k, v) ->
|
||||
ctx.setValue(SimPropertyKey.of(unit.unitId(), k), v)
|
||||
);
|
||||
}
|
||||
ctx.snapshot(0); // 记录初始状态
|
||||
|
||||
// 每步仿真
|
||||
// Step 2: 循环推进 (t=1 to steps)
|
||||
for (int step = 1; step <= steps; step++) {
|
||||
// 2.1 Event (Input): 应用普通事件 (作为输入条件)
|
||||
applyEvents(ctx, events, step, false);
|
||||
|
||||
// 事件
|
||||
for (SimEvent e : events) {
|
||||
if (e.getStep() == step) {
|
||||
ctx.setValue(e.getKey(), e.getValue());
|
||||
}
|
||||
}
|
||||
// 2.2 Influence: 计算影响关系
|
||||
// 基于当前 ctx (包含 static + input event) 计算派生属性
|
||||
applyInfluences(ctx, nodes, step);
|
||||
|
||||
// 影响关系
|
||||
for (SimInfluenceNode node : nodes) {
|
||||
double sum = node.getBias();
|
||||
for (SimInfluenceSource s : node.getSources()) {
|
||||
sum += ctx.getValue(s.getSource()) * s.getCoeff();
|
||||
}
|
||||
ctx.setValue(node.getTarget(), sum);
|
||||
}
|
||||
|
||||
// keff 占位(后续可以替换为推理接口计算)
|
||||
for (SimUnit unit : units) {
|
||||
ctx.setValue(SimPropertyKey.of(unit.unitId(), "keff"), 0.0);
|
||||
}
|
||||
// 2.3 Event (Override): 应用强制覆盖事件
|
||||
// 再次覆盖,确保强制逻辑生效 (如故障注入、手动锁定)
|
||||
applyEvents(ctx, events, step, true);
|
||||
|
||||
// 2.4 Snapshot: 保存当前步的状态快照
|
||||
ctx.snapshot(step);
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 应用事件
|
||||
* @param ctx 上下文
|
||||
* @param events 事件列表
|
||||
* @param step 当前步
|
||||
* @param isOverride 是否为强制覆盖事件
|
||||
*/
|
||||
private void applyEvents(SimContext ctx, List<SimEvent> events, int step, boolean isOverride) {
|
||||
for (SimEvent e : events) {
|
||||
if (e.getStep() == step && e.isOverride() == isOverride) {
|
||||
// 当前仅支持 SET 类型,ADD/MULTIPLY 可按需扩展
|
||||
ctx.setValue(e.getKey(), e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算影响关系
|
||||
* @param ctx 上下文
|
||||
* @param nodes 影响节点列表
|
||||
* @param currentStep 当前步
|
||||
*/
|
||||
private void applyInfluences(SimContext ctx, List<SimInfluenceNode> nodes, int currentStep) {
|
||||
// 简单实现:直接计算并更新。如果存在依赖环,可能需要多轮迭代或拓扑排序。
|
||||
// 这里假设无环或允许一帧内的即时传播。
|
||||
// 为了避免更新顺序影响结果(A依赖B,先算A还是先算B),理想情况下应使用双缓冲(读取上一帧或本帧快照)。
|
||||
// 但根据需求 "Influence(派生计算):基于当前状态执行影响关系计算",且 ProjectServiceImpl 中是直接读取当前 state。
|
||||
// ProjectServiceImpl 中 readValue 会读取 delay。
|
||||
|
||||
// 我们可以先计算所有变更,再统一应用,避免计算中间态影响后续计算(除非是有意为之的级联)
|
||||
// 这里采用:先计算所有目标值,存入临时 Map,最后统一写入 ctx
|
||||
|
||||
List<Map.Entry<SimPropertyKey, Double>> updates = new ArrayList<>();
|
||||
|
||||
for (SimInfluenceNode node : nodes) {
|
||||
double sum = node.getBias();
|
||||
for (SimInfluenceSource s : node.getSources()) {
|
||||
// 处理延迟: t - delay
|
||||
// SimContext 目前只存储当前值。如果需要历史值,SimContext 需要支持根据 step 获取 snapshot。
|
||||
// 假设 SimContext.getValue(key) 返回当前值。
|
||||
// 如果有 delay,我们需要访问历史 snapshot。
|
||||
|
||||
// 暂时假设无 delay 或 delay=0,读取当前值
|
||||
// 如果 SimInfluenceSource 有 delay 属性,需要 SimContext 支持历史回溯
|
||||
// 修改 SimContext 增加 getValue(key, step) ?
|
||||
// 目前 SimContext 实现未知,假设只能读当前。
|
||||
// 如果需要支持 delay,SimContext 需要保留 history。
|
||||
|
||||
// 简单起见,这里先读当前值。完善版本需要 SimContext 提供 getHistoryValue(key, step - delay)
|
||||
|
||||
// 修正:SimContext 应该能获取历史。我们先看 SimContext 定义。
|
||||
// 假设 SimContext 内部有 snapshots。
|
||||
|
||||
double sourceVal = ctx.getValue(s.getSource()); // 暂读当前
|
||||
sum += sourceVal * s.getCoeff();
|
||||
}
|
||||
updates.add(Map.entry(node.getTarget(), sum));
|
||||
}
|
||||
|
||||
for (Map.Entry<SimPropertyKey, Double> entry : updates) {
|
||||
ctx.setValue(entry.getKey(), entry.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,12 +16,19 @@ public class AlgorithmModelServiceImpl extends ServiceImpl<AlgorithmModelMapper,
|
||||
|
||||
@Override
|
||||
public String getCurrentModelPath(String algorithmType, String deviceType) {
|
||||
System.out.println("Querying current model path for algorithmType: " + algorithmType + ", deviceType: " + deviceType);
|
||||
QueryWrapper<AlgorithmModel> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.eq("algorithm_type", algorithmType)
|
||||
.eq("device_type", deviceType)
|
||||
.eq("is_current", 1); // 当前激活版本
|
||||
AlgorithmModel model = getOne(queryWrapper);
|
||||
return model != null ? model.getModelPath() : null;
|
||||
if (model != null) {
|
||||
System.out.println("Found model: " + model.getModelPath());
|
||||
return model.getModelPath();
|
||||
} else {
|
||||
System.out.println("Model not found in database.");
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -0,0 +1,228 @@
|
||||
package com.yfd.business.css.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yfd.business.css.common.exception.BizException;
|
||||
import com.yfd.business.css.domain.AlgorithmModel;
|
||||
import com.yfd.business.css.domain.ModelTrainTask;
|
||||
import com.yfd.business.css.mapper.ModelTrainTaskMapper;
|
||||
import com.yfd.business.css.service.AlgorithmModelService;
|
||||
import com.yfd.business.css.service.ModelTrainService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, ModelTrainTask> implements ModelTrainService {
|
||||
|
||||
@Value("${file-space.upload-path:./data/uploads/}")
|
||||
private String uploadPath;
|
||||
|
||||
@Value("${python.api.url:http://localhost:8000}")
|
||||
private String pythonApiUrl;
|
||||
|
||||
@Autowired
|
||||
private RestTemplate restTemplate;
|
||||
|
||||
@Autowired
|
||||
private AlgorithmModelService algorithmModelService;
|
||||
|
||||
@Autowired
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
@Override
|
||||
public String uploadDataset(MultipartFile file) {
|
||||
if (file.isEmpty()) {
|
||||
throw new BizException("上传文件不能为空");
|
||||
}
|
||||
|
||||
String originalFilename = file.getOriginalFilename();
|
||||
String dateDir = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd"));
|
||||
String uuid = UUID.randomUUID().toString();
|
||||
String extension = originalFilename.substring(originalFilename.lastIndexOf("."));
|
||||
String newFilename = uuid + extension;
|
||||
|
||||
String saveDir = uploadPath + File.separator + dateDir;
|
||||
File dir = new File(saveDir);
|
||||
if (!dir.exists()) {
|
||||
dir.mkdirs();
|
||||
}
|
||||
|
||||
String fullPath = saveDir + File.separator + newFilename;
|
||||
try {
|
||||
file.transferTo(new File(fullPath));
|
||||
return fullPath;
|
||||
} catch (IOException e) {
|
||||
throw new BizException("文件保存失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Transactional
|
||||
public String submitTask(ModelTrainTask task) {
|
||||
// 1. 初始化状态
|
||||
task.setStatus("PENDING");
|
||||
if (task.getTaskId() == null) {
|
||||
task.setTaskId(UUID.randomUUID().toString());
|
||||
}
|
||||
this.save(task);
|
||||
|
||||
// 2. 异步调用 Python 训练
|
||||
asyncCallTrain(task);
|
||||
|
||||
return task.getTaskId();
|
||||
}
|
||||
|
||||
@Async
|
||||
public void asyncCallTrain(ModelTrainTask task) {
|
||||
try {
|
||||
// 更新状态为 TRAINING
|
||||
task.setStatus("TRAINING");
|
||||
this.updateById(task);
|
||||
|
||||
// 构建请求参数
|
||||
Map<String, Object> request = new HashMap<>();
|
||||
request.put("task_id", task.getTaskId());
|
||||
request.put("algorithm_type", task.getAlgorithmType());
|
||||
request.put("device_type", task.getDeviceType());
|
||||
request.put("dataset_path", task.getDatasetPath());
|
||||
request.put("hyperparameters", task.getTrainParams());
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
||||
|
||||
String url = pythonApiUrl + "/v1/train";
|
||||
// 这里假设 Python 服务会立即返回,启动后台线程
|
||||
// 如果 Python 服务是阻塞的,这里的 Async 会起作用
|
||||
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
||||
|
||||
if (response.getStatusCode().is2xxSuccessful()) {
|
||||
// 调用成功,等待后续轮询状态
|
||||
System.out.println("训练任务提交成功: " + task.getTaskId());
|
||||
} else {
|
||||
task.setStatus("FAILED");
|
||||
task.setErrorLog("提交训练任务失败: " + response.getStatusCode());
|
||||
this.updateById(task);
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
task.setStatus("FAILED");
|
||||
task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
|
||||
this.updateById(task);
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelTrainTask syncTaskStatus(String taskId) {
|
||||
ModelTrainTask task = this.getById(taskId);
|
||||
if (task == null) {
|
||||
throw new BizException("任务不存在");
|
||||
}
|
||||
|
||||
// 只有在 TRAINING 状态才去查询 Python 服务
|
||||
if ("TRAINING".equals(task.getStatus()) || "PENDING".equals(task.getStatus())) {
|
||||
try {
|
||||
String url = pythonApiUrl + "/v1/train/status/" + taskId;
|
||||
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class);
|
||||
|
||||
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
|
||||
Map<String, Object> body = response.getBody();
|
||||
String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED
|
||||
|
||||
if (status != null) {
|
||||
task.setStatus(status);
|
||||
|
||||
if ("SUCCESS".equals(status)) {
|
||||
task.setModelOutputPath((String) body.get("model_path"));
|
||||
task.setMetrics((Map<String, Object>) body.get("metrics"));
|
||||
task.setFeatureMapSnapshot((Map<String, Object>) body.get("feature_map"));
|
||||
task.setMetricsImagePath((String) body.get("metrics_image"));
|
||||
} else if ("FAILED".equals(status)) {
|
||||
task.setErrorLog((String) body.get("error"));
|
||||
} else if ("TRAINING".equals(status)) {
|
||||
// 可以更新进度或其他中间指标
|
||||
if (body.containsKey("metrics")) {
|
||||
task.setMetrics((Map<String, Object>) body.get("metrics"));
|
||||
}
|
||||
}
|
||||
|
||||
this.updateById(task);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// 查询失败,暂不更新状态,或者记录日志
|
||||
System.err.println("同步任务状态失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Transactional
|
||||
public boolean publishModel(String taskId, String versionTag) {
|
||||
ModelTrainTask task = this.getById(taskId);
|
||||
if (task == null) {
|
||||
throw new BizException("任务不存在");
|
||||
}
|
||||
|
||||
if (!"SUCCESS".equals(task.getStatus())) {
|
||||
throw new BizException("任务未完成或失败,无法发布");
|
||||
}
|
||||
|
||||
// 检查版本号唯一性
|
||||
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
|
||||
.eq("algorithm_type", task.getAlgorithmType())
|
||||
.eq("device_type", task.getDeviceType())
|
||||
.eq("version_tag", versionTag));
|
||||
if (count > 0) {
|
||||
throw new BizException("版本号已存在");
|
||||
}
|
||||
|
||||
// 创建正式模型记录
|
||||
AlgorithmModel model = new AlgorithmModel();
|
||||
model.setAlgorithmType(task.getAlgorithmType());
|
||||
model.setDeviceType(task.getDeviceType());
|
||||
model.setVersionTag(versionTag);
|
||||
model.setModelPath(task.getModelOutputPath()); // 这里简化处理,直接引用临时路径,实际生产建议移动文件到正式目录
|
||||
model.setMetricsImagePath(task.getMetricsImagePath());
|
||||
model.setTrainedAt(LocalDateTime.now());
|
||||
model.setIsCurrent(0); // 默认不激活
|
||||
|
||||
try {
|
||||
if (task.getFeatureMapSnapshot() != null) {
|
||||
model.setFeatureMapSnapshot(objectMapper.writeValueAsString(task.getFeatureMapSnapshot()));
|
||||
} else {
|
||||
model.setFeatureMapSnapshot("{}");
|
||||
}
|
||||
|
||||
if (task.getMetrics() != null) {
|
||||
model.setMetrics(objectMapper.writeValueAsString(task.getMetrics()));
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new BizException("JSON 序列化失败");
|
||||
}
|
||||
|
||||
return algorithmModelService.save(model);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user