模型训练接口修改

This commit is contained in:
wanxiaoli 2026-03-16 11:00:27 +08:00
parent 15d442d9ec
commit 6ca6516b12
8 changed files with 258 additions and 58 deletions

View File

@ -1,25 +1,32 @@
package com.yfd.business.css.config;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.plugin.Interceptor;
import org.mybatis.spring.SqlSessionTemplate;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import jakarta.annotation.Resource;
import javax.sql.DataSource;
@Configuration
public class MybatisConfig {
@Resource
private MybatisPlusInterceptor mybatisPlusInterceptor;
@Bean
@ConditionalOnMissingBean(MybatisPlusInterceptor.class)
public MybatisPlusInterceptor mybatisPlusInterceptor() {
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
return interceptor;
}
@Bean
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
public SqlSessionFactory sqlSessionFactory(DataSource dataSource, MybatisPlusInterceptor mybatisPlusInterceptor) throws Exception {
MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
factoryBean.setDataSource(dataSource);
factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver()
@ -34,27 +41,3 @@ public class MybatisConfig {
}
}
// @Configuration
// public class MybatisConfig {
// @Bean
// public MybatisPlusInterceptor mybatisPlusInterceptor() {
// return new MybatisPlusInterceptor();
// }
// @Bean
// public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
// MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
// factoryBean.setDataSource(dataSource);
// factoryBean.setMapperLocations(
// new PathMatchingResourcePatternResolver()
// .getResources("classpath*:/mapper/**/*.xml"));
// return factoryBean.getObject();
// }
// @Bean
// public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
// return new SqlSessionTemplate(sqlSessionFactory);
// }
// }

View File

@ -53,6 +53,7 @@ public class ModelTrainController {
}
String taskId = modelTrainService.submitTask(task);
System.out.println("提交任务成功任务ID: " + taskId);
return ResponseResult.successData(taskId);
} catch (JsonProcessingException e) {
return ResponseResult.error("参数解析失败: " + e.getMessage());

View File

@ -1,15 +1,18 @@
package com.yfd.business.css.controller;
import com.yfd.business.css.build.SimBuilder;
import com.yfd.business.css.domain.Scenario;
import com.yfd.business.css.facade.SimDataFacade;
import com.yfd.business.css.model.*;
import com.yfd.business.css.service.MaterialService;
import com.yfd.business.css.service.ScenarioService;
import com.yfd.business.css.service.SimService;
import com.yfd.business.css.service.SimInferService;
import com.yfd.platform.config.ResponseResult;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
@ -26,6 +29,7 @@ public class SimController {
@Autowired private SimService simService;
@Autowired private MaterialService materialService;
@Autowired private SimInferService simInferService;
@Autowired private ScenarioService scenarioService;
/**
* 执行仿真计算
@ -39,6 +43,13 @@ public class SimController {
String scenarioId = (String) req.get("scenarioId");
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
// 0. Update Status: 更新情景状态为进行中
Scenario startScenario = new Scenario();
startScenario.setScenarioId(scenarioId);
startScenario.setStatus("1"); // 1: 进行中
startScenario.setUpdatedAt(LocalDateTime.now());
scenarioService.updateById(startScenario);
// 1. Load Data: 获取项目设备和事件数据
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
@ -50,10 +61,10 @@ public class SimController {
// 3. Run Engine: 执行核心仿真计算返回上下文结果
SimContext ctx = simService.runSimulation(units, events, nodes, steps);
// 4. Async Infer: 异步执行推理并保存结果
// 4. Async Infer: 异步执行推理并保存结果完成后更新情景状态
simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units);
// 5. Convert Result: 将仿真结果转换为前端友好的格式包含静态属性补全和元数据
// 6. Convert Result: 将仿真结果转换为前端友好的格式包含静态属性补全和元数据
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
return ResponseResult.successData(resultData);

View File

@ -31,7 +31,7 @@ public class ModelTrainTask implements Serializable {
private String trainParams; // JSON String
@TableField("status")
private String status; // PENDING, TRAINING, SUCCESS, FAILED
private String status; // Pending, Training, Success, Failed
@TableField(value = "metrics")
private String metrics; // JSON String
@ -48,9 +48,9 @@ public class ModelTrainTask implements Serializable {
@TableField("error_log")
private String errorLog;
@TableField(value = "created_at", fill = FieldFill.INSERT)
@TableField(value = "created_at")
private LocalDateTime createdAt;
@TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE)
@TableField(value = "updated_at")
private LocalDateTime updatedAt;
}

View File

@ -1,11 +1,15 @@
package com.yfd.business.css.model;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Map;
public class InferRequest {
private String modelDir;
@JsonProperty("device_type")
private String deviceType;
private List<Map<String, Object>> batch;
private Map<String, Object> features;
private Map<String, Object> meta;

View File

@ -196,7 +196,7 @@ public class DeviceInferService {
// 合并的Python推理调用方法
public InferResponse infer(InferRequest request) {
String url = pythonInferUrl + "/v1/infer" ;
String url = pythonInferUrl + "/v1/infer";
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = new HttpHeaders();

View File

@ -1,10 +1,12 @@
package com.yfd.business.css.service;
import com.yfd.business.css.domain.Scenario;
import com.yfd.business.css.model.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.time.LocalDateTime;
import java.util.*;
import java.util.stream.Collectors;
@ -17,6 +19,9 @@ public class SimInferService {
@Autowired
private DeviceInferService deviceInferService;
@Autowired
private ScenarioService scenarioService;
@Async
public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List<SimUnit> units) {
@ -57,7 +62,6 @@ public class SimInferService {
if (deviceType == null || deviceType.isEmpty()) continue;
// 确保静态属性也包含在内如果 SimContext 中没有 SimUnit 补充
// SimService 应该已经初始化了静态属性到 Context但为了保险起见
for (Map.Entry<String, Double> staticProp : unit.staticProperties().entrySet()) {
properties.putIfAbsent(staticProp.getKey(), staticProp.getValue());
}
@ -75,16 +79,35 @@ public class SimInferService {
if (groupedDevices.isEmpty()) {
System.out.println("No device data found for inference.");
// 即使没有数据也应该更新状态为完成或者视为正常结束
updateScenarioStatus(scenarioId, "2");
return;
}
// 3. 调用 DeviceInferService 进行推理和入库
// 复用现有的 processDeviceInference 方法它处理了按类型分组的数据
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
// 4. 更新状态为已完成
updateScenarioStatus(scenarioId, "2");
} catch (Exception e) {
System.err.println("Async inference failed: " + e.getMessage());
e.printStackTrace();
// 5. 更新状态为失败 (假设 3 代表失败)
updateScenarioStatus(scenarioId, "3");
}
}
private void updateScenarioStatus(String scenarioId, String status) {
try {
Scenario scenario = new Scenario();
scenario.setScenarioId(scenarioId);
scenario.setStatus(status);
scenario.setUpdatedAt(LocalDateTime.now());
scenarioService.updateById(scenario);
} catch (Exception e) {
System.err.println("Failed to update scenario status to " + status + ": " + e.getMessage());
}
}
}

View File

@ -25,17 +25,25 @@ import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.regex.Pattern;
@Service
public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, ModelTrainTask> implements ModelTrainService {
@Value("${file-space.upload-path:./data/uploads/}")
private String uploadPath;
@Value("${file-space.model-path:E:/python_coding/keffCenter/models/}")
private String modelPath;
@Value("${python.api.url:http://localhost:8000}")
private String pythonApiUrl;
@ -48,6 +56,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
@Autowired
private ObjectMapper objectMapper;
private static final Pattern VERSION_TAG_PATTERN = Pattern.compile("^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$");
@Override
public String uploadDataset(MultipartFile file) {
@ -100,7 +110,7 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
@Transactional
public String submitTask(ModelTrainTask task) {
// 1. 初始化状态
task.setStatus("PENDING");
task.setStatus("Pending");
if (task.getTaskId() == null) {
task.setTaskId(UUID.randomUUID().toString());
}
@ -115,8 +125,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
@Async
public void asyncCallTrain(ModelTrainTask task) {
try {
// 更新状态为 TRAINING
task.setStatus("TRAINING");
// 更新状态为 Training
task.setStatus("Training");
this.updateById(task);
// 构建请求参数
@ -125,6 +135,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
request.put("algorithm_type", task.getAlgorithmType());
request.put("device_type", task.getDeviceType());
request.put("dataset_path", task.getDatasetPath());
request.put("model_dir", modelPath);
// 解析 hyperparameters (String -> Map)
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
@ -147,15 +159,39 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
if (response.getStatusCode().is2xxSuccessful()) {
Map body = response.getBody();
if (body != null) {
try {
System.out.println("训练服务响应: " + objectMapper.writeValueAsString(body));
} catch (JsonProcessingException ignored) {
System.out.println("训练服务响应: " + body);
}
Object codeObj = body.get("code");
Integer code = null;
if (codeObj instanceof Number) {
code = ((Number) codeObj).intValue();
} else if (codeObj instanceof String) {
try {
code = Integer.parseInt((String) codeObj);
} catch (Exception ignored) {
}
}
if (code != null && code != 0) {
task.setStatus("Failed");
task.setErrorLog("训练服务返回失败: " + body);
this.updateById(task);
return;
}
}
System.out.println("训练任务提交成功: " + task.getTaskId());
} else {
task.setStatus("FAILED");
task.setStatus("Failed");
task.setErrorLog("提交训练任务失败: " + response.getStatusCode());
this.updateById(task);
}
} catch (Exception e) {
task.setStatus("FAILED");
task.setStatus("Failed");
task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
this.updateById(task);
e.printStackTrace();
@ -169,36 +205,79 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
throw new BizException("任务不存在");
}
// 只有在 TRAINING 状态才去查询 Python 服务
if ("TRAINING".equals(task.getStatus()) || "PENDING".equals(task.getStatus())) {
// 只有在 Training Pending 状态才去查询 Python 服务
if ("Training".equalsIgnoreCase(task.getStatus()) || "Pending".equalsIgnoreCase(task.getStatus())) {
try {
String url = pythonApiUrl + "/v1/train/status/" + taskId;
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class);
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
Map<String, Object> body = response.getBody();
String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED
Object codeObj = body.get("code");
Integer code = null;
if (codeObj instanceof Number) {
code = ((Number) codeObj).intValue();
} else if (codeObj instanceof String) {
try {
code = Integer.parseInt((String) codeObj);
} catch (Exception ignored) {
}
}
if (code != null && code != 0) {
task.setStatus("Failed");
task.setErrorLog(String.valueOf(body.get("msg")));
this.updateById(task);
return task;
}
Object dataObj = body.get("data");
if (!(dataObj instanceof Map)) {
return task;
}
Map<String, Object> data = (Map<String, Object>) dataObj;
String status = (String) data.get("status");
if (status != null) {
// 转换状态为首字母大写
if (status.equalsIgnoreCase("Success")) {
status = "Success";
} else if (status.equalsIgnoreCase("Failed")) {
status = "Failed";
} else if (status.equalsIgnoreCase("Training")) {
status = "Training";
} else if (status.equalsIgnoreCase("Pending")) {
status = "Pending";
}
task.setStatus(status);
if ("SUCCESS".equals(status)) {
task.setModelOutputPath((String) body.get("model_path"));
if ("Success".equals(status)) {
String modelPathRaw = firstNonBlank(
(String) data.get("model_relative_path"),
(String) data.get("model_path_rel_project"),
(String) data.get("model_path")
);
task.setModelOutputPath(normalizeModelPath(modelPathRaw));
// Map -> JSON String
if (body.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
if (data.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
}
if (body.get("feature_map") != null) {
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map")));
if (data.get("feature_map") != null) {
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map")));
}
task.setMetricsImagePath((String) body.get("metrics_image"));
} else if ("FAILED".equals(status)) {
task.setErrorLog((String) body.get("error"));
} else if ("TRAINING".equals(status)) {
if (body.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
String metricsPathRaw = firstNonBlank(
(String) data.get("metrics_image_relative_path"),
(String) data.get("metrics_image_rel_project"),
(String) data.get("metrics_image")
);
task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
} else if ("Failed".equals(status)) {
task.setErrorLog((String) data.get("error"));
} else if ("Training".equals(status)) {
if (data.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
}
}
@ -221,9 +300,13 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
throw new BizException("任务不存在");
}
if (!"SUCCESS".equals(task.getStatus())) {
if (!"Success".equalsIgnoreCase(task.getStatus())) {
throw new BizException("任务未完成或失败,无法发布");
}
if (versionTag == null || !VERSION_TAG_PATTERN.matcher(versionTag).matches()) {
throw new BizException("versionTag 格式不合法");
}
// 检查版本号唯一性
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
@ -233,14 +316,37 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (count > 0) {
throw new BizException("版本号已存在");
}
String algorithmType = task.getAlgorithmType();
String deviceType = task.getDeviceType();
if (algorithmType == null || algorithmType.isBlank() || deviceType == null || deviceType.isBlank()) {
throw new BizException("算法类型或设备类型不能为空");
}
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
Path versionDir = root.resolve(Paths.get(algorithmType, deviceType, versionTag)).normalize();
if (!versionDir.startsWith(root)) {
throw new BizException("发布目录非法");
}
try {
Files.createDirectories(versionDir);
} catch (IOException e) {
throw new BizException("创建发布目录失败: " + e.getMessage());
}
String publishedModelRelPath = copyToVersionDir(root, versionDir, task.getModelOutputPath());
String publishedMetricsRelPath = null;
if (task.getMetricsImagePath() != null && !task.getMetricsImagePath().isBlank()) {
publishedMetricsRelPath = copyToVersionDir(root, versionDir, task.getMetricsImagePath());
}
// 创建正式模型记录
AlgorithmModel model = new AlgorithmModel();
model.setAlgorithmType(task.getAlgorithmType());
model.setDeviceType(task.getDeviceType());
model.setVersionTag(versionTag);
model.setModelPath(task.getModelOutputPath());
model.setMetricsImagePath(task.getMetricsImagePath());
model.setModelPath(publishedModelRelPath);
model.setMetricsImagePath(publishedMetricsRelPath);
model.setTrainedAt(LocalDateTime.now());
model.setIsCurrent(0); // 默认不激活
@ -250,4 +356,76 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
return algorithmModelService.save(model);
}
private String firstNonBlank(String... values) {
if (values == null) return null;
for (String v : values) {
if (v != null && !v.isBlank()) return v;
}
return null;
}
private String normalizeModelPath(String rawPath) {
if (rawPath == null || rawPath.isBlank()) return null;
String s = rawPath.trim().replace("\\", "/");
while (s.contains("//")) {
s = s.replace("//", "/");
}
if (s.startsWith("./")) {
s = s.substring(2);
}
if (s.startsWith("models/")) {
s = s.substring("models/".length());
}
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
Path p;
try {
p = Paths.get(s.replace("/", File.separator));
} catch (Exception e) {
return s;
}
if (p.isAbsolute()) {
Path abs = p.normalize();
if (abs.startsWith(root)) {
return root.relativize(abs).toString().replace("\\", "/");
}
return abs.toString().replace("\\", "/");
}
return s;
}
private String copyToVersionDir(Path root, Path versionDir, String sourcePath) {
if (sourcePath == null || sourcePath.isBlank()) {
throw new BizException("训练产物路径为空");
}
String normalized = sourcePath.replace("\\", "/");
if (normalized.startsWith("models/")) {
normalized = normalized.substring("models/".length());
}
Path srcPath = Paths.get(normalized);
Path src = srcPath.isAbsolute() ? srcPath.normalize() : root.resolve(normalized).normalize();
if (!src.startsWith(root)) {
throw new BizException("训练产物路径非法");
}
if (!Files.exists(src)) {
throw new BizException("训练产物文件不存在: " + src);
}
Path dest = versionDir.resolve(src.getFileName()).normalize();
if (!dest.startsWith(versionDir)) {
throw new BizException("目标路径非法");
}
try {
if (Files.exists(dest)) {
throw new BizException("目标文件已存在: " + dest.getFileName());
}
Files.copy(src, dest, StandardCopyOption.COPY_ATTRIBUTES);
} catch (IOException e) {
throw new BizException("复制文件失败: " + e.getMessage());
}
return root.relativize(dest).toString().replace("\\", "/");
}
}