模型训练

This commit is contained in:
wanxiaoli 2026-03-11 13:55:32 +08:00
parent 632fd0f725
commit aff2d7feea
21 changed files with 1292 additions and 147 deletions

View File

@ -5,6 +5,7 @@ import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.web.servlet.ServletComponentScan;
@SpringBootApplication(
scanBasePackages = {
@ -18,6 +19,9 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
},
exclude = {DataSourceAutoConfiguration.class, RedisAutoConfiguration.class}
)
@ServletComponentScan(basePackages = {
"com.yfd.platform.config"
})
@MapperScan(basePackages = {
"com.yfd.platform.**.mapper",
"com.yfd.business.css.**.mapper"

View File

@ -1,44 +1,368 @@
package com.yfd.business.css.build;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.domain.Device;
import com.yfd.business.css.domain.Event;
import com.yfd.business.css.domain.Material;
import com.yfd.business.css.domain.Project;
import com.yfd.business.css.model.*;
import com.yfd.business.css.service.MaterialService;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Map;
import java.util.*;
/*
*将数据库原始表 Map拓扑事件影响关系
*转成 SimUnit / SimEvent / SimInfluenceNode
*/
/**
* 仿真模型构建器
* 负责将原始的 Project (Topology) Event 数据解析转换为 SimUnit, SimEvent, SimInfluenceNode
*/
@Component
public class SimBuilder {
// public static List<SimUnit> buildUnitsFromTopo(List<ProjectTopo> topoRows) {
// return topoRows.stream()
// .map(t -> new SimUnit(t.getDeviceId(), t.getDeviceId(), t.getMaterialId(), t.getDeviceType()))
// .toList();
// }
private final ObjectMapper objectMapper = new ObjectMapper();
// public static List<SimEvent> buildEventsFromAttrChanges(List<EventAttrChange> attrChanges) {
// return attrChanges.stream()
// .map(e -> new SimEvent(e.getStep(),
// SimPropertyKey.of(e.getDeviceId(), e.getAttr()),
// e.getValue(),
// SimEvent.EventType.SET))
// .toList();
// }
// 1. 构建单元 (SimUnit) - 包含静态属性
public List<SimUnit> buildUnits(SimDataPackage data, MaterialService materialService) {
List<SimUnit> units = new ArrayList<>();
Project project = data.getProject();
if (project == null || project.getTopology() == null) return units;
// public static List<SimInfluenceNode> buildInfluenceNodesFromInfluences(List<Influence> influenceRows) {
// return influenceRows.stream()
// .map(i -> {
// List<SimInfluenceSource> sources = i.getSources().stream()
// .map(s -> new SimInfluenceSource(
// SimPropertyKey.of(s.getUnitId(), s.getProp()),
// s.getCoeff(),
// s.getDelay()
// )).toList();
// return new SimInfluenceNode(SimPropertyKey.of(i.getTargetUnit(), i.getTargetProp()),
// sources,
// i.getBias());
// }).toList();
// }
// 建立 Device ID -> Device 实体的映射方便查找
Map<String, Device> deviceMap = new HashMap<>();
if (data.getDevices() != null) {
for (Device d : data.getDevices()) {
deviceMap.put(d.getDeviceId(), d);
}
}
try {
JsonNode root = objectMapper.readTree(project.getTopology());
Map<String, String> devToMat = buildDeviceMaterialMap(root);
Map<String, Map<String, Double>> matStaticDb = buildMaterialStaticFromDb(devToMat, materialService);
JsonNode devicesNode = root.path("devices");
if (devicesNode.isArray()) {
for (JsonNode deviceNode : devicesNode) {
String deviceId = deviceNode.path("deviceId").asText();
if (deviceId == null || deviceId.isEmpty()) continue;
String type = deviceNode.path("type").asText();
String materialId = devToMat.get(deviceId);
Map<String, Double> staticProps = new HashMap<>();
// 1.1 解析 topology 中的 static 节点
JsonNode st = deviceNode.path("static");
if (st.isObject()) {
st.fieldNames().forEachRemaining(k -> {
if (!"unit".equals(k)) {
JsonNode v = st.path(k);
if (v.isNumber()) staticProps.put(k, v.asDouble());
else if (v.isTextual()) staticProps.put(k, parseDouble(v.asText()));
}
});
}
// 1.2 解析 Device.size 并注入
Device device = deviceMap.get(deviceId);
if (device != null) {
injectDeviceSize(device, staticProps);
}
// 1.3 注入物料静态属性
if (materialId != null) {
Map<String, Double> mProps = matStaticDb.get(materialId);
if (mProps != null) staticProps.putAll(mProps);
// 解析 topology material static
JsonNode mats = deviceNode.path("materials");
if (mats.isMissingNode() || mats.isNull()) mats = deviceNode.path("material");
if (mats.isObject()) {
JsonNode mst = mats.path("static");
if (mst.isObject()) {
mst.fieldNames().forEachRemaining(k -> {
if (!"unit".equals(k)) {
JsonNode v = mst.path(k);
if (v.isNumber()) staticProps.put(k, v.asDouble());
else if (v.isTextual()) staticProps.put(k, parseDouble(v.asText()));
}
});
}
}
}
units.add(new SimUnit(deviceId, deviceId, materialId, type, staticProps));
}
}
} catch (Exception e) {
throw new RuntimeException("Build units failed", e);
}
return units;
}
private void injectDeviceSize(Device device, Map<String, Double> staticProps) {
try {
String sizeJson = device.getSize();
if (sizeJson == null || sizeJson.isBlank()) return;
JsonNode sizeNode = objectMapper.readTree(sizeJson);
String type = device.getType(); // 假设 Device getType() 方法或从外部传入
if (type == null) {
// Fallback: 尝试通用解析 (现有逻辑)
parseCommonSize(sizeNode, staticProps);
return;
}
switch (type) {
case "FlatTank":
if (sizeNode.has("width")) staticProps.put("width", sizeNode.get("width").asDouble());
if (sizeNode.has("length")) staticProps.put("length", sizeNode.get("length").asDouble());
if (sizeNode.has("height")) staticProps.put("height", sizeNode.get("height").asDouble());
break;
case "CylindricalTank":
case "AnnularTank":
case "TubeBundleTank":
parseCommonSize(sizeNode, staticProps);
break;
case "ExtractionColumn":
// 优先提取 tray_section (塔身)
if (sizeNode.has("tray_section")) {
parseCommonSize(sizeNode.get("tray_section"), staticProps);
}
// 可选将其他段的尺寸作为特殊属性注入例如 lower_expanded_height
break;
case "FluidizedBed":
// 优先提取 reaction_section (反应段)
if (sizeNode.has("reaction_section")) {
parseCommonSize(sizeNode.get("reaction_section"), staticProps);
}
break;
case "ACFTank":
// 优先提取 annular_cylinder (圆柱段)
if (sizeNode.has("annular_cylinder")) {
parseCommonSize(sizeNode.get("annular_cylinder"), staticProps);
}
break;
default:
parseCommonSize(sizeNode, staticProps);
}
} catch (Exception e) {
System.err.println("解析Device.size失败" + e.getMessage());
}
}
private void parseCommonSize(JsonNode node, Map<String, Double> staticProps) {
if (node.has("outer_diameter")) {
staticProps.put("diameter", node.get("outer_diameter").asDouble());
} else if (node.has("diameter")) {
staticProps.put("diameter", node.get("diameter").asDouble());
}
if (node.has("height")) {
staticProps.put("height", node.get("height").asDouble());
}
}
// 2. 构建事件 (SimEvent)
public List<SimEvent> buildEvents(List<Event> events) {
List<SimEvent> simEvents = new ArrayList<>();
if (events == null) return simEvents;
for (Event ev : events) {
String json = ev.getAttrChanges();
if (json == null || json.isBlank()) continue;
try {
JsonNode root = objectMapper.readTree(json);
JsonNode target = root.path("target");
String entityType = optText(target, "entityType");
String entityId = optText(target, "entityId");
String property = optText(target, "property");
if (entityType == null || entityId == null || property == null) continue;
// 区分 entityType? SimUnit id deviceId如果是 material这里需要注意 SimUnit 的定义
// ProjectServiceImpl 中是分开处理 device material state
// 这里假设 SimUnit deviceId 为主键material 属性也挂在 deviceId (SimContext key: deviceId + property)
// 或者 SimContext key 包含 entityType? SimPropertyKey(unitId, property)
// ProjectServiceImpl: entityType + ":" + entityId + ":" + property
// SimPropertyKey: unitId, property. unitId 通常是 deviceId.
// 如果是 material 属性ProjectServiceImpl 中是把 material 属性 copy 到了 device state
// 所以这里 key unitId 应该是 deviceId (如果 target material需找到对应的 deviceId)
// Event 中只记录了 materialId没记录 deviceId (虽然 Event 表有 device_id 字段)
String targetUnitId = entityId;
// 注意如果 event 针对 material material 被多个 device 共享这里逻辑需要确认
// ProjectServiceImpl 是遍历所有 device如果 device 关联了该 material就应用覆盖
// SimService SimEvent 直接作用于 SimContext
// 如果 SimUnit 对应 Device那么 Material 的事件需要转换为针对所有引用该 Material Device 的事件
// 这需要在 buildEvents 时知道 Device-Material 关系或者 SimContext 支持 Material 维度的存储
// 简化起见假设 Event 表中 device_id 有值或者我们在 buildEvents 时传入 devToMat 映射关系?
// Event 表有 device_id 字段可以直接用
if ("material".equals(entityType)) {
if (ev.getDeviceId() != null) targetUnitId = ev.getDeviceId();
}
SimPropertyKey key = SimPropertyKey.of(targetUnitId, property);
JsonNode segments = root.path("segments");
if (segments.isArray()) {
for (JsonNode seg : segments) {
String interp = optText(seg, "interp");
JsonNode timeline = seg.path("timeline");
if (timeline.isArray()) {
// 简化处理全部转为 step-setramp 也离散化?
// ProjectServiceImpl 是保留 ramp 结构readValue 时计算
// SimService step 推进
// 这里我们将 timeline 展开为每个 step SimEvent
// 线性插值处理比较复杂这里先处理离散点
for (JsonNode p : timeline) {
double t = p.path("t").asDouble();
double val = p.path("value").asDouble();
// Step ? Time to Step 转换
// 假设 1s = 1 step
int step = (int) t;
// isOverride=true 在所有计算完成后强制修改属性值 1 帧延迟
// Input 事件 ( isOverride=false ), applyInfluences 影响计算 之前 执行下游设备读取的是 修改后 的值当前帧的值
simEvents.add(new SimEvent(step, key, val, SimEvent.EventType.SET, false)); // 默认为强制覆盖?
}
}
}
}
} catch (Exception e) {
// log error
}
}
return simEvents;
}
// 3. 构建影响关系 (SimInfluenceNode)
public List<SimInfluenceNode> buildInfluenceNodes(Project project) {
List<SimInfluenceNode> nodes = new ArrayList<>();
if (project == null || project.getTopology() == null) return nodes;
try {
JsonNode root = objectMapper.readTree(project.getTopology());
JsonNode devicesNode = root.path("devices");
if (devicesNode.isArray()) {
for (JsonNode dn : devicesNode) {
String deviceId = optText(dn, "deviceId");
if (deviceId == null) continue;
// Device Properties Influence
parseInfluence(dn.path("properties"), deviceId, nodes);
// Material Properties Influence (mapped to device)
JsonNode mats = dn.path("materials");
if (mats.isMissingNode() || mats.isNull()) mats = dn.path("material");
if (mats.isObject()) {
parseInfluence(mats.path("properties"), deviceId, nodes);
} else if (mats.isArray()) {
for (JsonNode mn : mats) {
parseInfluence(mn.path("properties"), deviceId, nodes);
}
}
}
}
} catch (Exception e) {
throw new RuntimeException("Build influence nodes failed", e);
}
return nodes;
}
private void parseInfluence(JsonNode props, String targetUnitId, List<SimInfluenceNode> nodes) {
if (!props.isObject()) return;
props.fieldNames().forEachRemaining(propName -> {
JsonNode prop = props.path(propName);
if ("influence".equalsIgnoreCase(optText(prop, "type"))) {
double bias = prop.path("bias").asDouble(0.0);
List<SimInfluenceSource> sources = new ArrayList<>();
JsonNode srcs = prop.path("sources");
if (srcs.isArray()) {
for (JsonNode sNode : srcs) {
String seId = optText(sNode, "entityId"); // source unit id
String seProp = optText(sNode, "property");
double coef = sNode.path("coefficient").asDouble(1.0);
long delayMs = 0L;
JsonNode delay = sNode.path("delay");
if (delay.path("enabled").asBoolean(false)) {
long t = delay.path("time").asLong(0L);
String u = optText(delay, "unit");
delayMs = toMillis(t, u);
}
int delayStep = (int) (delayMs / 1000); // 假设 1s = 1 step
sources.add(new SimInfluenceSource(SimPropertyKey.of(seId, seProp), coef, delayStep));
}
}
nodes.add(new SimInfluenceNode(SimPropertyKey.of(targetUnitId, propName), sources, bias));
}
});
}
// Helpers
private Map<String, String> buildDeviceMaterialMap(JsonNode root) {
Map<String, String> map = new HashMap<>();
JsonNode devicesNode = root.path("devices");
if (devicesNode.isArray()) {
for (JsonNode dn : devicesNode) {
String did = optText(dn, "deviceId");
if (did == null) continue;
JsonNode mats = dn.path("materials");
if (mats.isMissingNode() || mats.isNull()) mats = dn.path("material");
String mid = null;
if (mats.isArray() && mats.size() > 0) mid = optText(mats.get(0), "materialId");
else if (mats.isObject()) mid = optText(mats, "materialId");
if (mid != null) map.put(did, mid);
}
}
return map;
}
private Map<String, Map<String, Double>> buildMaterialStaticFromDb(Map<String, String> devToMat, MaterialService materialService) {
Map<String, Map<String, Double>> out = new HashMap<>();
if (devToMat.isEmpty()) return out;
Set<String> mids = new HashSet<>(devToMat.values());
List<Material> mats = materialService.list(new QueryWrapper<Material>().in("material_id", mids));
for (Material m : mats) {
Map<String, Double> s = new HashMap<>();
if (m.getUConcentration() != null) s.put("u_concentration", m.getUConcentration().doubleValue());
if (m.getPuConcentration() != null) s.put("pu_concentration", m.getPuConcentration().doubleValue());
if (m.getUEnrichment() != null) s.put("u_enrichment", m.getUEnrichment().doubleValue());
if (m.getPuIsotope() != null) s.put("pu_isotope", m.getPuIsotope().doubleValue());
if (m.getUo2Density() != null) s.put("uo2_density", m.getUo2Density().doubleValue());
if (m.getPuo2Density() != null) s.put("puo2_density", m.getPuo2Density().doubleValue());
if (m.getHno3Acidity() != null) s.put("hno3_acidity", m.getHno3Acidity().doubleValue());
if (m.getH2c2o4Concentration() != null) s.put("h2c2o4_concentration", m.getH2c2o4Concentration().doubleValue());
if (m.getOrganicRatio() != null) s.put("organic_ratio", m.getOrganicRatio().doubleValue());
if (m.getMoistureContent() != null) s.put("moisture_content", m.getMoistureContent().doubleValue());
// ... 其他属性
out.put(m.getMaterialId(), s);
}
return out;
}
private String optText(JsonNode node, String key) {
if (node.has(key)) return node.get(key).asText();
return null;
}
private double parseDouble(String s) {
try { return Double.parseDouble(s); } catch (Exception e) { return 0.0; }
}
private long toMillis(long time, String unit) {
if ("s".equalsIgnoreCase(unit)) return time * 1000;
if ("ms".equalsIgnoreCase(unit)) return time;
return time * 1000; // default s
}
}

View File

@ -0,0 +1,14 @@
package com.yfd.business.css.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
@Configuration
public class RestTemplateConfig {
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}

View File

@ -154,7 +154,7 @@ public class EventController {
new QueryWrapper<Event>()
.select("event_id","scenario_id","device_id","material_id","attr_changes","trigger_time","created_at","modifier")
.eq("scenario_id", scenarioId)
.orderByDesc("created_at")
.orderByAsc("created_at")
);
}

View File

@ -0,0 +1,76 @@
package com.yfd.business.css.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.yfd.business.css.domain.ModelTrainTask;
import com.yfd.business.css.service.ModelTrainService;
import com.yfd.platform.config.ResponseResult;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.Map;
@RestController
@RequestMapping("/train")
public class ModelTrainController {
@Autowired
private ModelTrainService modelTrainService;
/**
* 上传数据集
*/
@PostMapping("/upload")
public ResponseResult upload(@RequestParam("file") MultipartFile file) {
String path = modelTrainService.uploadDataset(file);
return ResponseResult.successData(path);
}
/**
* 提交训练任务
*/
@PostMapping("/submit")
public ResponseResult submit(@RequestBody ModelTrainTask task) {
String taskId = modelTrainService.submitTask(task);
return ResponseResult.successData(taskId);
}
/**
* 查询任务列表
*/
@GetMapping("/list")
public ResponseResult list(@RequestParam(defaultValue = "1") Integer current,
@RequestParam(defaultValue = "10") Integer size) {
Page<ModelTrainTask> page = modelTrainService.page(new Page<>(current, size));
return ResponseResult.successData(page);
}
/**
* 查询任务详情/状态
*/
@GetMapping("/status/{taskId}")
public ResponseResult status(@PathVariable String taskId) {
ModelTrainTask task = modelTrainService.syncTaskStatus(taskId);
return ResponseResult.successData(task);
}
/**
* 发布模型
*/
@PostMapping("/publish")
public ResponseResult publish(@RequestBody Map<String, String> body) {
String taskId = body.get("taskId");
String versionTag = body.get("versionTag");
boolean success = modelTrainService.publishModel(taskId, versionTag);
return success ? ResponseResult.success() : ResponseResult.error("发布失败");
}
/**
* 删除训练任务
*/
@DeleteMapping("/{taskId}")
public ResponseResult delete(@PathVariable String taskId) {
boolean success = modelTrainService.removeById(taskId);
return success ? ResponseResult.success() : ResponseResult.error("删除失败");
}
}

View File

@ -1,45 +1,61 @@
package com.yfd.business.css.controller;
import com.yfd.business.css.build.SimBuilder;
import com.yfd.business.css.facade.SimDataFacade;
import com.yfd.business.css.model.*;
import com.yfd.business.css.service.MaterialService;
import com.yfd.business.css.service.SimService;
import com.yfd.business.css.service.SimInferService;
import com.yfd.platform.config.ResponseResult;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.Map;
/**
* 仿真服务统一入口控制器
* 负责接收仿真请求协调 FacadeBuilder Service 完成仿真计算
*/
@RestController
@RequestMapping("/sim")
public class SimController {
// private final ProjectRepository projectRepo;
// private final EventRepository eventRepo;
// private final InfluenceRepository influenceRepo;
// private final SimService simService;
@Autowired private SimDataFacade simDataFacade;
@Autowired private SimBuilder simBuilder;
@Autowired private SimService simService;
@Autowired private MaterialService materialService;
@Autowired private SimInferService simInferService;
// public SimController(ProjectRepository projectRepo,
// EventRepository eventRepo,
// InfluenceRepository influenceRepo,
// SimService simService) {
// this.projectRepo = projectRepo;
// this.eventRepo = eventRepo;
// this.influenceRepo = influenceRepo;
// this.simService = simService;
// }
/**
* 执行仿真计算
*
* @param req 请求参数包含 projectId, scenarioId, steps
* @return 仿真结果包含 code, msg, data (data.frames)
*/
@PostMapping("/run")
public ResponseResult run(@RequestBody Map<String, Object> req) {
String projectId = (String) req.get("projectId");
String scenarioId = (String) req.get("scenarioId");
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
// @PostMapping("/run")
// public Map<String,Object> runSimulation(@RequestParam String projectId,
// @RequestParam String scenarioId,
// @RequestParam int steps) {
// 1. Load Data: 获取项目设备和事件数据
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
// 2. Build Model: 将原始数据转换为仿真模型 (Units, Events, Nodes)
List<SimUnit> units = simBuilder.buildUnits(data, materialService);
List<SimEvent> events = simBuilder.buildEvents(data.getEvents());
List<SimInfluenceNode> nodes = simBuilder.buildInfluenceNodes(data.getProject());
// 3. Run Engine: 执行核心仿真计算返回上下文结果
SimContext ctx = simService.runSimulation(units, events, nodes, steps);
// 4. Async Infer: 异步执行推理并保存结果
simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units);
// List<Map<String, String>> topoRows = projectRepo.loadTopo(projectId);
// List<Map<String, Object>> attrChanges = eventRepo.loadAttrChanges(projectId, scenarioId);
// List<Map<String, Object>> influenceRows = influenceRepo.loadInfluences(projectId);
// List<SimUnit> units = SimBuilder.buildUnits(topoRows);
// List<SimEvent> events = SimBuilder.buildEvents(attrChanges);
// List<SimInfluenceNode> nodes = SimBuilder.buildInfluenceNodes(influenceRows);
// SimContext ctx = simService.runSimulation(units, events, nodes, steps);
// // 转换成推理接口 JSON
// return InferenceConverter.toInferenceInput(ctx, units, projectId, scenarioId);
// }
// 5. Convert Result: 将仿真结果转换为前端友好的格式包含静态属性补全和元数据
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
return ResponseResult.successData(resultData);
}
}

View File

@ -0,0 +1,58 @@
package com.yfd.business.css.domain;
import com.baomidou.mybatisplus.annotation.*;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.Data;
import java.io.Serializable;
import java.time.LocalDateTime;
import java.util.Map;
@Data
@TableName(value = "model_train_task", autoResultMap = true)
public class ModelTrainTask implements Serializable {
private static final long serialVersionUID = 1L;
@TableId(value = "task_id", type = IdType.ASSIGN_UUID)
private String taskId;
@TableField("task_name")
private String taskName;
@TableField("algorithm_type")
private String algorithmType;
@TableField("device_type")
private String deviceType;
@TableField("dataset_path")
private String datasetPath;
@TableField(value = "train_params", typeHandler = JacksonTypeHandler.class)
private Map<String, Object> trainParams; // 使用 Map 存储 JSON
@TableField("status")
private String status; // PENDING, TRAINING, SUCCESS, FAILED
@TableField(value = "metrics", typeHandler = JacksonTypeHandler.class)
private Map<String, Object> metrics;
@TableField("model_output_path")
private String modelOutputPath;
@TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class)
private Map<String, Object> featureMapSnapshot;
@TableField("metrics_image_path")
private String metricsImagePath;
@TableField("error_log")
private String errorLog;
@TableField(value = "created_at", fill = FieldFill.INSERT)
private LocalDateTime createdAt;
@TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE)
private LocalDateTime updatedAt;
}

View File

@ -7,6 +7,7 @@ import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.io.Serializable;
import java.math.BigDecimal;
import java.time.LocalDateTime;
@Data
@ -41,4 +42,16 @@ public class Scenario implements Serializable {
@TableField("algorithm_type")
private String algorithmType;
/**
* Keff预警阈值
*/
@TableField("keff_threshold")
private BigDecimal keffThreshold;
/**
* 设备算法配置映射格式{"deviceId1":"GPR", "deviceId2":"MLP"}
*/
@TableField("device_algo_config")
private String deviceAlgoConfig;
}

View File

@ -0,0 +1,39 @@
package com.yfd.business.css.facade;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.yfd.business.css.domain.Device;
import com.yfd.business.css.domain.Event;
import com.yfd.business.css.domain.Project;
import com.yfd.business.css.model.SimDataPackage;
import com.yfd.business.css.service.DeviceService;
import com.yfd.business.css.service.EventService;
import com.yfd.business.css.service.ProjectService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* 仿真数据门面
* 负责与各个业务Service交互获取仿真所需的原始数据 (Project, Events )
*/
@Component
public class SimDataFacade {
@Autowired private ProjectService projectService;
@Autowired private EventService eventService;
@Autowired private DeviceService deviceService;
public SimDataPackage loadSimulationData(String projectId, String scenarioId) {
// 1. 获取项目与拓扑
Project project = projectService.getById(projectId);
if (project == null) throw new IllegalArgumentException("Project not found: " + projectId);
// 2. 获取设备列表 (用于解析静态尺寸属性)
List<Device> devices = deviceService.list(new QueryWrapper<Device>().eq("project_id", projectId));
// 3. 获取事件
List<Event> events = eventService.list(new QueryWrapper<Event>().eq("scenario_id", scenarioId));
return new SimDataPackage(project, devices, events);
}
}

View File

@ -0,0 +1,9 @@
package com.yfd.business.css.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.yfd.business.css.domain.ModelTrainTask;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ModelTrainTaskMapper extends BaseMapper<ModelTrainTask> {
}

View File

@ -12,4 +12,4 @@ public class SimContext {
public void snapshot(int step) { timeline.put(step, new HashMap<>(currentValues)); }
public Map<Integer, Map<SimPropertyKey, Double>> getTimeline() { return timeline; }
}
}

View File

@ -0,0 +1,40 @@
package com.yfd.business.css.model;
import com.yfd.business.css.domain.Device;
import com.yfd.business.css.domain.Event;
import com.yfd.business.css.domain.Project;
import java.util.Collections;
import java.util.List;
/**
* 仿真数据包
* 封装从数据库加载的原始项目设备和事件数据
*/
public class SimDataPackage {
private final Project project;
private final List<Device> devices;
private final List<Event> events;
public SimDataPackage(Project project, List<Event> events) {
this(project, Collections.emptyList(), events);
}
public SimDataPackage(Project project, List<Device> devices, List<Event> events) {
this.project = project;
this.devices = devices;
this.events = events;
}
public Project getProject() {
return project;
}
public List<Device> getDevices() {
return devices;
}
public List<Event> getEvents() {
return events;
}
}

View File

@ -7,16 +7,23 @@ public class SimEvent {
private final SimPropertyKey key;
private final double value;
private final EventType type;
private final boolean isOverride; // 新增是否强制覆盖
public SimEvent(int step, SimPropertyKey key, double value, EventType type) {
public SimEvent(int step, SimPropertyKey key, double value, EventType type, boolean isOverride) {
this.step = step;
this.key = key;
this.value = value;
this.type = type;
this.isOverride = isOverride;
}
public SimEvent(int step, SimPropertyKey key, double value, EventType type) {
this(step, key, value, type, false); // 默认为非强制
}
public int getStep() { return step; }
public SimPropertyKey getKey() { return key; }
public double getValue() { return value; }
public EventType getType() { return type; }
public boolean isOverride() { return isOverride; }
}

View File

@ -0,0 +1,72 @@
package com.yfd.business.css.model;
import java.util.*;
/**
* 仿真结果转换器
* 负责将 SimContext 中的仿真状态转换为前端友好的帧结构
*/
public class SimResultConverter {
/**
* 将仿真上下文转换为帧列表
* @param ctx 仿真上下文
* @param units 仿真单元列表 (用于补充设备类型等元数据)
* @param projectId 项目ID
* @param scenarioId 情景ID
* @return 包含完整结果数据的 Map
*/
public static Map<String, Object> toFrames(SimContext ctx, List<SimUnit> units, String projectId, String scenarioId) {
List<Map<String, Object>> frames = new ArrayList<>();
Map<String, SimUnit> unitMap = new HashMap<>();
for (SimUnit u : units) {
unitMap.put(u.unitId(), u);
}
Map<Integer, Map<SimPropertyKey, Double>> timeline = ctx.getTimeline();
List<Integer> steps = new ArrayList<>(timeline.keySet());
Collections.sort(steps);
for (Integer step : steps) {
Map<SimPropertyKey, Double> snapshot = timeline.get(step);
Map<String, Map<String, Object>> devices = new HashMap<>();
// 1. 先初始化所有设备及其静态属性
for (SimUnit unit : units) {
Map<String, Object> devState = new HashMap<>();
devState.put("deviceType", unit.deviceType());
// 注入所有静态属性作为基线
if (unit.staticProperties() != null) {
devState.putAll(unit.staticProperties());
}
devices.put(unit.unitId(), devState);
}
// 2. 覆盖动态计算出的属性值
for (Map.Entry<SimPropertyKey, Double> entry : snapshot.entrySet()) {
String unitId = entry.getKey().unitId();
String prop = entry.getKey().property();
Double val = entry.getValue();
Map<String, Object> devState = devices.get(unitId);
if (devState == null) continue; // 理论上不应该发生除非有未定义的unit
devState.put(prop, val);
}
Map<String, Object> frame = new HashMap<>();
frame.put("step", step);
frame.put("time", step); // 假设 1 step = 1 time unit
frame.put("devices", devices);
frames.add(frame);
}
Map<String, Object> data = new HashMap<>();
data.put("frames", frames);
data.put("projectId", projectId);
data.put("scenarioId", scenarioId);
data.put("generated", Map.of("snapshots", steps.size()));
return data;
}
}

View File

@ -1,3 +1,9 @@
package com.yfd.business.css.model;
public record SimUnit(String unitId, String deviceId, String materialId, String deviceType) {}
import java.util.Map;
public record SimUnit(String unitId, String deviceId, String materialId, String deviceType, Map<String, Double> staticProperties) {
public SimUnit(String unitId, String deviceId, String materialId, String deviceType) {
this(unitId, deviceId, materialId, deviceType, Map.of());
}
}

View File

@ -1,6 +1,8 @@
package com.yfd.business.css.service;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.domain.Scenario;
import com.yfd.business.css.domain.ScenarioResult;
import com.yfd.business.css.model.DeviceStepInfo;
import com.yfd.business.css.model.InferRequest;
@ -18,6 +20,7 @@ import org.springframework.web.client.RestTemplate;
import org.springframework.stereotype.Service;
import java.util.*;
@Service
public class DeviceInferService {
@Value("${python.api.url:http://localhost:8000}")
@ -35,63 +38,94 @@ public class DeviceInferService {
public void processDeviceInference(String projectId, String scenarioId,
Map<String, List<DeviceStepInfo>> groupedDevices) {
// 遍历每个设备类型调用对应的模型进行推理
// 1. 获取情景配置信息
Scenario scenario = scenarioService.getById(scenarioId);
if (scenario == null) {
throw new IllegalArgumentException("场景 " + scenarioId + " 不存在");
}
// 全局算法类型
String globalAlgorithmType = scenario.getAlgorithmType();
if (globalAlgorithmType == null) {
throw new IllegalArgumentException("场景 " + scenarioId + " 未配置全局算法类型");
}
// 解析设备级算法配置
Map<String, String> deviceAlgoConfig = new HashMap<>();
if (scenario.getDeviceAlgoConfig() != null && !scenario.getDeviceAlgoConfig().isBlank()) {
try {
deviceAlgoConfig = objectMapper.readValue(scenario.getDeviceAlgoConfig(), new TypeReference<Map<String, String>>() {});
} catch (Exception e) {
System.err.println("解析设备算法配置失败: " + e.getMessage());
}
}
// 2. 遍历每个设备类型组
for (Map.Entry<String, List<DeviceStepInfo>> entry : groupedDevices.entrySet()) {
String deviceType = entry.getKey();
// algorithmType通过scenarioId获取
String algorithmType = scenarioService.getAlgorithmType(scenarioId);
if (algorithmType == null) {
throw new IllegalArgumentException("场景 " + scenarioId + " 未配置算法类型");
}
//modelPath根据模型类型设备类型从algorithm_model获取活动的model_path
String modelPath = algorithmModelService.getCurrentModelPath(algorithmType, deviceType);
System.out.println("modelPath="+modelPath);
if (modelPath == null) {
throw new IllegalArgumentException("未配置 " + algorithmType + " 模型路径");
}
List<DeviceStepInfo> devices = entry.getValue();
System.out.println("devices="+devices);
// 校验设备数据是否完整
if (devices == null || devices.isEmpty()) {
throw new IllegalArgumentException("设备数据为空,无法进行模拟");
if (devices == null || devices.isEmpty()) continue;
// 3. 将同一类型的设备按算法类型进行二级分组
Map<String, List<DeviceStepInfo>> algoGroup = new HashMap<>();
for (DeviceStepInfo device : devices) {
// 优先使用设备特定配置否则使用全局配置
String algoType = deviceAlgoConfig.getOrDefault(device.getDeviceId(), globalAlgorithmType);
algoGroup.computeIfAbsent(algoType, k -> new ArrayList<>()).add(device);
}
// 封装推理请求
InferRequest request = buildInferenceRequest(deviceType,devices,modelPath);
System.out.println("request="+request);
// 调用Python推理服务
InferResponse response = infer(request);
System.out.println("推理服务返回结果: " + response);
// 处理推理结果
if (response != null && response.getCode() == 0) {
// 重新构建InferResponse对象示例
// 1. 从response获取数据
int code = response.getCode();
String msg = response.getMsg();
InferResponse.InferData originalData = response.getData();
// 2. 重新构建InferData
InferResponse.InferData newData = new InferResponse.InferData();
newData.setItems(originalData.getItems());
newData.setMeta(originalData.getMeta());
newData.setFeatures(originalData.getFeatures());
newData.setKeff(originalData.getKeff());
// 3. 构建新的InferResponse
InferResponse reconstructedResponse = new InferResponse();
reconstructedResponse.setCode(code);
reconstructedResponse.setMsg(msg);
reconstructedResponse.setData(newData);
System.out.println("重新构建的response: " + reconstructedResponse);
// 使用重新构建的response处理结果
processInferenceResults(projectId, scenarioId, deviceType, devices, reconstructedResponse);
} else {
throw new RuntimeException("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误"));
// 4. 对每个算法分组进行推理
for (Map.Entry<String, List<DeviceStepInfo>> algoEntry : algoGroup.entrySet()) {
String currentAlgoType = algoEntry.getKey();
List<DeviceStepInfo> currentDevices = algoEntry.getValue();
// 获取模型路径
System.out.println("Processing inference for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType);
String modelPath = algorithmModelService.getCurrentModelPath(currentAlgoType, deviceType);
System.out.println("modelPath=" + modelPath);
if (modelPath == null) {
System.err.println("Model path not found for algorithmType: " + currentAlgoType + ", deviceType: " + deviceType);
// 这里可以选择抛异常中断或者跳过该组继续处理其他组
// 为了保证健壮性这里选择记录错误并跳过或者根据业务需求抛出异常
// throw new IllegalArgumentException("未配置 " + currentAlgoType + " 模型路径 (deviceType: " + deviceType + ")");
continue;
}
// 封装推理请求
InferRequest request = buildInferenceRequest(deviceType, currentDevices, modelPath);
// System.out.println("request=" + request);
try {
// 调用Python推理服务
InferResponse response = infer(request);
System.out.println("推理服务返回结果: code=" + (response != null ? response.getCode() : "null"));
// 处理推理结果
if (response != null && response.getCode() == 0) {
// 重新构建InferResponse对象示例 (为了兼容性保留原有逻辑)
InferResponse.InferData originalData = response.getData();
InferResponse.InferData newData = new InferResponse.InferData();
newData.setItems(originalData.getItems());
newData.setMeta(originalData.getMeta());
newData.setFeatures(originalData.getFeatures());
newData.setKeff(originalData.getKeff());
InferResponse reconstructedResponse = new InferResponse();
reconstructedResponse.setCode(response.getCode());
reconstructedResponse.setMsg(response.getMsg());
reconstructedResponse.setData(newData);
processInferenceResults(projectId, scenarioId, deviceType, currentDevices, reconstructedResponse);
} else {
System.err.println("推理服务调用失败: " + (response != null ? response.getMsg() : "未知错误"));
}
} catch (Exception e) {
System.err.println("推理异常: " + e.getMessage());
e.printStackTrace();
}
}
}
}

View File

@ -0,0 +1,36 @@
package com.yfd.business.css.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.yfd.business.css.domain.ModelTrainTask;
import org.springframework.web.multipart.MultipartFile;
public interface ModelTrainService extends IService<ModelTrainTask> {
/**
* 上传数据集
* @param file 文件
* @return 文件保存路径
*/
String uploadDataset(MultipartFile file);
/**
* 提交训练任务
* @param task 任务信息
* @return 任务ID
*/
String submitTask(ModelTrainTask task);
/**
* 同步任务状态
* @param taskId 任务ID
* @return 最新任务信息
*/
ModelTrainTask syncTaskStatus(String taskId);
/**
* 发布模型
* @param taskId 任务ID
* @param versionTag 版本号
* @return 是否成功
*/
boolean publishModel(String taskId, String versionTag);
}

View File

@ -0,0 +1,90 @@
package com.yfd.business.css.service;
import com.yfd.business.css.model.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
/**
* 仿真推理服务
* 负责将仿真计算结果 (SimContext) 转换为推理请求并异步调用 Python 推理服务入库
*/
@Service
public class SimInferService {
@Autowired
private DeviceInferService deviceInferService;
@Async
public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List<SimUnit> units) {
try {
// 1. 准备元数据映射: unitId -> SimUnit
Map<String, SimUnit> unitMap = units.stream()
.collect(Collectors.toMap(SimUnit::unitId, u -> u));
// 2. 转换 SimContext 为按 DeviceType 分组的 DeviceStepInfo 列表
Map<String, List<DeviceStepInfo>> groupedDevices = new HashMap<>();
// 遍历时间轴
Map<Integer, Map<SimPropertyKey, Double>> timeline = context.getTimeline();
for (Map.Entry<Integer, Map<SimPropertyKey, Double>> entry : timeline.entrySet()) {
int step = entry.getKey();
Map<SimPropertyKey, Double> props = entry.getValue();
// unitId 聚合属性
Map<String, Map<String, Object>> unitProps = new HashMap<>();
for (Map.Entry<SimPropertyKey, Double> propEntry : props.entrySet()) {
SimPropertyKey key = propEntry.getKey();
String unitId = key.unitId();
String propName = key.property();
Double value = propEntry.getValue();
unitProps.computeIfAbsent(unitId, k -> new HashMap<>()).put(propName, value);
}
// 构建 DeviceStepInfo
for (Map.Entry<String, Map<String, Object>> unitEntry : unitProps.entrySet()) {
String unitId = unitEntry.getKey();
Map<String, Object> properties = unitEntry.getValue();
SimUnit unit = unitMap.get(unitId);
if (unit == null) continue; // 可能是中间变量或非设备单元
String deviceType = unit.deviceType();
if (deviceType == null || deviceType.isEmpty()) continue;
// 确保静态属性也包含在内如果 SimContext 中没有 SimUnit 补充
// SimService 应该已经初始化了静态属性到 Context但为了保险起见
for (Map.Entry<String, Double> staticProp : unit.staticProperties().entrySet()) {
properties.putIfAbsent(staticProp.getKey(), staticProp.getValue());
}
DeviceStepInfo info = new DeviceStepInfo();
info.setDeviceId(unit.deviceId()); // unitId 通常等于 deviceId
info.setDeviceType(deviceType);
info.setProperties(properties);
info.setStep(step);
info.setTime(step); // 假设 time = step或者根据步长计算
groupedDevices.computeIfAbsent(deviceType, k -> new ArrayList<>()).add(info);
}
}
if (groupedDevices.isEmpty()) {
System.out.println("No device data found for inference.");
return;
}
// 3. 调用 DeviceInferService 进行推理和入库
// 复用现有的 processDeviceInference 方法它处理了按类型分组的数据
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
} catch (Exception e) {
System.err.println("Async inference failed: " + e.getMessage());
e.printStackTrace();
}
}
}

View File

@ -1,11 +1,28 @@
package com.yfd.business.css.service;
import com.yfd.business.css.model.*;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* 仿真核心计算引擎
* 负责执行时间步推进处理静态属性事件输入影响传播和强制覆盖
*/
@Service
public class SimService {
/**
* 运行仿真
*
* @param units 仿真单元列表 (设备/物料)
* @param events 仿真事件列表 (属性变更)
* @param nodes 影响关系节点列表
* @param steps 仿真总步数
* @return 仿真上下文包含所有时间步的状态快照
*/
public SimContext runSimulation(List<SimUnit> units,
List<SimEvent> events,
List<SimInfluenceNode> nodes,
@ -13,41 +30,96 @@ public class SimService {
SimContext ctx = new SimContext();
// 初始化设备/物料属性
// Step 1: 初始化静态基线 (t=0)
// 将所有单元的静态属性写入初始上下文
for (SimUnit unit : units) {
ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "device.power"));
ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "material.quantity"));
ctx.ensureProperty(SimPropertyKey.of(unit.unitId(), "keff")); // 占位
unit.staticProperties().forEach((k, v) ->
ctx.setValue(SimPropertyKey.of(unit.unitId(), k), v)
);
}
ctx.snapshot(0); // 记录初始状态
// 每步仿真
// Step 2: 循环推进 (t=1 to steps)
for (int step = 1; step <= steps; step++) {
// 2.1 Event (Input): 应用普通事件 (作为输入条件)
applyEvents(ctx, events, step, false);
// 事件
for (SimEvent e : events) {
if (e.getStep() == step) {
ctx.setValue(e.getKey(), e.getValue());
}
}
// 2.2 Influence: 计算影响关系
// 基于当前 ctx (包含 static + input event) 计算派生属性
applyInfluences(ctx, nodes, step);
// 影响关系
for (SimInfluenceNode node : nodes) {
double sum = node.getBias();
for (SimInfluenceSource s : node.getSources()) {
sum += ctx.getValue(s.getSource()) * s.getCoeff();
}
ctx.setValue(node.getTarget(), sum);
}
// keff 占位后续可以替换为推理接口计算
for (SimUnit unit : units) {
ctx.setValue(SimPropertyKey.of(unit.unitId(), "keff"), 0.0);
}
// 2.3 Event (Override): 应用强制覆盖事件
// 再次覆盖确保强制逻辑生效 (如故障注入手动锁定)
applyEvents(ctx, events, step, true);
// 2.4 Snapshot: 保存当前步的状态快照
ctx.snapshot(step);
}
return ctx;
}
}
/**
* 应用事件
* @param ctx 上下文
* @param events 事件列表
* @param step 当前步
* @param isOverride 是否为强制覆盖事件
*/
private void applyEvents(SimContext ctx, List<SimEvent> events, int step, boolean isOverride) {
for (SimEvent e : events) {
if (e.getStep() == step && e.isOverride() == isOverride) {
// 当前仅支持 SET 类型ADD/MULTIPLY 可按需扩展
ctx.setValue(e.getKey(), e.getValue());
}
}
}
/**
* 计算影响关系
* @param ctx 上下文
* @param nodes 影响节点列表
* @param currentStep 当前步
*/
private void applyInfluences(SimContext ctx, List<SimInfluenceNode> nodes, int currentStep) {
// 简单实现直接计算并更新如果存在依赖环可能需要多轮迭代或拓扑排序
// 这里假设无环或允许一帧内的即时传播
// 为了避免更新顺序影响结果A依赖B先算A还是先算B理想情况下应使用双缓冲读取上一帧或本帧快照
// 但根据需求 "Influence派生计算基于当前状态执行影响关系计算" ProjectServiceImpl 中是直接读取当前 state
// ProjectServiceImpl readValue 会读取 delay
// 我们可以先计算所有变更再统一应用避免计算中间态影响后续计算除非是有意为之的级联
// 这里采用先计算所有目标值存入临时 Map最后统一写入 ctx
List<Map.Entry<SimPropertyKey, Double>> updates = new ArrayList<>();
for (SimInfluenceNode node : nodes) {
double sum = node.getBias();
for (SimInfluenceSource s : node.getSources()) {
// 处理延迟: t - delay
// SimContext 目前只存储当前值如果需要历史值SimContext 需要支持根据 step 获取 snapshot
// 假设 SimContext.getValue(key) 返回当前值
// 如果有 delay我们需要访问历史 snapshot
// 暂时假设无 delay delay=0读取当前值
// 如果 SimInfluenceSource delay 属性需要 SimContext 支持历史回溯
// 修改 SimContext 增加 getValue(key, step) ?
// 目前 SimContext 实现未知假设只能读当前
// 如果需要支持 delaySimContext 需要保留 history
// 简单起见这里先读当前值完善版本需要 SimContext 提供 getHistoryValue(key, step - delay)
// 修正SimContext 应该能获取历史我们先看 SimContext 定义
// 假设 SimContext 内部有 snapshots
double sourceVal = ctx.getValue(s.getSource()); // 暂读当前
sum += sourceVal * s.getCoeff();
}
updates.add(Map.entry(node.getTarget(), sum));
}
for (Map.Entry<SimPropertyKey, Double> entry : updates) {
ctx.setValue(entry.getKey(), entry.getValue());
}
}
}

View File

@ -16,12 +16,19 @@ public class AlgorithmModelServiceImpl extends ServiceImpl<AlgorithmModelMapper,
@Override
public String getCurrentModelPath(String algorithmType, String deviceType) {
System.out.println("Querying current model path for algorithmType: " + algorithmType + ", deviceType: " + deviceType);
QueryWrapper<AlgorithmModel> queryWrapper = new QueryWrapper<>();
queryWrapper.eq("algorithm_type", algorithmType)
.eq("device_type", deviceType)
.eq("is_current", 1); // 当前激活版本
AlgorithmModel model = getOne(queryWrapper);
return model != null ? model.getModelPath() : null;
if (model != null) {
System.out.println("Found model: " + model.getModelPath());
return model.getModelPath();
} else {
System.out.println("Model not found in database.");
return null;
}
}
@Override

View File

@ -0,0 +1,228 @@
package com.yfd.business.css.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.common.exception.BizException;
import com.yfd.business.css.domain.AlgorithmModel;
import com.yfd.business.css.domain.ModelTrainTask;
import com.yfd.business.css.mapper.ModelTrainTaskMapper;
import com.yfd.business.css.service.AlgorithmModelService;
import com.yfd.business.css.service.ModelTrainService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
@Service
public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, ModelTrainTask> implements ModelTrainService {
@Value("${file-space.upload-path:./data/uploads/}")
private String uploadPath;
@Value("${python.api.url:http://localhost:8000}")
private String pythonApiUrl;
@Autowired
private RestTemplate restTemplate;
@Autowired
private AlgorithmModelService algorithmModelService;
@Autowired
private ObjectMapper objectMapper;
@Override
public String uploadDataset(MultipartFile file) {
if (file.isEmpty()) {
throw new BizException("上传文件不能为空");
}
String originalFilename = file.getOriginalFilename();
String dateDir = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd"));
String uuid = UUID.randomUUID().toString();
String extension = originalFilename.substring(originalFilename.lastIndexOf("."));
String newFilename = uuid + extension;
String saveDir = uploadPath + File.separator + dateDir;
File dir = new File(saveDir);
if (!dir.exists()) {
dir.mkdirs();
}
String fullPath = saveDir + File.separator + newFilename;
try {
file.transferTo(new File(fullPath));
return fullPath;
} catch (IOException e) {
throw new BizException("文件保存失败: " + e.getMessage());
}
}
@Override
@Transactional
public String submitTask(ModelTrainTask task) {
// 1. 初始化状态
task.setStatus("PENDING");
if (task.getTaskId() == null) {
task.setTaskId(UUID.randomUUID().toString());
}
this.save(task);
// 2. 异步调用 Python 训练
asyncCallTrain(task);
return task.getTaskId();
}
@Async
public void asyncCallTrain(ModelTrainTask task) {
try {
// 更新状态为 TRAINING
task.setStatus("TRAINING");
this.updateById(task);
// 构建请求参数
Map<String, Object> request = new HashMap<>();
request.put("task_id", task.getTaskId());
request.put("algorithm_type", task.getAlgorithmType());
request.put("device_type", task.getDeviceType());
request.put("dataset_path", task.getDatasetPath());
request.put("hyperparameters", task.getTrainParams());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
String url = pythonApiUrl + "/v1/train";
// 这里假设 Python 服务会立即返回启动后台线程
// 如果 Python 服务是阻塞的这里的 Async 会起作用
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
if (response.getStatusCode().is2xxSuccessful()) {
// 调用成功等待后续轮询状态
System.out.println("训练任务提交成功: " + task.getTaskId());
} else {
task.setStatus("FAILED");
task.setErrorLog("提交训练任务失败: " + response.getStatusCode());
this.updateById(task);
}
} catch (Exception e) {
task.setStatus("FAILED");
task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
this.updateById(task);
e.printStackTrace();
}
}
@Override
public ModelTrainTask syncTaskStatus(String taskId) {
ModelTrainTask task = this.getById(taskId);
if (task == null) {
throw new BizException("任务不存在");
}
// 只有在 TRAINING 状态才去查询 Python 服务
if ("TRAINING".equals(task.getStatus()) || "PENDING".equals(task.getStatus())) {
try {
String url = pythonApiUrl + "/v1/train/status/" + taskId;
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class);
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
Map<String, Object> body = response.getBody();
String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED
if (status != null) {
task.setStatus(status);
if ("SUCCESS".equals(status)) {
task.setModelOutputPath((String) body.get("model_path"));
task.setMetrics((Map<String, Object>) body.get("metrics"));
task.setFeatureMapSnapshot((Map<String, Object>) body.get("feature_map"));
task.setMetricsImagePath((String) body.get("metrics_image"));
} else if ("FAILED".equals(status)) {
task.setErrorLog((String) body.get("error"));
} else if ("TRAINING".equals(status)) {
// 可以更新进度或其他中间指标
if (body.containsKey("metrics")) {
task.setMetrics((Map<String, Object>) body.get("metrics"));
}
}
this.updateById(task);
}
}
} catch (Exception e) {
// 查询失败暂不更新状态或者记录日志
System.err.println("同步任务状态失败: " + e.getMessage());
}
}
return task;
}
@Override
@Transactional
public boolean publishModel(String taskId, String versionTag) {
ModelTrainTask task = this.getById(taskId);
if (task == null) {
throw new BizException("任务不存在");
}
if (!"SUCCESS".equals(task.getStatus())) {
throw new BizException("任务未完成或失败,无法发布");
}
// 检查版本号唯一性
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
.eq("algorithm_type", task.getAlgorithmType())
.eq("device_type", task.getDeviceType())
.eq("version_tag", versionTag));
if (count > 0) {
throw new BizException("版本号已存在");
}
// 创建正式模型记录
AlgorithmModel model = new AlgorithmModel();
model.setAlgorithmType(task.getAlgorithmType());
model.setDeviceType(task.getDeviceType());
model.setVersionTag(versionTag);
model.setModelPath(task.getModelOutputPath()); // 这里简化处理直接引用临时路径实际生产建议移动文件到正式目录
model.setMetricsImagePath(task.getMetricsImagePath());
model.setTrainedAt(LocalDateTime.now());
model.setIsCurrent(0); // 默认不激活
try {
if (task.getFeatureMapSnapshot() != null) {
model.setFeatureMapSnapshot(objectMapper.writeValueAsString(task.getFeatureMapSnapshot()));
} else {
model.setFeatureMapSnapshot("{}");
}
if (task.getMetrics() != null) {
model.setMetrics(objectMapper.writeValueAsString(task.getMetrics()));
}
} catch (JsonProcessingException e) {
throw new BizException("JSON 序列化失败");
}
return algorithmModelService.save(model);
}
}