模型训练添加特征映射配置 环节

This commit is contained in:
wanxiaoli 2026-04-09 11:06:53 +08:00
parent 4bad0bc52d
commit 52bb4f69c5
5 changed files with 32 additions and 42 deletions

View File

@ -74,6 +74,7 @@ public class ModelTrainController {
@RequestPart(value = "file", required = false) MultipartFile file) {
try {
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
// System.out.println("提交任务参数: " + task.toString());
// 如果上传了文件优先使用文件路径
if (file != null && !file.isEmpty()) {
@ -87,7 +88,7 @@ public class ModelTrainController {
}
String taskId = modelTrainService.submitTask(task);
System.out.println("提交任务成功任务ID: " + taskId);
// System.out.println("提交任务成功任务ID: " + taskId);
return ResponseResult.successData(taskId);
} catch (JsonProcessingException e) {
return ResponseResult.error("参数解析失败: " + e.getMessage());

View File

@ -1,13 +1,12 @@
package com.yfd.business.css.domain;
import com.baomidou.mybatisplus.annotation.*;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.io.Serializable;
import java.time.LocalDateTime;
import java.util.Map;
@Data
@TableName(value = "model_train_task", autoResultMap = true)
@ -28,9 +27,11 @@ public class ModelTrainTask implements Serializable {
private String deviceType;
@TableField("dataset_path")
@JsonAlias("dataset_path")
private String datasetPath;
@TableField(value = "train_params")
@JsonAlias("train_params")
private String trainParams; // JSON String
@TableField("status")
@ -52,12 +53,9 @@ public class ModelTrainTask implements Serializable {
private String errorLog;
@TableField(value = "feature_map_config")
@JsonIgnore
private String featureMapConfigJson;
@TableField(exist = false)
@JsonProperty("featureMapConfig")
private Map<String, Object> featureMapConfig;
@JsonAlias({"feature_map_config","featureMapConfig"})
@JsonProperty("feature_map_config")
private String featureMapConfig;
@TableField(value = "created_at")
private LocalDateTime createdAt;

View File

@ -20,6 +20,7 @@ public class TrainWebSocketService {
* @param data 状态数据
*/
public void sendTrainStatus(String taskId, Map<String, Object> data) {
// System.out.println("sendTrainStatus, taskId="+taskId+", data="+data.toString());
// 1. 细粒度推送供详情页使用
String specificDestination = "/topic/train-status/" + taskId;
messagingTemplate.convertAndSend(specificDestination, data);

View File

@ -146,19 +146,11 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (task.getTaskId() == null) {
task.setTaskId(UUID.randomUUID().toString());
}
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isEmpty()) {
Map<String, Object> params = new HashMap<>();
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
try {
params.putAll(objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {}));
} catch (Exception ignored) {
}
}
params.put("feature_map", task.getFeatureMapConfig());
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
try {
task.setTrainParams(objectMapper.writeValueAsString(params));
task.setFeatureMapConfigJson(objectMapper.writeValueAsString(task.getFeatureMapConfig()));
} catch (JsonProcessingException ignored) {
objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
} catch (Exception ignored) {
throw new BizException("feature_map_config JSON解析失败");
}
}
this.save(task);
@ -185,17 +177,19 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
request.put("model_dir", modelPath);
Map<String, Object> fmMap = null;
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
try {
fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
} catch (Exception ignored) {
}
}
// 解析 hyperparameters (String -> Map)
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
try {
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
request.put("hyperparameters", params);
if (task.getFeatureMapConfig() == null) {
Object fm = params.get("feature_map");
if (fm instanceof Map) {
request.put("feature_map", fm);
}
}
} catch (Exception e) {
log.error("解析训练参数失败,将作为原始字符串发送: {}", e.getMessage());
request.put("hyperparameters", task.getTrainParams());
@ -203,9 +197,12 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
} else {
request.put("hyperparameters", new HashMap<>());
}
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isEmpty()) {
request.put("feature_map", task.getFeatureMapConfig());
if (fmMap == null || fmMap.isEmpty()) {
throw new BizException("feature_map_config不能为空");
}
request.put("feature_map_config", fmMap);
// System.out.println("request="+request.toString());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
@ -323,15 +320,6 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (task == null) {
throw new BizException("任务不存在");
}
if ((task.getFeatureMapConfig() == null || task.getFeatureMapConfig().isEmpty())
&& task.getFeatureMapConfigJson() != null
&& !task.getFeatureMapConfigJson().isBlank()) {
try {
Map<String, Object> fm = objectMapper.readValue(task.getFeatureMapConfigJson(), new TypeReference<Map<String, Object>>() {});
task.setFeatureMapConfig(fm);
} catch (Exception ignored) {
}
}
// 由于改为 WebSocket 异步推送这里简化为直接查库返回
return task;
}

View File

@ -1202,7 +1202,7 @@ public class ProjectServiceImpl
.orderByAsc("step"));
Sheet s1 = wb.createSheet("projects");
String[] h1 = {"project_id","code","name","description","topology","created_at","updated_at","modifier"};
String[] h1 = {"project_id","code","name","description","topology","visibility","creator","created_at","updated_at","modifier"};
int r = 0; Row rh1 = s1.createRow(r++); for (int i=0;i<h1.length;i++) rh1.createCell(i).setCellValue(h1[i]);
for (Project p : projects) {
Row row = s1.createRow(r++);
@ -1211,9 +1211,11 @@ public class ProjectServiceImpl
row.createCell(2).setCellValue(p.getName()==null?"":p.getName());
row.createCell(3).setCellValue(p.getDescription()==null?"":p.getDescription());
row.createCell(4).setCellValue(p.getTopology()==null?"":p.getTopology());
row.createCell(5).setCellValue(p.getCreatedAt()==null?"":fmt.format(p.getCreatedAt()));
row.createCell(6).setCellValue(p.getUpdatedAt()==null?"":fmt.format(p.getUpdatedAt()));
row.createCell(7).setCellValue(p.getModifier()==null?"":p.getModifier());
row.createCell(5).setCellValue(p.getVisibility()==null?"":p.getVisibility());
row.createCell(6).setCellValue(p.getCreator()==null?"":p.getCreator());
row.createCell(7).setCellValue(p.getCreatedAt()==null?"":fmt.format(p.getCreatedAt()));
row.createCell(8).setCellValue(p.getUpdatedAt()==null?"":fmt.format(p.getUpdatedAt()));
row.createCell(9).setCellValue(p.getModifier()==null?"":p.getModifier());
}
for (int i=0;i<h1.length;i++) s1.autoSizeColumn(i);