模型训练添加特征映射配置 环节
This commit is contained in:
parent
4bad0bc52d
commit
52bb4f69c5
@ -74,6 +74,7 @@ public class ModelTrainController {
|
|||||||
@RequestPart(value = "file", required = false) MultipartFile file) {
|
@RequestPart(value = "file", required = false) MultipartFile file) {
|
||||||
try {
|
try {
|
||||||
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
|
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
|
||||||
|
// System.out.println("提交任务参数: " + task.toString());
|
||||||
|
|
||||||
// 如果上传了文件,优先使用文件路径
|
// 如果上传了文件,优先使用文件路径
|
||||||
if (file != null && !file.isEmpty()) {
|
if (file != null && !file.isEmpty()) {
|
||||||
@ -87,7 +88,7 @@ public class ModelTrainController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
String taskId = modelTrainService.submitTask(task);
|
String taskId = modelTrainService.submitTask(task);
|
||||||
System.out.println("提交任务成功,任务ID: " + taskId);
|
// System.out.println("提交任务成功,任务ID: " + taskId);
|
||||||
return ResponseResult.successData(taskId);
|
return ResponseResult.successData(taskId);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
package com.yfd.business.css.domain;
|
package com.yfd.business.css.domain;
|
||||||
|
|
||||||
import com.baomidou.mybatisplus.annotation.*;
|
import com.baomidou.mybatisplus.annotation.*;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonAlias;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@TableName(value = "model_train_task", autoResultMap = true)
|
@TableName(value = "model_train_task", autoResultMap = true)
|
||||||
@ -28,9 +27,11 @@ public class ModelTrainTask implements Serializable {
|
|||||||
private String deviceType;
|
private String deviceType;
|
||||||
|
|
||||||
@TableField("dataset_path")
|
@TableField("dataset_path")
|
||||||
|
@JsonAlias("dataset_path")
|
||||||
private String datasetPath;
|
private String datasetPath;
|
||||||
|
|
||||||
@TableField(value = "train_params")
|
@TableField(value = "train_params")
|
||||||
|
@JsonAlias("train_params")
|
||||||
private String trainParams; // JSON String
|
private String trainParams; // JSON String
|
||||||
|
|
||||||
@TableField("status")
|
@TableField("status")
|
||||||
@ -52,12 +53,9 @@ public class ModelTrainTask implements Serializable {
|
|||||||
private String errorLog;
|
private String errorLog;
|
||||||
|
|
||||||
@TableField(value = "feature_map_config")
|
@TableField(value = "feature_map_config")
|
||||||
@JsonIgnore
|
@JsonAlias({"feature_map_config","featureMapConfig"})
|
||||||
private String featureMapConfigJson;
|
@JsonProperty("feature_map_config")
|
||||||
|
private String featureMapConfig;
|
||||||
@TableField(exist = false)
|
|
||||||
@JsonProperty("featureMapConfig")
|
|
||||||
private Map<String, Object> featureMapConfig;
|
|
||||||
|
|
||||||
@TableField(value = "created_at")
|
@TableField(value = "created_at")
|
||||||
private LocalDateTime createdAt;
|
private LocalDateTime createdAt;
|
||||||
|
|||||||
@ -20,6 +20,7 @@ public class TrainWebSocketService {
|
|||||||
* @param data 状态数据
|
* @param data 状态数据
|
||||||
*/
|
*/
|
||||||
public void sendTrainStatus(String taskId, Map<String, Object> data) {
|
public void sendTrainStatus(String taskId, Map<String, Object> data) {
|
||||||
|
// System.out.println("sendTrainStatus, taskId="+taskId+", data="+data.toString());
|
||||||
// 1. 细粒度推送(供详情页使用)
|
// 1. 细粒度推送(供详情页使用)
|
||||||
String specificDestination = "/topic/train-status/" + taskId;
|
String specificDestination = "/topic/train-status/" + taskId;
|
||||||
messagingTemplate.convertAndSend(specificDestination, data);
|
messagingTemplate.convertAndSend(specificDestination, data);
|
||||||
|
|||||||
@ -146,19 +146,11 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
if (task.getTaskId() == null) {
|
if (task.getTaskId() == null) {
|
||||||
task.setTaskId(UUID.randomUUID().toString());
|
task.setTaskId(UUID.randomUUID().toString());
|
||||||
}
|
}
|
||||||
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isEmpty()) {
|
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
|
||||||
Map<String, Object> params = new HashMap<>();
|
|
||||||
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
|
||||||
try {
|
|
||||||
params.putAll(objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {}));
|
|
||||||
} catch (Exception ignored) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params.put("feature_map", task.getFeatureMapConfig());
|
|
||||||
try {
|
try {
|
||||||
task.setTrainParams(objectMapper.writeValueAsString(params));
|
objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
|
||||||
task.setFeatureMapConfigJson(objectMapper.writeValueAsString(task.getFeatureMapConfig()));
|
} catch (Exception ignored) {
|
||||||
} catch (JsonProcessingException ignored) {
|
throw new BizException("feature_map_config JSON解析失败");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.save(task);
|
this.save(task);
|
||||||
@ -185,17 +177,19 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
request.put("model_dir", modelPath);
|
request.put("model_dir", modelPath);
|
||||||
|
|
||||||
|
|
||||||
|
Map<String, Object> fmMap = null;
|
||||||
|
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
|
||||||
|
try {
|
||||||
|
fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 解析 hyperparameters (String -> Map)
|
// 解析 hyperparameters (String -> Map)
|
||||||
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
||||||
try {
|
try {
|
||||||
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
|
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
|
||||||
request.put("hyperparameters", params);
|
request.put("hyperparameters", params);
|
||||||
if (task.getFeatureMapConfig() == null) {
|
|
||||||
Object fm = params.get("feature_map");
|
|
||||||
if (fm instanceof Map) {
|
|
||||||
request.put("feature_map", fm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("解析训练参数失败,将作为原始字符串发送: {}", e.getMessage());
|
log.error("解析训练参数失败,将作为原始字符串发送: {}", e.getMessage());
|
||||||
request.put("hyperparameters", task.getTrainParams());
|
request.put("hyperparameters", task.getTrainParams());
|
||||||
@ -203,9 +197,12 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
} else {
|
} else {
|
||||||
request.put("hyperparameters", new HashMap<>());
|
request.put("hyperparameters", new HashMap<>());
|
||||||
}
|
}
|
||||||
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isEmpty()) {
|
|
||||||
request.put("feature_map", task.getFeatureMapConfig());
|
if (fmMap == null || fmMap.isEmpty()) {
|
||||||
|
throw new BizException("feature_map_config不能为空");
|
||||||
}
|
}
|
||||||
|
request.put("feature_map_config", fmMap);
|
||||||
|
// System.out.println("request="+request.toString());
|
||||||
|
|
||||||
HttpHeaders headers = new HttpHeaders();
|
HttpHeaders headers = new HttpHeaders();
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
@ -323,15 +320,6 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
if (task == null) {
|
if (task == null) {
|
||||||
throw new BizException("任务不存在");
|
throw new BizException("任务不存在");
|
||||||
}
|
}
|
||||||
if ((task.getFeatureMapConfig() == null || task.getFeatureMapConfig().isEmpty())
|
|
||||||
&& task.getFeatureMapConfigJson() != null
|
|
||||||
&& !task.getFeatureMapConfigJson().isBlank()) {
|
|
||||||
try {
|
|
||||||
Map<String, Object> fm = objectMapper.readValue(task.getFeatureMapConfigJson(), new TypeReference<Map<String, Object>>() {});
|
|
||||||
task.setFeatureMapConfig(fm);
|
|
||||||
} catch (Exception ignored) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 由于改为 WebSocket 异步推送,这里简化为直接查库返回
|
// 由于改为 WebSocket 异步推送,这里简化为直接查库返回
|
||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1202,7 +1202,7 @@ public class ProjectServiceImpl
|
|||||||
.orderByAsc("step"));
|
.orderByAsc("step"));
|
||||||
|
|
||||||
Sheet s1 = wb.createSheet("projects");
|
Sheet s1 = wb.createSheet("projects");
|
||||||
String[] h1 = {"project_id","code","name","description","topology","created_at","updated_at","modifier"};
|
String[] h1 = {"project_id","code","name","description","topology","visibility","creator","created_at","updated_at","modifier"};
|
||||||
int r = 0; Row rh1 = s1.createRow(r++); for (int i=0;i<h1.length;i++) rh1.createCell(i).setCellValue(h1[i]);
|
int r = 0; Row rh1 = s1.createRow(r++); for (int i=0;i<h1.length;i++) rh1.createCell(i).setCellValue(h1[i]);
|
||||||
for (Project p : projects) {
|
for (Project p : projects) {
|
||||||
Row row = s1.createRow(r++);
|
Row row = s1.createRow(r++);
|
||||||
@ -1211,9 +1211,11 @@ public class ProjectServiceImpl
|
|||||||
row.createCell(2).setCellValue(p.getName()==null?"":p.getName());
|
row.createCell(2).setCellValue(p.getName()==null?"":p.getName());
|
||||||
row.createCell(3).setCellValue(p.getDescription()==null?"":p.getDescription());
|
row.createCell(3).setCellValue(p.getDescription()==null?"":p.getDescription());
|
||||||
row.createCell(4).setCellValue(p.getTopology()==null?"":p.getTopology());
|
row.createCell(4).setCellValue(p.getTopology()==null?"":p.getTopology());
|
||||||
row.createCell(5).setCellValue(p.getCreatedAt()==null?"":fmt.format(p.getCreatedAt()));
|
row.createCell(5).setCellValue(p.getVisibility()==null?"":p.getVisibility());
|
||||||
row.createCell(6).setCellValue(p.getUpdatedAt()==null?"":fmt.format(p.getUpdatedAt()));
|
row.createCell(6).setCellValue(p.getCreator()==null?"":p.getCreator());
|
||||||
row.createCell(7).setCellValue(p.getModifier()==null?"":p.getModifier());
|
row.createCell(7).setCellValue(p.getCreatedAt()==null?"":fmt.format(p.getCreatedAt()));
|
||||||
|
row.createCell(8).setCellValue(p.getUpdatedAt()==null?"":fmt.format(p.getUpdatedAt()));
|
||||||
|
row.createCell(9).setCellValue(p.getModifier()==null?"":p.getModifier());
|
||||||
}
|
}
|
||||||
for (int i=0;i<h1.length;i++) s1.autoSizeColumn(i);
|
for (int i=0;i<h1.length;i++) s1.autoSizeColumn(i);
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user