模型训练添加特征映射配置 环节

This commit is contained in:
wanxiaoli 2026-04-09 11:06:53 +08:00
parent 4bad0bc52d
commit 52bb4f69c5
5 changed files with 32 additions and 42 deletions

View File

@ -74,6 +74,7 @@ public class ModelTrainController {
@RequestPart(value = "file", required = false) MultipartFile file) { @RequestPart(value = "file", required = false) MultipartFile file) {
try { try {
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class); ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
// System.out.println("提交任务参数: " + task.toString());
// 如果上传了文件优先使用文件路径 // 如果上传了文件优先使用文件路径
if (file != null && !file.isEmpty()) { if (file != null && !file.isEmpty()) {
@ -87,7 +88,7 @@ public class ModelTrainController {
} }
String taskId = modelTrainService.submitTask(task); String taskId = modelTrainService.submitTask(task);
System.out.println("提交任务成功任务ID: " + taskId); // System.out.println("提交任务成功任务ID: " + taskId);
return ResponseResult.successData(taskId); return ResponseResult.successData(taskId);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
return ResponseResult.error("参数解析失败: " + e.getMessage()); return ResponseResult.error("参数解析失败: " + e.getMessage());

View File

@ -1,13 +1,12 @@
package com.yfd.business.css.domain; package com.yfd.business.css.domain;
import com.baomidou.mybatisplus.annotation.*; import com.baomidou.mybatisplus.annotation.*;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data; import lombok.Data;
import java.io.Serializable; import java.io.Serializable;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.Map;
@Data @Data
@TableName(value = "model_train_task", autoResultMap = true) @TableName(value = "model_train_task", autoResultMap = true)
@ -28,9 +27,11 @@ public class ModelTrainTask implements Serializable {
private String deviceType; private String deviceType;
@TableField("dataset_path") @TableField("dataset_path")
@JsonAlias("dataset_path")
private String datasetPath; private String datasetPath;
@TableField(value = "train_params") @TableField(value = "train_params")
@JsonAlias("train_params")
private String trainParams; // JSON String private String trainParams; // JSON String
@TableField("status") @TableField("status")
@ -52,12 +53,9 @@ public class ModelTrainTask implements Serializable {
private String errorLog; private String errorLog;
@TableField(value = "feature_map_config") @TableField(value = "feature_map_config")
@JsonIgnore @JsonAlias({"feature_map_config","featureMapConfig"})
private String featureMapConfigJson; @JsonProperty("feature_map_config")
private String featureMapConfig;
@TableField(exist = false)
@JsonProperty("featureMapConfig")
private Map<String, Object> featureMapConfig;
@TableField(value = "created_at") @TableField(value = "created_at")
private LocalDateTime createdAt; private LocalDateTime createdAt;

View File

@ -20,6 +20,7 @@ public class TrainWebSocketService {
* @param data 状态数据 * @param data 状态数据
*/ */
public void sendTrainStatus(String taskId, Map<String, Object> data) { public void sendTrainStatus(String taskId, Map<String, Object> data) {
// System.out.println("sendTrainStatus, taskId="+taskId+", data="+data.toString());
// 1. 细粒度推送供详情页使用 // 1. 细粒度推送供详情页使用
String specificDestination = "/topic/train-status/" + taskId; String specificDestination = "/topic/train-status/" + taskId;
messagingTemplate.convertAndSend(specificDestination, data); messagingTemplate.convertAndSend(specificDestination, data);

View File

@ -146,19 +146,11 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (task.getTaskId() == null) { if (task.getTaskId() == null) {
task.setTaskId(UUID.randomUUID().toString()); task.setTaskId(UUID.randomUUID().toString());
} }
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isEmpty()) { if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
Map<String, Object> params = new HashMap<>();
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
try {
params.putAll(objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {}));
} catch (Exception ignored) {
}
}
params.put("feature_map", task.getFeatureMapConfig());
try { try {
task.setTrainParams(objectMapper.writeValueAsString(params)); objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
task.setFeatureMapConfigJson(objectMapper.writeValueAsString(task.getFeatureMapConfig())); } catch (Exception ignored) {
} catch (JsonProcessingException ignored) { throw new BizException("feature_map_config JSON解析失败");
} }
} }
this.save(task); this.save(task);
@ -185,17 +177,19 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
request.put("model_dir", modelPath); request.put("model_dir", modelPath);
Map<String, Object> fmMap = null;
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
try {
fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
} catch (Exception ignored) {
}
}
// 解析 hyperparameters (String -> Map) // 解析 hyperparameters (String -> Map)
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) { if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
try { try {
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {}); Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
request.put("hyperparameters", params); request.put("hyperparameters", params);
if (task.getFeatureMapConfig() == null) {
Object fm = params.get("feature_map");
if (fm instanceof Map) {
request.put("feature_map", fm);
}
}
} catch (Exception e) { } catch (Exception e) {
log.error("解析训练参数失败,将作为原始字符串发送: {}", e.getMessage()); log.error("解析训练参数失败,将作为原始字符串发送: {}", e.getMessage());
request.put("hyperparameters", task.getTrainParams()); request.put("hyperparameters", task.getTrainParams());
@ -203,9 +197,12 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
} else { } else {
request.put("hyperparameters", new HashMap<>()); request.put("hyperparameters", new HashMap<>());
} }
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isEmpty()) {
request.put("feature_map", task.getFeatureMapConfig()); if (fmMap == null || fmMap.isEmpty()) {
throw new BizException("feature_map_config不能为空");
} }
request.put("feature_map_config", fmMap);
// System.out.println("request="+request.toString());
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); headers.setContentType(MediaType.APPLICATION_JSON);
@ -323,15 +320,6 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (task == null) { if (task == null) {
throw new BizException("任务不存在"); throw new BizException("任务不存在");
} }
if ((task.getFeatureMapConfig() == null || task.getFeatureMapConfig().isEmpty())
&& task.getFeatureMapConfigJson() != null
&& !task.getFeatureMapConfigJson().isBlank()) {
try {
Map<String, Object> fm = objectMapper.readValue(task.getFeatureMapConfigJson(), new TypeReference<Map<String, Object>>() {});
task.setFeatureMapConfig(fm);
} catch (Exception ignored) {
}
}
// 由于改为 WebSocket 异步推送这里简化为直接查库返回 // 由于改为 WebSocket 异步推送这里简化为直接查库返回
return task; return task;
} }

View File

@ -1202,7 +1202,7 @@ public class ProjectServiceImpl
.orderByAsc("step")); .orderByAsc("step"));
Sheet s1 = wb.createSheet("projects"); Sheet s1 = wb.createSheet("projects");
String[] h1 = {"project_id","code","name","description","topology","created_at","updated_at","modifier"}; String[] h1 = {"project_id","code","name","description","topology","visibility","creator","created_at","updated_at","modifier"};
int r = 0; Row rh1 = s1.createRow(r++); for (int i=0;i<h1.length;i++) rh1.createCell(i).setCellValue(h1[i]); int r = 0; Row rh1 = s1.createRow(r++); for (int i=0;i<h1.length;i++) rh1.createCell(i).setCellValue(h1[i]);
for (Project p : projects) { for (Project p : projects) {
Row row = s1.createRow(r++); Row row = s1.createRow(r++);
@ -1211,9 +1211,11 @@ public class ProjectServiceImpl
row.createCell(2).setCellValue(p.getName()==null?"":p.getName()); row.createCell(2).setCellValue(p.getName()==null?"":p.getName());
row.createCell(3).setCellValue(p.getDescription()==null?"":p.getDescription()); row.createCell(3).setCellValue(p.getDescription()==null?"":p.getDescription());
row.createCell(4).setCellValue(p.getTopology()==null?"":p.getTopology()); row.createCell(4).setCellValue(p.getTopology()==null?"":p.getTopology());
row.createCell(5).setCellValue(p.getCreatedAt()==null?"":fmt.format(p.getCreatedAt())); row.createCell(5).setCellValue(p.getVisibility()==null?"":p.getVisibility());
row.createCell(6).setCellValue(p.getUpdatedAt()==null?"":fmt.format(p.getUpdatedAt())); row.createCell(6).setCellValue(p.getCreator()==null?"":p.getCreator());
row.createCell(7).setCellValue(p.getModifier()==null?"":p.getModifier()); row.createCell(7).setCellValue(p.getCreatedAt()==null?"":fmt.format(p.getCreatedAt()));
row.createCell(8).setCellValue(p.getUpdatedAt()==null?"":fmt.format(p.getUpdatedAt()));
row.createCell(9).setCellValue(p.getModifier()==null?"":p.getModifier());
} }
for (int i=0;i<h1.length;i++) s1.autoSizeColumn(i); for (int i=0;i<h1.length;i++) s1.autoSizeColumn(i);