模型训练接口
This commit is contained in:
parent
aff2d7feea
commit
f7d113b71c
@ -1,6 +1,9 @@
|
|||||||
package com.yfd.business.css.controller;
|
package com.yfd.business.css.controller;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.yfd.business.css.domain.ModelTrainTask;
|
import com.yfd.business.css.domain.ModelTrainTask;
|
||||||
import com.yfd.business.css.service.ModelTrainService;
|
import com.yfd.business.css.service.ModelTrainService;
|
||||||
import com.yfd.platform.config.ResponseResult;
|
import com.yfd.platform.config.ResponseResult;
|
||||||
@ -16,6 +19,9 @@ public class ModelTrainController {
|
|||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ModelTrainService modelTrainService;
|
private ModelTrainService modelTrainService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ObjectMapper objectMapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 上传数据集
|
* 上传数据集
|
||||||
@ -27,22 +33,65 @@ public class ModelTrainController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 提交训练任务
|
* 提交训练任务 (支持文件上传和 JSON 参数)
|
||||||
*/
|
*/
|
||||||
@PostMapping("/submit")
|
@PostMapping("/submit")
|
||||||
public ResponseResult submit(@RequestBody ModelTrainTask task) {
|
public ResponseResult submit(@RequestPart("task") String taskJson,
|
||||||
String taskId = modelTrainService.submitTask(task);
|
@RequestPart(value = "file", required = false) MultipartFile file) {
|
||||||
return ResponseResult.successData(taskId);
|
try {
|
||||||
|
ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class);
|
||||||
|
|
||||||
|
// 如果上传了文件,优先使用文件路径
|
||||||
|
if (file != null && !file.isEmpty()) {
|
||||||
|
String path = modelTrainService.uploadDataset(file);
|
||||||
|
task.setDatasetPath(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 校验 datasetPath
|
||||||
|
if (task.getDatasetPath() == null || task.getDatasetPath().isBlank()) {
|
||||||
|
return ResponseResult.error("数据集路径不能为空,请上传文件或指定路径");
|
||||||
|
}
|
||||||
|
|
||||||
|
String taskId = modelTrainService.submitTask(task);
|
||||||
|
return ResponseResult.successData(taskId);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
||||||
|
} catch (Exception e) {
|
||||||
|
return ResponseResult.error("提交任务失败: " + e.getMessage());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 查询任务列表
|
* 查询任务列表 (支持条件查询)
|
||||||
*/
|
*/
|
||||||
@GetMapping("/list")
|
@GetMapping("/list")
|
||||||
public ResponseResult list(@RequestParam(defaultValue = "1") Integer current,
|
public ResponseResult list(@RequestParam(defaultValue = "1") Integer current,
|
||||||
@RequestParam(defaultValue = "10") Integer size) {
|
@RequestParam(defaultValue = "10") Integer size,
|
||||||
Page<ModelTrainTask> page = modelTrainService.page(new Page<>(current, size));
|
@RequestParam(required = false) String algoType,
|
||||||
return ResponseResult.successData(page);
|
@RequestParam(required = false) String deviceType,
|
||||||
|
@RequestParam(required = false) String status,
|
||||||
|
@RequestParam(required = false) String name) {
|
||||||
|
|
||||||
|
Page<ModelTrainTask> page = new Page<>(current, size);
|
||||||
|
QueryWrapper<ModelTrainTask> query = new QueryWrapper<>();
|
||||||
|
|
||||||
|
if (algoType != null && !algoType.isBlank()) {
|
||||||
|
query.eq("algorithm_type", algoType);
|
||||||
|
}
|
||||||
|
if (deviceType != null && !deviceType.isBlank()) {
|
||||||
|
query.eq("device_type", deviceType);
|
||||||
|
}
|
||||||
|
if (status != null && !status.isBlank()) {
|
||||||
|
query.eq("status", status);
|
||||||
|
}
|
||||||
|
if (name != null && !name.isBlank()) {
|
||||||
|
query.like("task_name", name);
|
||||||
|
}
|
||||||
|
|
||||||
|
query.orderByDesc("created_at");
|
||||||
|
|
||||||
|
Page<ModelTrainTask> result = modelTrainService.page(page, query);
|
||||||
|
return ResponseResult.successData(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
package com.yfd.business.css.domain;
|
package com.yfd.business.css.domain;
|
||||||
|
|
||||||
import com.baomidou.mybatisplus.annotation.*;
|
import com.baomidou.mybatisplus.annotation.*;
|
||||||
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@TableName(value = "model_train_task", autoResultMap = true)
|
@TableName(value = "model_train_task", autoResultMap = true)
|
||||||
@ -29,20 +27,20 @@ public class ModelTrainTask implements Serializable {
|
|||||||
@TableField("dataset_path")
|
@TableField("dataset_path")
|
||||||
private String datasetPath;
|
private String datasetPath;
|
||||||
|
|
||||||
@TableField(value = "train_params", typeHandler = JacksonTypeHandler.class)
|
@TableField(value = "train_params")
|
||||||
private Map<String, Object> trainParams; // 使用 Map 存储 JSON
|
private String trainParams; // JSON String
|
||||||
|
|
||||||
@TableField("status")
|
@TableField("status")
|
||||||
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
||||||
|
|
||||||
@TableField(value = "metrics", typeHandler = JacksonTypeHandler.class)
|
@TableField(value = "metrics")
|
||||||
private Map<String, Object> metrics;
|
private String metrics; // JSON String
|
||||||
|
|
||||||
@TableField("model_output_path")
|
@TableField("model_output_path")
|
||||||
private String modelOutputPath;
|
private String modelOutputPath;
|
||||||
|
|
||||||
@TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class)
|
@TableField(value = "feature_map_snapshot")
|
||||||
private Map<String, Object> featureMapSnapshot;
|
private String featureMapSnapshot; // JSON String
|
||||||
|
|
||||||
@TableField("metrics_image_path")
|
@TableField("metrics_image_path")
|
||||||
private String metricsImagePath;
|
private String metricsImagePath;
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package com.yfd.business.css.service.impl;
|
|||||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.yfd.business.css.common.exception.BizException;
|
import com.yfd.business.css.common.exception.BizException;
|
||||||
import com.yfd.business.css.domain.AlgorithmModel;
|
import com.yfd.business.css.domain.AlgorithmModel;
|
||||||
@ -104,19 +105,28 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
request.put("algorithm_type", task.getAlgorithmType());
|
request.put("algorithm_type", task.getAlgorithmType());
|
||||||
request.put("device_type", task.getDeviceType());
|
request.put("device_type", task.getDeviceType());
|
||||||
request.put("dataset_path", task.getDatasetPath());
|
request.put("dataset_path", task.getDatasetPath());
|
||||||
request.put("hyperparameters", task.getTrainParams());
|
|
||||||
|
// 解析 hyperparameters (String -> Map)
|
||||||
|
if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) {
|
||||||
|
try {
|
||||||
|
Map<String, Object> params = objectMapper.readValue(task.getTrainParams(), new TypeReference<Map<String, Object>>() {});
|
||||||
|
request.put("hyperparameters", params);
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.err.println("解析训练参数失败,将作为原始字符串发送: " + e.getMessage());
|
||||||
|
request.put("hyperparameters", task.getTrainParams());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
request.put("hyperparameters", new HashMap<>());
|
||||||
|
}
|
||||||
|
|
||||||
HttpHeaders headers = new HttpHeaders();
|
HttpHeaders headers = new HttpHeaders();
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
||||||
|
|
||||||
String url = pythonApiUrl + "/v1/train";
|
String url = pythonApiUrl + "/v1/train";
|
||||||
// 这里假设 Python 服务会立即返回,启动后台线程
|
|
||||||
// 如果 Python 服务是阻塞的,这里的 Async 会起作用
|
|
||||||
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
ResponseEntity<Map> response = restTemplate.postForEntity(url, entity, Map.class);
|
||||||
|
|
||||||
if (response.getStatusCode().is2xxSuccessful()) {
|
if (response.getStatusCode().is2xxSuccessful()) {
|
||||||
// 调用成功,等待后续轮询状态
|
|
||||||
System.out.println("训练任务提交成功: " + task.getTaskId());
|
System.out.println("训练任务提交成功: " + task.getTaskId());
|
||||||
} else {
|
} else {
|
||||||
task.setStatus("FAILED");
|
task.setStatus("FAILED");
|
||||||
@ -154,15 +164,21 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
|
|
||||||
if ("SUCCESS".equals(status)) {
|
if ("SUCCESS".equals(status)) {
|
||||||
task.setModelOutputPath((String) body.get("model_path"));
|
task.setModelOutputPath((String) body.get("model_path"));
|
||||||
task.setMetrics((Map<String, Object>) body.get("metrics"));
|
|
||||||
task.setFeatureMapSnapshot((Map<String, Object>) body.get("feature_map"));
|
// Map -> JSON String
|
||||||
|
if (body.get("metrics") != null) {
|
||||||
|
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
||||||
|
}
|
||||||
|
if (body.get("feature_map") != null) {
|
||||||
|
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map")));
|
||||||
|
}
|
||||||
|
|
||||||
task.setMetricsImagePath((String) body.get("metrics_image"));
|
task.setMetricsImagePath((String) body.get("metrics_image"));
|
||||||
} else if ("FAILED".equals(status)) {
|
} else if ("FAILED".equals(status)) {
|
||||||
task.setErrorLog((String) body.get("error"));
|
task.setErrorLog((String) body.get("error"));
|
||||||
} else if ("TRAINING".equals(status)) {
|
} else if ("TRAINING".equals(status)) {
|
||||||
// 可以更新进度或其他中间指标
|
if (body.get("metrics") != null) {
|
||||||
if (body.containsKey("metrics")) {
|
task.setMetrics(objectMapper.writeValueAsString(body.get("metrics")));
|
||||||
task.setMetrics((Map<String, Object>) body.get("metrics"));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,7 +186,6 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
// 查询失败,暂不更新状态,或者记录日志
|
|
||||||
System.err.println("同步任务状态失败: " + e.getMessage());
|
System.err.println("同步任务状态失败: " + e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -204,24 +219,14 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
|||||||
model.setAlgorithmType(task.getAlgorithmType());
|
model.setAlgorithmType(task.getAlgorithmType());
|
||||||
model.setDeviceType(task.getDeviceType());
|
model.setDeviceType(task.getDeviceType());
|
||||||
model.setVersionTag(versionTag);
|
model.setVersionTag(versionTag);
|
||||||
model.setModelPath(task.getModelOutputPath()); // 这里简化处理,直接引用临时路径,实际生产建议移动文件到正式目录
|
model.setModelPath(task.getModelOutputPath());
|
||||||
model.setMetricsImagePath(task.getMetricsImagePath());
|
model.setMetricsImagePath(task.getMetricsImagePath());
|
||||||
model.setTrainedAt(LocalDateTime.now());
|
model.setTrainedAt(LocalDateTime.now());
|
||||||
model.setIsCurrent(0); // 默认不激活
|
model.setIsCurrent(0); // 默认不激活
|
||||||
|
|
||||||
try {
|
// 直接赋值 String,不再需要序列化
|
||||||
if (task.getFeatureMapSnapshot() != null) {
|
model.setFeatureMapSnapshot(task.getFeatureMapSnapshot() != null ? task.getFeatureMapSnapshot() : "{}");
|
||||||
model.setFeatureMapSnapshot(objectMapper.writeValueAsString(task.getFeatureMapSnapshot()));
|
model.setMetrics(task.getMetrics());
|
||||||
} else {
|
|
||||||
model.setFeatureMapSnapshot("{}");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (task.getMetrics() != null) {
|
|
||||||
model.setMetrics(objectMapper.writeValueAsString(task.getMetrics()));
|
|
||||||
}
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new BizException("JSON 序列化失败");
|
|
||||||
}
|
|
||||||
|
|
||||||
return algorithmModelService.save(model);
|
return algorithmModelService.save(model);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user