模型训练接口

This commit is contained in:
wanxiaoli 2026-03-11 14:52:43 +08:00
parent aff2d7feea
commit f7d113b71c
3 changed files with 92 additions and 40 deletions

View File

@ -1,6 +1,9 @@
package com.yfd.business.css.controller; package com.yfd.business.css.controller;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.domain.ModelTrainTask; import com.yfd.business.css.domain.ModelTrainTask;
import com.yfd.business.css.service.ModelTrainService; import com.yfd.business.css.service.ModelTrainService;
import com.yfd.platform.config.ResponseResult; import com.yfd.platform.config.ResponseResult;
@ -16,6 +19,9 @@ public class ModelTrainController {
@Autowired @Autowired
private ModelTrainService modelTrainService; private ModelTrainService modelTrainService;
@Autowired
private ObjectMapper objectMapper;
/** /**
* 上传数据集 * 上传数据集
@ -27,22 +33,65 @@ public class ModelTrainController {
} }
/** /**
* 提交训练任务 * 提交训练任务 (支持文件上传和 JSON 参数)
*/ */
@PostMapping("/submit") @PostMapping("/submit")
public ResponseResult submit(@RequestBody ModelTrainTask task) { public ResponseResult submit(@RequestPart("task") String taskJson,
String taskId = modelTrainService.submitTask(task); @RequestPart(value = "file", required = false) MultipartFile file) {
return ResponseResult.successData(taskId); try {
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
// 如果上传了文件优先使用文件路径
if (file != null && !file.isEmpty()) {
String path = modelTrainService.uploadDataset(file);
task.setDatasetPath(path);
}
// 校验 datasetPath
if (task.getDatasetPath() == null || task.getDatasetPath().isBlank()) {
return ResponseResult.error("数据集路径不能为空,请上传文件或指定路径");
}
String taskId = modelTrainService.submitTask(task);
return ResponseResult.successData(taskId);
} catch (JsonProcessingException e) {
return ResponseResult.error("参数解析失败: " + e.getMessage());
} catch (Exception e) {
return ResponseResult.error("提交任务失败: " + e.getMessage());
}
} }
/** /**
* 查询任务列表 * 查询任务列表 (支持条件查询)
*/ */
@GetMapping("/list") @GetMapping("/list")
public ResponseResult list(@RequestParam(defaultValue = "1") Integer current, public ResponseResult list(@RequestParam(defaultValue = "1") Integer current,
@RequestParam(defaultValue = "10") Integer size) { @RequestParam(defaultValue = "10") Integer size,
Page<ModelTrainTask> page = modelTrainService.page(new Page<>(current, size)); @RequestParam(required = false) String algoType,
return ResponseResult.successData(page); @RequestParam(required = false) String deviceType,
@RequestParam(required = false) String status,
@RequestParam(required = false) String name) {
Page<ModelTrainTask> page = new Page<>(current, size);
QueryWrapper<ModelTrainTask> query = new QueryWrapper<>();
if (algoType != null && !algoType.isBlank()) {
query.eq("algorithm_type", algoType);
}
if (deviceType != null && !deviceType.isBlank()) {
query.eq("device_type", deviceType);
}
if (status != null && !status.isBlank()) {
query.eq("status", status);
}
if (name != null && !name.isBlank()) {
query.like("task_name", name);
}
query.orderByDesc("created_at");
Page<ModelTrainTask> result = modelTrainService.page(page, query);
return ResponseResult.successData(result);
} }
/** /**

View File

@ -1,12 +1,10 @@
package com.yfd.business.css.domain; package com.yfd.business.css.domain;
import com.baomidou.mybatisplus.annotation.*; import com.baomidou.mybatisplus.annotation.*;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.Data; import lombok.Data;
import java.io.Serializable; import java.io.Serializable;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.Map;
@Data @Data
@TableName(value = "model_train_task", autoResultMap = true) @TableName(value = "model_train_task", autoResultMap = true)
@ -29,20 +27,20 @@ public class ModelTrainTask implements Serializable {
@TableField("dataset_path") @TableField("dataset_path")
private String datasetPath; private String datasetPath;
@TableField(value = "train_params", typeHandler = JacksonTypeHandler.class) @TableField(value = "train_params")
private Map<String, Object> trainParams; // 使用 Map 存储 JSON private String trainParams; // JSON String
@TableField("status") @TableField("status")
private String status; // PENDING, TRAINING, SUCCESS, FAILED private String status; // PENDING, TRAINING, SUCCESS, FAILED
@TableField(value = "metrics", typeHandler = JacksonTypeHandler.class) @TableField(value = "metrics")
private Map<String, Object> metrics; private String metrics; // JSON String
@TableField("model_output_path") @TableField("model_output_path")
private String modelOutputPath; private String modelOutputPath;
@TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class) @TableField(value = "feature_map_snapshot")
private Map<String, Object> featureMapSnapshot; private String featureMapSnapshot; // JSON String
@TableField("metrics_image_path") @TableField("metrics_image_path")
private String metricsImagePath; private String metricsImagePath;

View File

@ -3,6 +3,7 @@ package com.yfd.business.css.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.common.exception.BizException; import com.yfd.business.css.common.exception.BizException;
import com.yfd.business.css.domain.AlgorithmModel; import com.yfd.business.css.domain.AlgorithmModel;
@ -104,19 +105,28 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
request.put("algorithm_type", task.getAlgorithmType()); request.put("algorithm_type", task.getAlgorithmType());
request.put("device_type", task.getDeviceType()); request.put("device_type", task.getDeviceType());
request.put("dataset_path", task.getDatasetPath()); request.put("dataset_path", task.getDatasetPath());
request.put("hyperparameters", task.getTrainParams());
// 解析 hyperparameters (String -> Map)
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
try {
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
request.put("hyperparameters", params);
} catch (Exception e) {
System.err.println("解析训练参数失败,将作为原始字符串发送: " + e.getMessage());
request.put("hyperparameters", task.getTrainParams());
}
} else {
request.put("hyperparameters", new HashMap<>());
}
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers); HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
String url = pythonApiUrl + "/v1/train"; String url = pythonApiUrl + "/v1/train";
// 这里假设 Python 服务会立即返回启动后台线程
// 如果 Python 服务是阻塞的这里的 Async 会起作用
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class); ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
if (response.getStatusCode().is2xxSuccessful()) { if (response.getStatusCode().is2xxSuccessful()) {
// 调用成功等待后续轮询状态
System.out.println("训练任务提交成功: " + task.getTaskId()); System.out.println("训练任务提交成功: " + task.getTaskId());
} else { } else {
task.setStatus("FAILED"); task.setStatus("FAILED");
@ -154,15 +164,21 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if ("SUCCESS".equals(status)) { if ("SUCCESS".equals(status)) {
task.setModelOutputPath((String) body.get("model_path")); task.setModelOutputPath((String) body.get("model_path"));
task.setMetrics((Map<String, Object>) body.get("metrics"));
task.setFeatureMapSnapshot((Map<String, Object>) body.get("feature_map")); // Map -> JSON String
if (body.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
}
if (body.get("feature_map") != null) {
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map")));
}
task.setMetricsImagePath((String) body.get("metrics_image")); task.setMetricsImagePath((String) body.get("metrics_image"));
} else if ("FAILED".equals(status)) { } else if ("FAILED".equals(status)) {
task.setErrorLog((String) body.get("error")); task.setErrorLog((String) body.get("error"));
} else if ("TRAINING".equals(status)) { } else if ("TRAINING".equals(status)) {
// 可以更新进度或其他中间指标 if (body.get("metrics") != null) {
if (body.containsKey("metrics")) { task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
task.setMetrics((Map<String, Object>) body.get("metrics"));
} }
} }
@ -170,7 +186,6 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
} }
} }
} catch (Exception e) { } catch (Exception e) {
// 查询失败暂不更新状态或者记录日志
System.err.println("同步任务状态失败: " + e.getMessage()); System.err.println("同步任务状态失败: " + e.getMessage());
} }
} }
@ -204,24 +219,14 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
model.setAlgorithmType(task.getAlgorithmType()); model.setAlgorithmType(task.getAlgorithmType());
model.setDeviceType(task.getDeviceType()); model.setDeviceType(task.getDeviceType());
model.setVersionTag(versionTag); model.setVersionTag(versionTag);
model.setModelPath(task.getModelOutputPath()); // 这里简化处理直接引用临时路径实际生产建议移动文件到正式目录 model.setModelPath(task.getModelOutputPath());
model.setMetricsImagePath(task.getMetricsImagePath()); model.setMetricsImagePath(task.getMetricsImagePath());
model.setTrainedAt(LocalDateTime.now()); model.setTrainedAt(LocalDateTime.now());
model.setIsCurrent(0); // 默认不激活 model.setIsCurrent(0); // 默认不激活
try { // 直接赋值 String不再需要序列化
if (task.getFeatureMapSnapshot() != null) { model.setFeatureMapSnapshot(task.getFeatureMapSnapshot() != null ? task.getFeatureMapSnapshot() : "{}");
model.setFeatureMapSnapshot(objectMapper.writeValueAsString(task.getFeatureMapSnapshot())); model.setMetrics(task.getMetrics());
} else {
model.setFeatureMapSnapshot("{}");
}
if (task.getMetrics() != null) {
model.setMetrics(objectMapper.writeValueAsString(task.getMetrics()));
}
} catch (JsonProcessingException e) {
throw new BizException("JSON 序列化失败");
}
return algorithmModelService.save(model); return algorithmModelService.save(model);
} }