模型训练接口修改
This commit is contained in:
parent
15d442d9ec
commit
6ca6516b12
@ -1,25 +1,32 @@
|
||||
package com.yfd.business.css.config;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.DbType;
|
||||
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
|
||||
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
|
||||
import org.apache.ibatis.session.SqlSessionFactory;
|
||||
import org.apache.ibatis.plugin.Interceptor;
|
||||
import org.mybatis.spring.SqlSessionTemplate;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
|
||||
import jakarta.annotation.Resource;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
@Configuration
|
||||
public class MybatisConfig {
|
||||
|
||||
@Resource
|
||||
private MybatisPlusInterceptor mybatisPlusInterceptor;
|
||||
@Bean
|
||||
@ConditionalOnMissingBean(MybatisPlusInterceptor.class)
|
||||
public MybatisPlusInterceptor mybatisPlusInterceptor() {
|
||||
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
|
||||
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
|
||||
return interceptor;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
|
||||
public SqlSessionFactory sqlSessionFactory(DataSource dataSource, MybatisPlusInterceptor mybatisPlusInterceptor) throws Exception {
|
||||
MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
|
||||
factoryBean.setDataSource(dataSource);
|
||||
factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver()
|
||||
@ -34,27 +41,3 @@ public class MybatisConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// @Configuration
|
||||
// public class MybatisConfig {
|
||||
|
||||
// @Bean
|
||||
// public MybatisPlusInterceptor mybatisPlusInterceptor() {
|
||||
// return new MybatisPlusInterceptor();
|
||||
// }
|
||||
|
||||
// @Bean
|
||||
// public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
|
||||
// MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
|
||||
// factoryBean.setDataSource(dataSource);
|
||||
// factoryBean.setMapperLocations(
|
||||
// new PathMatchingResourcePatternResolver()
|
||||
// .getResources("classpath*:/mapper/**/*.xml"));
|
||||
// return factoryBean.getObject();
|
||||
// }
|
||||
|
||||
// @Bean
|
||||
// public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
|
||||
// return new SqlSessionTemplate(sqlSessionFactory);
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ public class ModelTrainController {
|
||||
}
|
||||
|
||||
String taskId = modelTrainService.submitTask(task);
|
||||
System.out.println("提交任务成功,任务ID: " + taskId);
|
||||
return ResponseResult.successData(taskId);
|
||||
} catch (JsonProcessingException e) {
|
||||
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
package com.yfd.business.css.controller;
|
||||
|
||||
import com.yfd.business.css.build.SimBuilder;
|
||||
import com.yfd.business.css.domain.Scenario;
|
||||
import com.yfd.business.css.facade.SimDataFacade;
|
||||
import com.yfd.business.css.model.*;
|
||||
import com.yfd.business.css.service.MaterialService;
|
||||
import com.yfd.business.css.service.ScenarioService;
|
||||
import com.yfd.business.css.service.SimService;
|
||||
import com.yfd.business.css.service.SimInferService;
|
||||
import com.yfd.platform.config.ResponseResult;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@ -26,6 +29,7 @@ public class SimController {
|
||||
@Autowired private SimService simService;
|
||||
@Autowired private MaterialService materialService;
|
||||
@Autowired private SimInferService simInferService;
|
||||
@Autowired private ScenarioService scenarioService;
|
||||
|
||||
/**
|
||||
* 执行仿真计算
|
||||
@ -39,6 +43,13 @@ public class SimController {
|
||||
String scenarioId = (String) req.get("scenarioId");
|
||||
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
|
||||
|
||||
// 0. Update Status: 更新情景状态为进行中
|
||||
Scenario startScenario = new Scenario();
|
||||
startScenario.setScenarioId(scenarioId);
|
||||
startScenario.setStatus("1"); // 1: 进行中
|
||||
startScenario.setUpdatedAt(LocalDateTime.now());
|
||||
scenarioService.updateById(startScenario);
|
||||
|
||||
// 1. Load Data: 获取项目、设备和事件数据
|
||||
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
|
||||
|
||||
@ -50,10 +61,10 @@ public class SimController {
|
||||
// 3. Run Engine: 执行核心仿真计算,返回上下文结果
|
||||
SimContext ctx = simService.runSimulation(units, events, nodes, steps);
|
||||
|
||||
// 4. Async Infer: 异步执行推理并保存结果
|
||||
// 4. Async Infer: 异步执行推理并保存结果,完成后更新情景状态
|
||||
simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units);
|
||||
|
||||
// 5. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据
|
||||
// 6. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据
|
||||
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
|
||||
|
||||
return ResponseResult.successData(resultData);
|
||||
|
||||
@ -31,7 +31,7 @@ public class ModelTrainTask implements Serializable {
|
||||
private String trainParams; // JSON String
|
||||
|
||||
@TableField("status")
|
||||
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
||||
private String status; // Pending, Training, Success, Failed
|
||||
|
||||
@TableField(value = "metrics")
|
||||
private String metrics; // JSON String
|
||||
@ -48,9 +48,9 @@ public class ModelTrainTask implements Serializable {
|
||||
@TableField("error_log")
|
||||
private String errorLog;
|
||||
|
||||
@TableField(value = "created_at", fill = FieldFill.INSERT)
|
||||
@TableField(value = "created_at")
|
||||
private LocalDateTime createdAt;
|
||||
|
||||
@TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE)
|
||||
@TableField(value = "updated_at")
|
||||
private LocalDateTime updatedAt;
|
||||
}
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
package com.yfd.business.css.model;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class InferRequest {
|
||||
private String modelDir;
|
||||
|
||||
@JsonProperty("device_type")
|
||||
private String deviceType;
|
||||
|
||||
private List<Map<String, Object>> batch;
|
||||
private Map<String, Object> features;
|
||||
private Map<String, Object> meta;
|
||||
|
||||
@ -196,7 +196,7 @@ public class DeviceInferService {
|
||||
|
||||
// 合并的Python推理调用方法
|
||||
public InferResponse infer(InferRequest request) {
|
||||
String url = pythonInferUrl + "/v1/infer" ;
|
||||
String url = pythonInferUrl + "/v1/infer";
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
package com.yfd.business.css.service;
|
||||
|
||||
import com.yfd.business.css.domain.Scenario;
|
||||
import com.yfd.business.css.model.*;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -18,6 +20,9 @@ public class SimInferService {
|
||||
@Autowired
|
||||
private DeviceInferService deviceInferService;
|
||||
|
||||
@Autowired
|
||||
private ScenarioService scenarioService;
|
||||
|
||||
@Async
|
||||
public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List<SimUnit> units) {
|
||||
try {
|
||||
@ -57,7 +62,6 @@ public class SimInferService {
|
||||
if (deviceType == null || deviceType.isEmpty()) continue;
|
||||
|
||||
// 确保静态属性也包含在内(如果 SimContext 中没有,从 SimUnit 补充)
|
||||
// SimService 应该已经初始化了静态属性到 Context,但为了保险起见:
|
||||
for (Map.Entry<String, Double> staticProp : unit.staticProperties().entrySet()) {
|
||||
properties.putIfAbsent(staticProp.getKey(), staticProp.getValue());
|
||||
}
|
||||
@ -75,6 +79,8 @@ public class SimInferService {
|
||||
|
||||
if (groupedDevices.isEmpty()) {
|
||||
System.out.println("No device data found for inference.");
|
||||
// 即使没有数据,也应该更新状态为完成,或者视为正常结束
|
||||
updateScenarioStatus(scenarioId, "2");
|
||||
return;
|
||||
}
|
||||
|
||||
@ -82,9 +88,26 @@ public class SimInferService {
|
||||
// 复用现有的 processDeviceInference 方法,它处理了按类型分组的数据
|
||||
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
|
||||
|
||||
// 4. 更新状态为已完成
|
||||
updateScenarioStatus(scenarioId, "2");
|
||||
|
||||
} catch (Exception e) {
|
||||
System.err.println("Async inference failed: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
// 5. 更新状态为失败 (假设 3 代表失败)
|
||||
updateScenarioStatus(scenarioId, "3");
|
||||
}
|
||||
}
|
||||
|
||||
private void updateScenarioStatus(String scenarioId, String status) {
|
||||
try {
|
||||
Scenario scenario = new Scenario();
|
||||
scenario.setScenarioId(scenarioId);
|
||||
scenario.setStatus(status);
|
||||
scenario.setUpdatedAt(LocalDateTime.now());
|
||||
scenarioService.updateById(scenario);
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to update scenario status to " + status + ": " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,11 +25,16 @@ import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@Service
|
||||
public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, ModelTrainTask> implements ModelTrainService {
|
||||
@ -37,6 +42,9 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
@Value("${file-space.upload-path:./data/uploads/}")
|
||||
private String uploadPath;
|
||||
|
||||
@Value("${file-space.model-path:E:/python_coding/keffCenter/models/}")
|
||||
private String modelPath;
|
||||
|
||||
@Value("${python.api.url:http://localhost:8000}")
|
||||
private String pythonApiUrl;
|
||||
|
||||
@ -49,6 +57,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
@Autowired
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
private static final Pattern VERSION_TAG_PATTERN = Pattern.compile("^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$");
|
||||
|
||||
@Override
|
||||
public String uploadDataset(MultipartFile file) {
|
||||
if (file.isEmpty()) {
|
||||
@ -100,7 +110,7 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
@Transactional
|
||||
public String submitTask(ModelTrainTask task) {
|
||||
// 1. 初始化状态
|
||||
task.setStatus("PENDING");
|
||||
task.setStatus("Pending");
|
||||
if (task.getTaskId() == null) {
|
||||
task.setTaskId(UUID.randomUUID().toString());
|
||||
}
|
||||
@ -115,8 +125,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
@Async
|
||||
public void asyncCallTrain(ModelTrainTask task) {
|
||||
try {
|
||||
// 更新状态为 TRAINING
|
||||
task.setStatus("TRAINING");
|
||||
// 更新状态为 Training
|
||||
task.setStatus("Training");
|
||||
this.updateById(task);
|
||||
|
||||
// 构建请求参数
|
||||
@ -125,6 +135,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
request.put("algorithm_type", task.getAlgorithmType());
|
||||
request.put("device_type", task.getDeviceType());
|
||||
request.put("dataset_path", task.getDatasetPath());
|
||||
request.put("model_dir", modelPath);
|
||||
|
||||
|
||||
// 解析 hyperparameters (String -> Map)
|
||||
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
||||
@ -147,15 +159,39 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
||||
|
||||
if (response.getStatusCode().is2xxSuccessful()) {
|
||||
Map body = response.getBody();
|
||||
if (body != null) {
|
||||
try {
|
||||
System.out.println("训练服务响应: " + objectMapper.writeValueAsString(body));
|
||||
} catch (JsonProcessingException ignored) {
|
||||
System.out.println("训练服务响应: " + body);
|
||||
}
|
||||
Object codeObj = body.get("code");
|
||||
Integer code = null;
|
||||
if (codeObj instanceof Number) {
|
||||
code = ((Number) codeObj).intValue();
|
||||
} else if (codeObj instanceof String) {
|
||||
try {
|
||||
code = Integer.parseInt((String) codeObj);
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
}
|
||||
if (code != null && code != 0) {
|
||||
task.setStatus("Failed");
|
||||
task.setErrorLog("训练服务返回失败: " + body);
|
||||
this.updateById(task);
|
||||
return;
|
||||
}
|
||||
}
|
||||
System.out.println("训练任务提交成功: " + task.getTaskId());
|
||||
} else {
|
||||
task.setStatus("FAILED");
|
||||
task.setStatus("Failed");
|
||||
task.setErrorLog("提交训练任务失败: " + response.getStatusCode());
|
||||
this.updateById(task);
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
task.setStatus("FAILED");
|
||||
task.setStatus("Failed");
|
||||
task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
|
||||
this.updateById(task);
|
||||
e.printStackTrace();
|
||||
@ -169,36 +205,79 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
throw new BizException("任务不存在");
|
||||
}
|
||||
|
||||
// 只有在 TRAINING 状态才去查询 Python 服务
|
||||
if ("TRAINING".equals(task.getStatus()) || "PENDING".equals(task.getStatus())) {
|
||||
// 只有在 Training 或 Pending 状态才去查询 Python 服务
|
||||
if ("Training".equalsIgnoreCase(task.getStatus()) || "Pending".equalsIgnoreCase(task.getStatus())) {
|
||||
try {
|
||||
String url = pythonApiUrl + "/v1/train/status/" + taskId;
|
||||
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class);
|
||||
|
||||
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
|
||||
Map<String, Object> body = response.getBody();
|
||||
String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED
|
||||
Object codeObj = body.get("code");
|
||||
Integer code = null;
|
||||
if (codeObj instanceof Number) {
|
||||
code = ((Number) codeObj).intValue();
|
||||
} else if (codeObj instanceof String) {
|
||||
try {
|
||||
code = Integer.parseInt((String) codeObj);
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
}
|
||||
if (code != null && code != 0) {
|
||||
task.setStatus("Failed");
|
||||
task.setErrorLog(String.valueOf(body.get("msg")));
|
||||
this.updateById(task);
|
||||
return task;
|
||||
}
|
||||
|
||||
Object dataObj = body.get("data");
|
||||
if (!(dataObj instanceof Map)) {
|
||||
return task;
|
||||
}
|
||||
Map<String, Object> data = (Map<String, Object>) dataObj;
|
||||
String status = (String) data.get("status");
|
||||
|
||||
if (status != null) {
|
||||
// 转换状态为首字母大写
|
||||
if (status.equalsIgnoreCase("Success")) {
|
||||
status = "Success";
|
||||
} else if (status.equalsIgnoreCase("Failed")) {
|
||||
status = "Failed";
|
||||
} else if (status.equalsIgnoreCase("Training")) {
|
||||
status = "Training";
|
||||
} else if (status.equalsIgnoreCase("Pending")) {
|
||||
status = "Pending";
|
||||
}
|
||||
|
||||
task.setStatus(status);
|
||||
|
||||
if ("SUCCESS".equals(status)) {
|
||||
task.setModelOutputPath((String) body.get("model_path"));
|
||||
if ("Success".equals(status)) {
|
||||
String modelPathRaw = firstNonBlank(
|
||||
(String) data.get("model_relative_path"),
|
||||
(String) data.get("model_path_rel_project"),
|
||||
(String) data.get("model_path")
|
||||
);
|
||||
task.setModelOutputPath(normalizeModelPath(modelPathRaw));
|
||||
|
||||
// Map -> JSON String
|
||||
if (body.get("metrics") != null) {
|
||||
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
||||
if (data.get("metrics") != null) {
|
||||
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
|
||||
}
|
||||
if (body.get("feature_map") != null) {
|
||||
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map")));
|
||||
if (data.get("feature_map") != null) {
|
||||
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map")));
|
||||
}
|
||||
|
||||
task.setMetricsImagePath((String) body.get("metrics_image"));
|
||||
} else if ("FAILED".equals(status)) {
|
||||
task.setErrorLog((String) body.get("error"));
|
||||
} else if ("TRAINING".equals(status)) {
|
||||
if (body.get("metrics") != null) {
|
||||
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
||||
String metricsPathRaw = firstNonBlank(
|
||||
(String) data.get("metrics_image_relative_path"),
|
||||
(String) data.get("metrics_image_rel_project"),
|
||||
(String) data.get("metrics_image")
|
||||
);
|
||||
task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
|
||||
} else if ("Failed".equals(status)) {
|
||||
task.setErrorLog((String) data.get("error"));
|
||||
} else if ("Training".equals(status)) {
|
||||
if (data.get("metrics") != null) {
|
||||
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,10 +300,14 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
throw new BizException("任务不存在");
|
||||
}
|
||||
|
||||
if (!"SUCCESS".equals(task.getStatus())) {
|
||||
if (!"Success".equalsIgnoreCase(task.getStatus())) {
|
||||
throw new BizException("任务未完成或失败,无法发布");
|
||||
}
|
||||
|
||||
if (versionTag == null || !VERSION_TAG_PATTERN.matcher(versionTag).matches()) {
|
||||
throw new BizException("versionTag 格式不合法");
|
||||
}
|
||||
|
||||
// 检查版本号唯一性
|
||||
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
|
||||
.eq("algorithm_type", task.getAlgorithmType())
|
||||
@ -234,13 +317,36 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
throw new BizException("版本号已存在");
|
||||
}
|
||||
|
||||
String algorithmType = task.getAlgorithmType();
|
||||
String deviceType = task.getDeviceType();
|
||||
if (algorithmType == null || algorithmType.isBlank() || deviceType == null || deviceType.isBlank()) {
|
||||
throw new BizException("算法类型或设备类型不能为空");
|
||||
}
|
||||
|
||||
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
|
||||
Path versionDir = root.resolve(Paths.get(algorithmType, deviceType, versionTag)).normalize();
|
||||
if (!versionDir.startsWith(root)) {
|
||||
throw new BizException("发布目录非法");
|
||||
}
|
||||
try {
|
||||
Files.createDirectories(versionDir);
|
||||
} catch (IOException e) {
|
||||
throw new BizException("创建发布目录失败: " + e.getMessage());
|
||||
}
|
||||
|
||||
String publishedModelRelPath = copyToVersionDir(root, versionDir, task.getModelOutputPath());
|
||||
String publishedMetricsRelPath = null;
|
||||
if (task.getMetricsImagePath() != null && !task.getMetricsImagePath().isBlank()) {
|
||||
publishedMetricsRelPath = copyToVersionDir(root, versionDir, task.getMetricsImagePath());
|
||||
}
|
||||
|
||||
// 创建正式模型记录
|
||||
AlgorithmModel model = new AlgorithmModel();
|
||||
model.setAlgorithmType(task.getAlgorithmType());
|
||||
model.setDeviceType(task.getDeviceType());
|
||||
model.setVersionTag(versionTag);
|
||||
model.setModelPath(task.getModelOutputPath());
|
||||
model.setMetricsImagePath(task.getMetricsImagePath());
|
||||
model.setModelPath(publishedModelRelPath);
|
||||
model.setMetricsImagePath(publishedMetricsRelPath);
|
||||
model.setTrainedAt(LocalDateTime.now());
|
||||
model.setIsCurrent(0); // 默认不激活
|
||||
|
||||
@ -250,4 +356,76 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
|
||||
return algorithmModelService.save(model);
|
||||
}
|
||||
|
||||
private String firstNonBlank(String... values) {
|
||||
if (values == null) return null;
|
||||
for (String v : values) {
|
||||
if (v != null && !v.isBlank()) return v;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private String normalizeModelPath(String rawPath) {
|
||||
if (rawPath == null || rawPath.isBlank()) return null;
|
||||
String s = rawPath.trim().replace("\\", "/");
|
||||
while (s.contains("//")) {
|
||||
s = s.replace("//", "/");
|
||||
}
|
||||
if (s.startsWith("./")) {
|
||||
s = s.substring(2);
|
||||
}
|
||||
if (s.startsWith("models/")) {
|
||||
s = s.substring("models/".length());
|
||||
}
|
||||
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
|
||||
Path p;
|
||||
try {
|
||||
p = Paths.get(s.replace("/", File.separator));
|
||||
} catch (Exception e) {
|
||||
return s;
|
||||
}
|
||||
if (p.isAbsolute()) {
|
||||
Path abs = p.normalize();
|
||||
if (abs.startsWith(root)) {
|
||||
return root.relativize(abs).toString().replace("\\", "/");
|
||||
}
|
||||
return abs.toString().replace("\\", "/");
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
private String copyToVersionDir(Path root, Path versionDir, String sourcePath) {
|
||||
if (sourcePath == null || sourcePath.isBlank()) {
|
||||
throw new BizException("训练产物路径为空");
|
||||
}
|
||||
|
||||
String normalized = sourcePath.replace("\\", "/");
|
||||
if (normalized.startsWith("models/")) {
|
||||
normalized = normalized.substring("models/".length());
|
||||
}
|
||||
Path srcPath = Paths.get(normalized);
|
||||
Path src = srcPath.isAbsolute() ? srcPath.normalize() : root.resolve(normalized).normalize();
|
||||
if (!src.startsWith(root)) {
|
||||
throw new BizException("训练产物路径非法");
|
||||
}
|
||||
if (!Files.exists(src)) {
|
||||
throw new BizException("训练产物文件不存在: " + src);
|
||||
}
|
||||
|
||||
Path dest = versionDir.resolve(src.getFileName()).normalize();
|
||||
if (!dest.startsWith(versionDir)) {
|
||||
throw new BizException("目标路径非法");
|
||||
}
|
||||
|
||||
try {
|
||||
if (Files.exists(dest)) {
|
||||
throw new BizException("目标文件已存在: " + dest.getFileName());
|
||||
}
|
||||
Files.copy(src, dest, StandardCopyOption.COPY_ATTRIBUTES);
|
||||
} catch (IOException e) {
|
||||
throw new BizException("复制文件失败: " + e.getMessage());
|
||||
}
|
||||
|
||||
return root.relativize(dest).toString().replace("\\", "/");
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user