模型训练接口
This commit is contained in:
parent
aff2d7feea
commit
f7d113b71c
@ -1,6 +1,9 @@
|
||||
package com.yfd.business.css.controller;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yfd.business.css.domain.ModelTrainTask;
|
||||
import com.yfd.business.css.service.ModelTrainService;
|
||||
import com.yfd.platform.config.ResponseResult;
|
||||
@ -17,6 +20,9 @@ public class ModelTrainController {
|
||||
@Autowired
|
||||
private ModelTrainService modelTrainService;
|
||||
|
||||
@Autowired
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
/**
|
||||
* 上传数据集
|
||||
*/
|
||||
@ -27,22 +33,65 @@ public class ModelTrainController {
|
||||
}
|
||||
|
||||
/**
|
||||
* 提交训练任务
|
||||
* 提交训练任务 (支持文件上传和 JSON 参数)
|
||||
*/
|
||||
@PostMapping("/submit")
|
||||
public ResponseResult submit(@RequestBody ModelTrainTask task) {
|
||||
public ResponseResult submit(@RequestPart("task") String taskJson,
|
||||
@RequestPart(value = "file", required = false) MultipartFile file) {
|
||||
try {
|
||||
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
|
||||
|
||||
// 如果上传了文件,优先使用文件路径
|
||||
if (file != null && !file.isEmpty()) {
|
||||
String path = modelTrainService.uploadDataset(file);
|
||||
task.setDatasetPath(path);
|
||||
}
|
||||
|
||||
// 校验 datasetPath
|
||||
if (task.getDatasetPath() == null || task.getDatasetPath().isBlank()) {
|
||||
return ResponseResult.error("数据集路径不能为空,请上传文件或指定路径");
|
||||
}
|
||||
|
||||
String taskId = modelTrainService.submitTask(task);
|
||||
return ResponseResult.successData(taskId);
|
||||
} catch (JsonProcessingException e) {
|
||||
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
||||
} catch (Exception e) {
|
||||
return ResponseResult.error("提交任务失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询任务列表
|
||||
* 查询任务列表 (支持条件查询)
|
||||
*/
|
||||
@GetMapping("/list")
|
||||
public ResponseResult list(@RequestParam(defaultValue = "1") Integer current,
|
||||
@RequestParam(defaultValue = "10") Integer size) {
|
||||
Page<ModelTrainTask> page = modelTrainService.page(new Page<>(current, size));
|
||||
return ResponseResult.successData(page);
|
||||
@RequestParam(defaultValue = "10") Integer size,
|
||||
@RequestParam(required = false) String algoType,
|
||||
@RequestParam(required = false) String deviceType,
|
||||
@RequestParam(required = false) String status,
|
||||
@RequestParam(required = false) String name) {
|
||||
|
||||
Page<ModelTrainTask> page = new Page<>(current, size);
|
||||
QueryWrapper<ModelTrainTask> query = new QueryWrapper<>();
|
||||
|
||||
if (algoType != null && !algoType.isBlank()) {
|
||||
query.eq("algorithm_type", algoType);
|
||||
}
|
||||
if (deviceType != null && !deviceType.isBlank()) {
|
||||
query.eq("device_type", deviceType);
|
||||
}
|
||||
if (status != null && !status.isBlank()) {
|
||||
query.eq("status", status);
|
||||
}
|
||||
if (name != null && !name.isBlank()) {
|
||||
query.like("task_name", name);
|
||||
}
|
||||
|
||||
query.orderByDesc("created_at");
|
||||
|
||||
Page<ModelTrainTask> result = modelTrainService.page(page, query);
|
||||
return ResponseResult.successData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
package com.yfd.business.css.domain;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.*;
|
||||
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@TableName(value = "model_train_task", autoResultMap = true)
|
||||
@ -29,20 +27,20 @@ public class ModelTrainTask implements Serializable {
|
||||
@TableField("dataset_path")
|
||||
private String datasetPath;
|
||||
|
||||
@TableField(value = "train_params", typeHandler = JacksonTypeHandler.class)
|
||||
private Map<String, Object> trainParams; // 使用 Map 存储 JSON
|
||||
@TableField(value = "train_params")
|
||||
private String trainParams; // JSON String
|
||||
|
||||
@TableField("status")
|
||||
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
||||
|
||||
@TableField(value = "metrics", typeHandler = JacksonTypeHandler.class)
|
||||
private Map<String, Object> metrics;
|
||||
@TableField(value = "metrics")
|
||||
private String metrics; // JSON String
|
||||
|
||||
@TableField("model_output_path")
|
||||
private String modelOutputPath;
|
||||
|
||||
@TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class)
|
||||
private Map<String, Object> featureMapSnapshot;
|
||||
@TableField(value = "feature_map_snapshot")
|
||||
private String featureMapSnapshot; // JSON String
|
||||
|
||||
@TableField("metrics_image_path")
|
||||
private String metricsImagePath;
|
||||
|
||||
@ -3,6 +3,7 @@ package com.yfd.business.css.service.impl;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yfd.business.css.common.exception.BizException;
|
||||
import com.yfd.business.css.domain.AlgorithmModel;
|
||||
@ -104,19 +105,28 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
request.put("algorithm_type", task.getAlgorithmType());
|
||||
request.put("device_type", task.getDeviceType());
|
||||
request.put("dataset_path", task.getDatasetPath());
|
||||
|
||||
// 解析 hyperparameters (String -> Map)
|
||||
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
||||
try {
|
||||
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
|
||||
request.put("hyperparameters", params);
|
||||
} catch (Exception e) {
|
||||
System.err.println("解析训练参数失败,将作为原始字符串发送: " + e.getMessage());
|
||||
request.put("hyperparameters", task.getTrainParams());
|
||||
}
|
||||
} else {
|
||||
request.put("hyperparameters", new HashMap<>());
|
||||
}
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
||||
|
||||
String url = pythonApiUrl + "/v1/train";
|
||||
// 这里假设 Python 服务会立即返回,启动后台线程
|
||||
// 如果 Python 服务是阻塞的,这里的 Async 会起作用
|
||||
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
||||
|
||||
if (response.getStatusCode().is2xxSuccessful()) {
|
||||
// 调用成功,等待后续轮询状态
|
||||
System.out.println("训练任务提交成功: " + task.getTaskId());
|
||||
} else {
|
||||
task.setStatus("FAILED");
|
||||
@ -154,15 +164,21 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
|
||||
if ("SUCCESS".equals(status)) {
|
||||
task.setModelOutputPath((String) body.get("model_path"));
|
||||
task.setMetrics((Map<String, Object>) body.get("metrics"));
|
||||
task.setFeatureMapSnapshot((Map<String, Object>) body.get("feature_map"));
|
||||
|
||||
// Map -> JSON String
|
||||
if (body.get("metrics") != null) {
|
||||
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
||||
}
|
||||
if (body.get("feature_map") != null) {
|
||||
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map")));
|
||||
}
|
||||
|
||||
task.setMetricsImagePath((String) body.get("metrics_image"));
|
||||
} else if ("FAILED".equals(status)) {
|
||||
task.setErrorLog((String) body.get("error"));
|
||||
} else if ("TRAINING".equals(status)) {
|
||||
// 可以更新进度或其他中间指标
|
||||
if (body.containsKey("metrics")) {
|
||||
task.setMetrics((Map<String, Object>) body.get("metrics"));
|
||||
if (body.get("metrics") != null) {
|
||||
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
||||
}
|
||||
}
|
||||
|
||||
@ -170,7 +186,6 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// 查询失败,暂不更新状态,或者记录日志
|
||||
System.err.println("同步任务状态失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
@ -204,24 +219,14 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
model.setAlgorithmType(task.getAlgorithmType());
|
||||
model.setDeviceType(task.getDeviceType());
|
||||
model.setVersionTag(versionTag);
|
||||
model.setModelPath(task.getModelOutputPath()); // 这里简化处理,直接引用临时路径,实际生产建议移动文件到正式目录
|
||||
model.setModelPath(task.getModelOutputPath());
|
||||
model.setMetricsImagePath(task.getMetricsImagePath());
|
||||
model.setTrainedAt(LocalDateTime.now());
|
||||
model.setIsCurrent(0); // 默认不激活
|
||||
|
||||
try {
|
||||
if (task.getFeatureMapSnapshot() != null) {
|
||||
model.setFeatureMapSnapshot(objectMapper.writeValueAsString(task.getFeatureMapSnapshot()));
|
||||
} else {
|
||||
model.setFeatureMapSnapshot("{}");
|
||||
}
|
||||
|
||||
if (task.getMetrics() != null) {
|
||||
model.setMetrics(objectMapper.writeValueAsString(task.getMetrics()));
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new BizException("JSON 序列化失败");
|
||||
}
|
||||
// 直接赋值 String,不再需要序列化
|
||||
model.setFeatureMapSnapshot(task.getFeatureMapSnapshot() != null ? task.getFeatureMapSnapshot() : "{}");
|
||||
model.setMetrics(task.getMetrics());
|
||||
|
||||
return algorithmModelService.save(model);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user