From f7d113b71cc1e44fb19c4336ce12f87370985b97 Mon Sep 17 00:00:00 2001 From: wanxiaoli Date: Wed, 11 Mar 2026 14:52:43 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../css/controller/ModelTrainController.java | 65 ++++++++++++++++--- .../business/css/domain/ModelTrainTask.java | 14 ++-- .../service/impl/ModelTrainServiceImpl.java | 53 ++++++++------- 3 files changed, 92 insertions(+), 40 deletions(-) diff --git a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java index 32de20e..6f7ecb8 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java @@ -1,6 +1,9 @@ package com.yfd.business.css.controller; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.yfd.business.css.domain.ModelTrainTask; import com.yfd.business.css.service.ModelTrainService; import com.yfd.platform.config.ResponseResult; @@ -16,6 +19,9 @@ public class ModelTrainController { @Autowired private ModelTrainService modelTrainService; + + @Autowired + private ObjectMapper objectMapper; /** * 上传数据集 @@ -27,22 +33,65 @@ public class ModelTrainController { } /** - * 提交训练任务 + * 提交训练任务 (支持文件上传和 JSON 参数) */ @PostMapping("/submit") - public ResponseResult submit(@RequestBody ModelTrainTask task) { - String taskId = modelTrainService.submitTask(task); - return ResponseResult.successData(taskId); + public ResponseResult submit(@RequestPart("task") String taskJson, + @RequestPart(value = "file", required = false) MultipartFile file) { + try { + ModelTrainTask task = objectMapper.readValue(taskJson, ModelTrainTask.class); + + // 如果上传了文件,优先使用文件路径 + if (file != null && !file.isEmpty()) { + String path = modelTrainService.uploadDataset(file); + task.setDatasetPath(path); + } + + // 校验 datasetPath + if (task.getDatasetPath() == null || task.getDatasetPath().isBlank()) { + return ResponseResult.error("数据集路径不能为空,请上传文件或指定路径"); + } + + String taskId = modelTrainService.submitTask(task); + return ResponseResult.successData(taskId); + } catch (JsonProcessingException e) { + return ResponseResult.error("参数解析失败: " + e.getMessage()); + } catch (Exception e) { + return ResponseResult.error("提交任务失败: " + e.getMessage()); + } } /** - * 查询任务列表 + * 查询任务列表 (支持条件查询) */ @GetMapping("/list") public ResponseResult list(@RequestParam(defaultValue = "1") Integer current, - @RequestParam(defaultValue = "10") Integer size) { - Page page = modelTrainService.page(new Page<>(current, size)); - return ResponseResult.successData(page); + @RequestParam(defaultValue = "10") Integer size, + @RequestParam(required = false) String algoType, + @RequestParam(required = false) String deviceType, + @RequestParam(required = false) String status, + @RequestParam(required = false) String name) { + + Page page = new Page<>(current, size); + QueryWrapper query = new QueryWrapper<>(); + + if (algoType != null && !algoType.isBlank()) { + query.eq("algorithm_type", algoType); + } + if (deviceType != null && !deviceType.isBlank()) { + query.eq("device_type", deviceType); + } + if (status != null && !status.isBlank()) { + query.eq("status", status); + } + if (name != null && !name.isBlank()) { + query.like("task_name", name); + } + + query.orderByDesc("created_at"); + + Page result = modelTrainService.page(page, query); + return ResponseResult.successData(result); } /** diff --git a/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java b/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java index 92b4bdc..f138e84 100644 --- a/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java +++ b/business-css/src/main/java/com/yfd/business/css/domain/ModelTrainTask.java @@ -1,12 +1,10 @@ package com.yfd.business.css.domain; import com.baomidou.mybatisplus.annotation.*; -import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; import lombok.Data; import java.io.Serializable; import java.time.LocalDateTime; -import java.util.Map; @Data @TableName(value = "model_train_task", autoResultMap = true) @@ -29,20 +27,20 @@ public class ModelTrainTask implements Serializable { @TableField("dataset_path") private String datasetPath; - @TableField(value = "train_params", typeHandler = JacksonTypeHandler.class) - private Map trainParams; // 使用 Map 存储 JSON + @TableField(value = "train_params") + private String trainParams; // JSON String @TableField("status") private String status; // PENDING, TRAINING, SUCCESS, FAILED - @TableField(value = "metrics", typeHandler = JacksonTypeHandler.class) - private Map metrics; + @TableField(value = "metrics") + private String metrics; // JSON String @TableField("model_output_path") private String modelOutputPath; - @TableField(value = "feature_map_snapshot", typeHandler = JacksonTypeHandler.class) - private Map featureMapSnapshot; + @TableField(value = "feature_map_snapshot") + private String featureMapSnapshot; // JSON String @TableField("metrics_image_path") private String metricsImagePath; diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java index f03d541..961bfd2 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java @@ -3,6 +3,7 @@ package com.yfd.business.css.service.impl; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.yfd.business.css.common.exception.BizException; import com.yfd.business.css.domain.AlgorithmModel; @@ -104,19 +105,28 @@ public class ModelTrainServiceImpl extends ServiceImpl Map) + if (task.getTrainParams() != null && !task.getTrainParams().isBlank()) { + try { + Map params = objectMapper.readValue(task.getTrainParams(), new TypeReference>() {}); + request.put("hyperparameters", params); + } catch (Exception e) { + System.err.println("解析训练参数失败,将作为原始字符串发送: " + e.getMessage()); + request.put("hyperparameters", task.getTrainParams()); + } + } else { + request.put("hyperparameters", new HashMap<>()); + } HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); HttpEntity> entity = new HttpEntity<>(request, headers); String url = pythonApiUrl + "/v1/train"; - // 这里假设 Python 服务会立即返回,启动后台线程 - // 如果 Python 服务是阻塞的,这里的 Async 会起作用 ResponseEntity response = restTemplate.postForEntity(url, entity, Map.class); if (response.getStatusCode().is2xxSuccessful()) { - // 调用成功,等待后续轮询状态 System.out.println("训练任务提交成功: " + task.getTaskId()); } else { task.setStatus("FAILED"); @@ -154,15 +164,21 @@ public class ModelTrainServiceImpl extends ServiceImpl) body.get("metrics")); - task.setFeatureMapSnapshot((Map) body.get("feature_map")); + + // Map -> JSON String + if (body.get("metrics") != null) { + task.setMetrics(objectMapper.writeValueAsString(body.get("metrics"))); + } + if (body.get("feature_map") != null) { + task.setFeatureMapSnapshot(objectMapper.writeValueAsString(body.get("feature_map"))); + } + task.setMetricsImagePath((String) body.get("metrics_image")); } else if ("FAILED".equals(status)) { task.setErrorLog((String) body.get("error")); } else if ("TRAINING".equals(status)) { - // 可以更新进度或其他中间指标 - if (body.containsKey("metrics")) { - task.setMetrics((Map) body.get("metrics")); + if (body.get("metrics") != null) { + task.setMetrics(objectMapper.writeValueAsString(body.get("metrics"))); } } @@ -170,7 +186,6 @@ public class ModelTrainServiceImpl extends ServiceImpl