模型训练接口修改
This commit is contained in:
parent
15d442d9ec
commit
6ca6516b12
@ -1,25 +1,32 @@
|
|||||||
package com.yfd.business.css.config;
|
package com.yfd.business.css.config;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.annotation.DbType;
|
||||||
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
|
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
|
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
|
||||||
|
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
|
||||||
import org.apache.ibatis.session.SqlSessionFactory;
|
import org.apache.ibatis.session.SqlSessionFactory;
|
||||||
import org.apache.ibatis.plugin.Interceptor;
|
import org.apache.ibatis.plugin.Interceptor;
|
||||||
import org.mybatis.spring.SqlSessionTemplate;
|
import org.mybatis.spring.SqlSessionTemplate;
|
||||||
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
|
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
|
||||||
import jakarta.annotation.Resource;
|
|
||||||
|
|
||||||
import javax.sql.DataSource;
|
import javax.sql.DataSource;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
public class MybatisConfig {
|
public class MybatisConfig {
|
||||||
|
|
||||||
@Resource
|
@Bean
|
||||||
private MybatisPlusInterceptor mybatisPlusInterceptor;
|
@ConditionalOnMissingBean(MybatisPlusInterceptor.class)
|
||||||
|
public MybatisPlusInterceptor mybatisPlusInterceptor() {
|
||||||
|
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
|
||||||
|
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
|
||||||
|
return interceptor;
|
||||||
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
|
public SqlSessionFactory sqlSessionFactory(DataSource dataSource, MybatisPlusInterceptor mybatisPlusInterceptor) throws Exception {
|
||||||
MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
|
MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
|
||||||
factoryBean.setDataSource(dataSource);
|
factoryBean.setDataSource(dataSource);
|
||||||
factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver()
|
factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver()
|
||||||
@ -34,27 +41,3 @@ public class MybatisConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Configuration
|
|
||||||
// public class MybatisConfig {
|
|
||||||
|
|
||||||
// @Bean
|
|
||||||
// public MybatisPlusInterceptor mybatisPlusInterceptor() {
|
|
||||||
// return new MybatisPlusInterceptor();
|
|
||||||
// }
|
|
||||||
|
|
||||||
// @Bean
|
|
||||||
// public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
|
|
||||||
// MybatisSqlSessionFactoryBean factoryBean = new MybatisSqlSessionFactoryBean();
|
|
||||||
// factoryBean.setDataSource(dataSource);
|
|
||||||
// factoryBean.setMapperLocations(
|
|
||||||
// new PathMatchingResourcePatternResolver()
|
|
||||||
// .getResources("classpath*:/mapper/**/*.xml"));
|
|
||||||
// return factoryBean.getObject();
|
|
||||||
// }
|
|
||||||
|
|
||||||
// @Bean
|
|
||||||
// public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
|
|
||||||
// return new SqlSessionTemplate(sqlSessionFactory);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|||||||
@ -53,6 +53,7 @@ public class ModelTrainController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
String taskId = modelTrainService.submitTask(task);
|
String taskId = modelTrainService.submitTask(task);
|
||||||
|
System.out.println("提交任务成功,任务ID: " + taskId);
|
||||||
return ResponseResult.successData(taskId);
|
return ResponseResult.successData(taskId);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
package com.yfd.business.css.controller;
|
package com.yfd.business.css.controller;
|
||||||
|
|
||||||
import com.yfd.business.css.build.SimBuilder;
|
import com.yfd.business.css.build.SimBuilder;
|
||||||
|
import com.yfd.business.css.domain.Scenario;
|
||||||
import com.yfd.business.css.facade.SimDataFacade;
|
import com.yfd.business.css.facade.SimDataFacade;
|
||||||
import com.yfd.business.css.model.*;
|
import com.yfd.business.css.model.*;
|
||||||
import com.yfd.business.css.service.MaterialService;
|
import com.yfd.business.css.service.MaterialService;
|
||||||
|
import com.yfd.business.css.service.ScenarioService;
|
||||||
import com.yfd.business.css.service.SimService;
|
import com.yfd.business.css.service.SimService;
|
||||||
import com.yfd.business.css.service.SimInferService;
|
import com.yfd.business.css.service.SimInferService;
|
||||||
import com.yfd.platform.config.ResponseResult;
|
import com.yfd.platform.config.ResponseResult;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@ -26,6 +29,7 @@ public class SimController {
|
|||||||
@Autowired private SimService simService;
|
@Autowired private SimService simService;
|
||||||
@Autowired private MaterialService materialService;
|
@Autowired private MaterialService materialService;
|
||||||
@Autowired private SimInferService simInferService;
|
@Autowired private SimInferService simInferService;
|
||||||
|
@Autowired private ScenarioService scenarioService;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 执行仿真计算
|
* 执行仿真计算
|
||||||
@ -39,6 +43,13 @@ public class SimController {
|
|||||||
String scenarioId = (String) req.get("scenarioId");
|
String scenarioId = (String) req.get("scenarioId");
|
||||||
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
|
int steps = req.containsKey("steps") ? (int) req.get("steps") : 10;
|
||||||
|
|
||||||
|
// 0. Update Status: 更新情景状态为进行中
|
||||||
|
Scenario startScenario = new Scenario();
|
||||||
|
startScenario.setScenarioId(scenarioId);
|
||||||
|
startScenario.setStatus("1"); // 1: 进行中
|
||||||
|
startScenario.setUpdatedAt(LocalDateTime.now());
|
||||||
|
scenarioService.updateById(startScenario);
|
||||||
|
|
||||||
// 1. Load Data: 获取项目、设备和事件数据
|
// 1. Load Data: 获取项目、设备和事件数据
|
||||||
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
|
SimDataPackage data = simDataFacade.loadSimulationData(projectId, scenarioId);
|
||||||
|
|
||||||
@ -50,10 +61,10 @@ public class SimController {
|
|||||||
// 3. Run Engine: 执行核心仿真计算,返回上下文结果
|
// 3. Run Engine: 执行核心仿真计算,返回上下文结果
|
||||||
SimContext ctx = simService.runSimulation(units, events, nodes, steps);
|
SimContext ctx = simService.runSimulation(units, events, nodes, steps);
|
||||||
|
|
||||||
// 4. Async Infer: 异步执行推理并保存结果
|
// 4. Async Infer: 异步执行推理并保存结果,完成后更新情景状态
|
||||||
simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units);
|
simInferService.asyncInferAndSave(projectId, scenarioId, ctx, units);
|
||||||
|
|
||||||
// 5. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据
|
// 6. Convert Result: 将仿真结果转换为前端友好的格式,包含静态属性补全和元数据
|
||||||
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
|
Map<String, Object> resultData = SimResultConverter.toFrames(ctx, units, projectId, scenarioId);
|
||||||
|
|
||||||
return ResponseResult.successData(resultData);
|
return ResponseResult.successData(resultData);
|
||||||
|
|||||||
@ -31,7 +31,7 @@ public class ModelTrainTask implements Serializable {
|
|||||||
private String trainParams; // JSON String
|
private String trainParams; // JSON String
|
||||||
|
|
||||||
@TableField("status")
|
@TableField("status")
|
||||||
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
private String status; // Pending, Training, Success, Failed
|
||||||
|
|
||||||
@TableField(value = "metrics")
|
@TableField(value = "metrics")
|
||||||
private String metrics; // JSON String
|
private String metrics; // JSON String
|
||||||
@ -48,9 +48,9 @@ public class ModelTrainTask implements Serializable {
|
|||||||
@TableField("error_log")
|
@TableField("error_log")
|
||||||
private String errorLog;
|
private String errorLog;
|
||||||
|
|
||||||
@TableField(value = "created_at", fill = FieldFill.INSERT)
|
@TableField(value = "created_at")
|
||||||
private LocalDateTime createdAt;
|
private LocalDateTime createdAt;
|
||||||
|
|
||||||
@TableField(value = "updated_at", fill = FieldFill.INSERT_UPDATE)
|
@TableField(value = "updated_at")
|
||||||
private LocalDateTime updatedAt;
|
private LocalDateTime updatedAt;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
package com.yfd.business.css.model;
|
package com.yfd.business.css.model;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class InferRequest {
|
public class InferRequest {
|
||||||
private String modelDir;
|
private String modelDir;
|
||||||
|
|
||||||
|
@JsonProperty("device_type")
|
||||||
private String deviceType;
|
private String deviceType;
|
||||||
|
|
||||||
private List<Map<String, Object>> batch;
|
private List<Map<String, Object>> batch;
|
||||||
private Map<String, Object> features;
|
private Map<String, Object> features;
|
||||||
private Map<String, Object> meta;
|
private Map<String, Object> meta;
|
||||||
|
|||||||
@ -196,7 +196,7 @@ public class DeviceInferService {
|
|||||||
|
|
||||||
// 合并的Python推理调用方法
|
// 合并的Python推理调用方法
|
||||||
public InferResponse infer(InferRequest request) {
|
public InferResponse infer(InferRequest request) {
|
||||||
String url = pythonInferUrl + "/v1/infer" ;
|
String url = pythonInferUrl + "/v1/infer";
|
||||||
RestTemplate restTemplate = new RestTemplate();
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
|
|
||||||
HttpHeaders headers = new HttpHeaders();
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
package com.yfd.business.css.service;
|
package com.yfd.business.css.service;
|
||||||
|
|
||||||
|
import com.yfd.business.css.domain.Scenario;
|
||||||
import com.yfd.business.css.model.*;
|
import com.yfd.business.css.model.*;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.scheduling.annotation.Async;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@ -17,6 +19,9 @@ public class SimInferService {
|
|||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private DeviceInferService deviceInferService;
|
private DeviceInferService deviceInferService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ScenarioService scenarioService;
|
||||||
|
|
||||||
@Async
|
@Async
|
||||||
public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List<SimUnit> units) {
|
public void asyncInferAndSave(String projectId, String scenarioId, SimContext context, List<SimUnit> units) {
|
||||||
@ -57,7 +62,6 @@ public class SimInferService {
|
|||||||
if (deviceType == null || deviceType.isEmpty()) continue;
|
if (deviceType == null || deviceType.isEmpty()) continue;
|
||||||
|
|
||||||
// 确保静态属性也包含在内(如果 SimContext 中没有,从 SimUnit 补充)
|
// 确保静态属性也包含在内(如果 SimContext 中没有,从 SimUnit 补充)
|
||||||
// SimService 应该已经初始化了静态属性到 Context,但为了保险起见:
|
|
||||||
for (Map.Entry<String, Double> staticProp : unit.staticProperties().entrySet()) {
|
for (Map.Entry<String, Double> staticProp : unit.staticProperties().entrySet()) {
|
||||||
properties.putIfAbsent(staticProp.getKey(), staticProp.getValue());
|
properties.putIfAbsent(staticProp.getKey(), staticProp.getValue());
|
||||||
}
|
}
|
||||||
@ -75,16 +79,35 @@ public class SimInferService {
|
|||||||
|
|
||||||
if (groupedDevices.isEmpty()) {
|
if (groupedDevices.isEmpty()) {
|
||||||
System.out.println("No device data found for inference.");
|
System.out.println("No device data found for inference.");
|
||||||
|
// 即使没有数据,也应该更新状态为完成,或者视为正常结束
|
||||||
|
updateScenarioStatus(scenarioId, "2");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 调用 DeviceInferService 进行推理和入库
|
// 3. 调用 DeviceInferService 进行推理和入库
|
||||||
// 复用现有的 processDeviceInference 方法,它处理了按类型分组的数据
|
// 复用现有的 processDeviceInference 方法,它处理了按类型分组的数据
|
||||||
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
|
deviceInferService.processDeviceInference(projectId, scenarioId, groupedDevices);
|
||||||
|
|
||||||
|
// 4. 更新状态为已完成
|
||||||
|
updateScenarioStatus(scenarioId, "2");
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
System.err.println("Async inference failed: " + e.getMessage());
|
System.err.println("Async inference failed: " + e.getMessage());
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
|
// 5. 更新状态为失败 (假设 3 代表失败)
|
||||||
|
updateScenarioStatus(scenarioId, "3");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateScenarioStatus(String scenarioId, String status) {
|
||||||
|
try {
|
||||||
|
Scenario scenario = new Scenario();
|
||||||
|
scenario.setScenarioId(scenarioId);
|
||||||
|
scenario.setStatus(status);
|
||||||
|
scenario.setUpdatedAt(LocalDateTime.now());
|
||||||
|
scenarioService.updateById(scenario);
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.err.println("Failed to update scenario status to " + status + ": " + e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,17 +25,25 @@ import org.springframework.web.multipart.MultipartFile;
|
|||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.nio.file.StandardCopyOption;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.time.format.DateTimeFormatter;
|
import java.time.format.DateTimeFormatter;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, ModelTrainTask> implements ModelTrainService {
|
public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, ModelTrainTask> implements ModelTrainService {
|
||||||
|
|
||||||
@Value("${file-space.upload-path:./data/uploads/}")
|
@Value("${file-space.upload-path:./data/uploads/}")
|
||||||
private String uploadPath;
|
private String uploadPath;
|
||||||
|
|
||||||
|
@Value("${file-space.model-path:E:/python_coding/keffCenter/models/}")
|
||||||
|
private String modelPath;
|
||||||
|
|
||||||
@Value("${python.api.url:http://localhost:8000}")
|
@Value("${python.api.url:http://localhost:8000}")
|
||||||
private String pythonApiUrl;
|
private String pythonApiUrl;
|
||||||
@ -48,6 +56,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ObjectMapper objectMapper;
|
private ObjectMapper objectMapper;
|
||||||
|
|
||||||
|
private static final Pattern VERSION_TAG_PATTERN = Pattern.compile("^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$");
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String uploadDataset(MultipartFile file) {
|
public String uploadDataset(MultipartFile file) {
|
||||||
@ -100,7 +110,7 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
@Transactional
|
@Transactional
|
||||||
public String submitTask(ModelTrainTask task) {
|
public String submitTask(ModelTrainTask task) {
|
||||||
// 1. 初始化状态
|
// 1. 初始化状态
|
||||||
task.setStatus("PENDING");
|
task.setStatus("Pending");
|
||||||
if (task.getTaskId() == null) {
|
if (task.getTaskId() == null) {
|
||||||
task.setTaskId(UUID.randomUUID().toString());
|
task.setTaskId(UUID.randomUUID().toString());
|
||||||
}
|
}
|
||||||
@ -115,8 +125,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
@Async
|
@Async
|
||||||
public void asyncCallTrain(ModelTrainTask task) {
|
public void asyncCallTrain(ModelTrainTask task) {
|
||||||
try {
|
try {
|
||||||
// 更新状态为 TRAINING
|
// 更新状态为 Training
|
||||||
task.setStatus("TRAINING");
|
task.setStatus("Training");
|
||||||
this.updateById(task);
|
this.updateById(task);
|
||||||
|
|
||||||
// 构建请求参数
|
// 构建请求参数
|
||||||
@ -125,6 +135,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
request.put("algorithm_type", task.getAlgorithmType());
|
request.put("algorithm_type", task.getAlgorithmType());
|
||||||
request.put("device_type", task.getDeviceType());
|
request.put("device_type", task.getDeviceType());
|
||||||
request.put("dataset_path", task.getDatasetPath());
|
request.put("dataset_path", task.getDatasetPath());
|
||||||
|
request.put("model_dir", modelPath);
|
||||||
|
|
||||||
|
|
||||||
// 解析 hyperparameters (String -> Map)
|
// 解析 hyperparameters (String -> Map)
|
||||||
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
||||||
@ -147,15 +159,39 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
||||||
|
|
||||||
if (response.getStatusCode().is2xxSuccessful()) {
|
if (response.getStatusCode().is2xxSuccessful()) {
|
||||||
|
Map body = response.getBody();
|
||||||
|
if (body != null) {
|
||||||
|
try {
|
||||||
|
System.out.println("训练服务响应: " + objectMapper.writeValueAsString(body));
|
||||||
|
} catch (JsonProcessingException ignored) {
|
||||||
|
System.out.println("训练服务响应: " + body);
|
||||||
|
}
|
||||||
|
Object codeObj = body.get("code");
|
||||||
|
Integer code = null;
|
||||||
|
if (codeObj instanceof Number) {
|
||||||
|
code = ((Number) codeObj).intValue();
|
||||||
|
} else if (codeObj instanceof String) {
|
||||||
|
try {
|
||||||
|
code = Integer.parseInt((String) codeObj);
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (code != null && code != 0) {
|
||||||
|
task.setStatus("Failed");
|
||||||
|
task.setErrorLog("训练服务返回失败: " + body);
|
||||||
|
this.updateById(task);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
System.out.println("训练任务提交成功: " + task.getTaskId());
|
System.out.println("训练任务提交成功: " + task.getTaskId());
|
||||||
} else {
|
} else {
|
||||||
task.setStatus("FAILED");
|
task.setStatus("Failed");
|
||||||
task.setErrorLog("提交训练任务失败: " + response.getStatusCode());
|
task.setErrorLog("提交训练任务失败: " + response.getStatusCode());
|
||||||
this.updateById(task);
|
this.updateById(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
task.setStatus("FAILED");
|
task.setStatus("Failed");
|
||||||
task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
|
task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
|
||||||
this.updateById(task);
|
this.updateById(task);
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
@ -169,36 +205,79 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
throw new BizException("任务不存在");
|
throw new BizException("任务不存在");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只有在 TRAINING 状态才去查询 Python 服务
|
// 只有在 Training 或 Pending 状态才去查询 Python 服务
|
||||||
if ("TRAINING".equals(task.getStatus()) || "PENDING".equals(task.getStatus())) {
|
if ("Training".equalsIgnoreCase(task.getStatus()) || "Pending".equalsIgnoreCase(task.getStatus())) {
|
||||||
try {
|
try {
|
||||||
String url = pythonApiUrl + "/v1/train/status/" + taskId;
|
String url = pythonApiUrl + "/v1/train/status/" + taskId;
|
||||||
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class);
|
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class);
|
||||||
|
|
||||||
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
|
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
|
||||||
Map<String, Object> body = response.getBody();
|
Map<String, Object> body = response.getBody();
|
||||||
String status = (String) body.get("status"); // TRAINING, SUCCESS, FAILED
|
Object codeObj = body.get("code");
|
||||||
|
Integer code = null;
|
||||||
|
if (codeObj instanceof Number) {
|
||||||
|
code = ((Number) codeObj).intValue();
|
||||||
|
} else if (codeObj instanceof String) {
|
||||||
|
try {
|
||||||
|
code = Integer.parseInt((String) codeObj);
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (code != null && code != 0) {
|
||||||
|
task.setStatus("Failed");
|
||||||
|
task.setErrorLog(String.valueOf(body.get("msg")));
|
||||||
|
this.updateById(task);
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
|
||||||
|
Object dataObj = body.get("data");
|
||||||
|
if (!(dataObj instanceof Map)) {
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
Map<String, Object> data = (Map<String, Object>) dataObj;
|
||||||
|
String status = (String) data.get("status");
|
||||||
|
|
||||||
if (status != null) {
|
if (status != null) {
|
||||||
|
// 转换状态为首字母大写
|
||||||
|
if (status.equalsIgnoreCase("Success")) {
|
||||||
|
status = "Success";
|
||||||
|
} else if (status.equalsIgnoreCase("Failed")) {
|
||||||
|
status = "Failed";
|
||||||
|
} else if (status.equalsIgnoreCase("Training")) {
|
||||||
|
status = "Training";
|
||||||
|
} else if (status.equalsIgnoreCase("Pending")) {
|
||||||
|
status = "Pending";
|
||||||
|
}
|
||||||
|
|
||||||
task.setStatus(status);
|
task.setStatus(status);
|
||||||
|
|
||||||
if ("SUCCESS".equals(status)) {
|
if ("Success".equals(status)) {
|
||||||
task.setModelOutputPath((String) body.get("model_path"));
|
String modelPathRaw = firstNonBlank(
|
||||||
|
(String) data.get("model_relative_path"),
|
||||||
|
(String) data.get("model_path_rel_project"),
|
||||||
|
(String) data.get("model_path")
|
||||||
|
);
|
||||||
|
task.setModelOutputPath(normalizeModelPath(modelPathRaw));
|
||||||
|
|
||||||
// Map -> JSON String
|
// Map -> JSON String
|
||||||
if (body.get("metrics") != null) {
|
if (data.get("metrics") != null) {
|
||||||
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
|
||||||
}
|
}
|
||||||
if (body.get("feature_map") != null) {
|
if (data.get("feature_map") != null) {
|
||||||
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map")));
|
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map")));
|
||||||
}
|
}
|
||||||
|
|
||||||
task.setMetricsImagePath((String) body.get("metrics_image"));
|
String metricsPathRaw = firstNonBlank(
|
||||||
} else if ("FAILED".equals(status)) {
|
(String) data.get("metrics_image_relative_path"),
|
||||||
task.setErrorLog((String) body.get("error"));
|
(String) data.get("metrics_image_rel_project"),
|
||||||
} else if ("TRAINING".equals(status)) {
|
(String) data.get("metrics_image")
|
||||||
if (body.get("metrics") != null) {
|
);
|
||||||
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
|
||||||
|
} else if ("Failed".equals(status)) {
|
||||||
|
task.setErrorLog((String) data.get("error"));
|
||||||
|
} else if ("Training".equals(status)) {
|
||||||
|
if (data.get("metrics") != null) {
|
||||||
|
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,9 +300,13 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
throw new BizException("任务不存在");
|
throw new BizException("任务不存在");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!"SUCCESS".equals(task.getStatus())) {
|
if (!"Success".equalsIgnoreCase(task.getStatus())) {
|
||||||
throw new BizException("任务未完成或失败,无法发布");
|
throw new BizException("任务未完成或失败,无法发布");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (versionTag == null || !VERSION_TAG_PATTERN.matcher(versionTag).matches()) {
|
||||||
|
throw new BizException("versionTag 格式不合法");
|
||||||
|
}
|
||||||
|
|
||||||
// 检查版本号唯一性
|
// 检查版本号唯一性
|
||||||
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
|
long count = algorithmModelService.count(new QueryWrapper<AlgorithmModel>()
|
||||||
@ -233,14 +316,37 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
throw new BizException("版本号已存在");
|
throw new BizException("版本号已存在");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
String algorithmType = task.getAlgorithmType();
|
||||||
|
String deviceType = task.getDeviceType();
|
||||||
|
if (algorithmType == null || algorithmType.isBlank() || deviceType == null || deviceType.isBlank()) {
|
||||||
|
throw new BizException("算法类型或设备类型不能为空");
|
||||||
|
}
|
||||||
|
|
||||||
|
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
|
||||||
|
Path versionDir = root.resolve(Paths.get(algorithmType, deviceType, versionTag)).normalize();
|
||||||
|
if (!versionDir.startsWith(root)) {
|
||||||
|
throw new BizException("发布目录非法");
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
Files.createDirectories(versionDir);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new BizException("创建发布目录失败: " + e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
String publishedModelRelPath = copyToVersionDir(root, versionDir, task.getModelOutputPath());
|
||||||
|
String publishedMetricsRelPath = null;
|
||||||
|
if (task.getMetricsImagePath() != null && !task.getMetricsImagePath().isBlank()) {
|
||||||
|
publishedMetricsRelPath = copyToVersionDir(root, versionDir, task.getMetricsImagePath());
|
||||||
|
}
|
||||||
|
|
||||||
// 创建正式模型记录
|
// 创建正式模型记录
|
||||||
AlgorithmModel model = new AlgorithmModel();
|
AlgorithmModel model = new AlgorithmModel();
|
||||||
model.setAlgorithmType(task.getAlgorithmType());
|
model.setAlgorithmType(task.getAlgorithmType());
|
||||||
model.setDeviceType(task.getDeviceType());
|
model.setDeviceType(task.getDeviceType());
|
||||||
model.setVersionTag(versionTag);
|
model.setVersionTag(versionTag);
|
||||||
model.setModelPath(task.getModelOutputPath());
|
model.setModelPath(publishedModelRelPath);
|
||||||
model.setMetricsImagePath(task.getMetricsImagePath());
|
model.setMetricsImagePath(publishedMetricsRelPath);
|
||||||
model.setTrainedAt(LocalDateTime.now());
|
model.setTrainedAt(LocalDateTime.now());
|
||||||
model.setIsCurrent(0); // 默认不激活
|
model.setIsCurrent(0); // 默认不激活
|
||||||
|
|
||||||
@ -250,4 +356,76 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
|
|
||||||
return algorithmModelService.save(model);
|
return algorithmModelService.save(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String firstNonBlank(String... values) {
|
||||||
|
if (values == null) return null;
|
||||||
|
for (String v : values) {
|
||||||
|
if (v != null && !v.isBlank()) return v;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String normalizeModelPath(String rawPath) {
|
||||||
|
if (rawPath == null || rawPath.isBlank()) return null;
|
||||||
|
String s = rawPath.trim().replace("\\", "/");
|
||||||
|
while (s.contains("//")) {
|
||||||
|
s = s.replace("//", "/");
|
||||||
|
}
|
||||||
|
if (s.startsWith("./")) {
|
||||||
|
s = s.substring(2);
|
||||||
|
}
|
||||||
|
if (s.startsWith("models/")) {
|
||||||
|
s = s.substring("models/".length());
|
||||||
|
}
|
||||||
|
Path root = Paths.get(modelPath).toAbsolutePath().normalize();
|
||||||
|
Path p;
|
||||||
|
try {
|
||||||
|
p = Paths.get(s.replace("/", File.separator));
|
||||||
|
} catch (Exception e) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
if (p.isAbsolute()) {
|
||||||
|
Path abs = p.normalize();
|
||||||
|
if (abs.startsWith(root)) {
|
||||||
|
return root.relativize(abs).toString().replace("\\", "/");
|
||||||
|
}
|
||||||
|
return abs.toString().replace("\\", "/");
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String copyToVersionDir(Path root, Path versionDir, String sourcePath) {
|
||||||
|
if (sourcePath == null || sourcePath.isBlank()) {
|
||||||
|
throw new BizException("训练产物路径为空");
|
||||||
|
}
|
||||||
|
|
||||||
|
String normalized = sourcePath.replace("\\", "/");
|
||||||
|
if (normalized.startsWith("models/")) {
|
||||||
|
normalized = normalized.substring("models/".length());
|
||||||
|
}
|
||||||
|
Path srcPath = Paths.get(normalized);
|
||||||
|
Path src = srcPath.isAbsolute() ? srcPath.normalize() : root.resolve(normalized).normalize();
|
||||||
|
if (!src.startsWith(root)) {
|
||||||
|
throw new BizException("训练产物路径非法");
|
||||||
|
}
|
||||||
|
if (!Files.exists(src)) {
|
||||||
|
throw new BizException("训练产物文件不存在: " + src);
|
||||||
|
}
|
||||||
|
|
||||||
|
Path dest = versionDir.resolve(src.getFileName()).normalize();
|
||||||
|
if (!dest.startsWith(versionDir)) {
|
||||||
|
throw new BizException("目标路径非法");
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (Files.exists(dest)) {
|
||||||
|
throw new BizException("目标文件已存在: " + dest.getFileName());
|
||||||
|
}
|
||||||
|
Files.copy(src, dest, StandardCopyOption.COPY_ATTRIBUTES);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new BizException("复制文件失败: " + e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
return root.relativize(dest).toString().replace("\\", "/");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user