From bfb3d234ebd14aa14ad84f5c7b459d9eec5d730f Mon Sep 17 00:00:00 2001 From: wanxiaoli Date: Wed, 15 Apr 2026 14:55:56 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=95=B0=E6=8D=AE=E9=9B=86=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../css/controller/ModelTrainController.java | 8 +- .../css/service/ModelTrainService.java | 3 + .../service/impl/ModelTrainServiceImpl.java | 188 +++++++++++++++++- 3 files changed, 188 insertions(+), 11 deletions(-) diff --git a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java index 0bb16ff..9c704e0 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java @@ -57,11 +57,7 @@ public class ModelTrainController { @Log(value = "上传训练数据集", module = "模型训练") @PostMapping("/upload") public ResponseResult upload(@RequestParam("file") MultipartFile file) { - String path = modelTrainService.uploadDataset(file); - return ResponseResult.successData(Map.of( - "path", path, - "columns", modelTrainService.parseDatasetColumns(path) - )); + return ResponseResult.successData(modelTrainService.uploadAndInspectDataset(file)); } /** @@ -90,6 +86,8 @@ public class ModelTrainController { String taskId = modelTrainService.submitTask(task); // System.out.println("提交任务成功,任务ID: " + taskId); return ResponseResult.successData(taskId); + } catch (com.yfd.business.css.common.exception.BizException e) { + return ResponseResult.error(e.getMessage()); } catch (JsonProcessingException e) { return ResponseResult.error("参数解析失败: " + e.getMessage()); } catch (Exception e) { diff --git a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java index f742b5c..131db13 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java @@ -5,6 +5,7 @@ import com.yfd.business.css.domain.ModelTrainTask; import org.springframework.web.multipart.MultipartFile; import java.util.List; +import java.util.Map; public interface ModelTrainService extends IService { /** @@ -16,6 +17,8 @@ public interface ModelTrainService extends IService { List parseDatasetColumns(String datasetPath); + Map uploadAndInspectDataset(MultipartFile file); + /** * 提交训练任务 * @param task 任务信息 diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java index c2fdd5e..2bedfcf 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java @@ -41,12 +41,14 @@ import java.nio.file.StandardCopyOption; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Map; import java.util.UUID; import java.util.regex.Pattern; +import java.util.regex.Matcher; @Slf4j @Service @@ -71,6 +73,8 @@ public class ModelTrainServiceImpl extends ServiceImpl uploadAndInspectDataset(MultipartFile file) { + String path = uploadDataset(file); + List columns = parseDatasetColumns(path); + List> warnings = validateUploadColumns(path, columns); + return Map.of( + "path", path, + "columns", columns, + "warnings", warnings + ); + } + @Override @Transactional public String submitTask(ModelTrainTask task) { @@ -146,13 +162,18 @@ public class ModelTrainServiceImpl extends ServiceImpl>() {}); - } catch (Exception ignored) { - throw new BizException("feature_map_config JSON解析失败"); - } + if (task.getFeatureMapConfig() == null || task.getFeatureMapConfig().isBlank()) { + throw new BizException("feature_map_config不能为空"); } + + Map fmMap; + try { + fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference>() {}); + } catch (Exception e) { + throw new BizException("feature_map_config JSON解析失败"); + } + validateDatasetForSubmit(task.getDatasetPath(), fmMap); + this.save(task); // 2. 异步调用 Python 训练 @@ -251,6 +272,161 @@ public class ModelTrainServiceImpl extends ServiceImpl> validateUploadColumns(String datasetPath, List columns) { + if (datasetPath == null || datasetPath.isBlank()) { + throw new BizException("数据集路径不能为空"); + } + Path p = Paths.get(datasetPath); + String lower = p.getFileName().toString().toLowerCase(); + if (!(lower.endsWith(".xlsx") || lower.endsWith(".xls") || lower.endsWith(".csv"))) { + throw new BizException("不支持的数据集格式: " + p.getFileName()); + } + if (columns == null || columns.isEmpty()) { + throw new BizException("表头为空,请检查数据集第一行是否为表头"); + } + + Map seen = new LinkedHashMap<>(); + List duplicates = new ArrayList<>(); + for (String c : columns) { + String k = c == null ? "" : c.trim(); + if (k.isEmpty()) continue; + int cnt = seen.getOrDefault(k, 0) + 1; + seen.put(k, cnt); + if (cnt == 2) duplicates.add(k); + } + if (!duplicates.isEmpty()) { + throw new BizException("表头列名重复: " + String.join(", ", duplicates)); + } + + List> warnings = new ArrayList<>(); + if (lower.endsWith(".csv") && columns.size() == 1) { + warnings.add(Map.of( + "code", "SUSPICIOUS_CSV", + "message", "CSV 可能不是逗号分隔,表头仅 1 列" + )); + } + return warnings; + } + + private void validateDatasetForSubmit(String datasetPath, Map featureMapConfig) { + if (datasetPath == null || datasetPath.isBlank()) { + throw new BizException("数据集路径不能为空,请上传文件或指定路径"); + } + List columns = parseDatasetColumns(datasetPath); + validateUploadColumns(datasetPath, columns); + + Object inputColsObj = featureMapConfig.get("input_cols"); + List inputCols = toStringList(inputColsObj); + if (inputCols == null || inputCols.isEmpty()) { + throw new BizException("feature_map_config.input_cols不能为空"); + } + + String targetCol = toNonBlankString(featureMapConfig.get("target_col")); + if (targetCol == null) { + throw new BizException("feature_map_config.target_col不能为空"); + } + + List missing = new ArrayList<>(); + for (String c : inputCols) { + if (c == null || c.isBlank()) continue; + if (!columns.contains(c)) missing.add(c); + } + if (!columns.contains(targetCol)) missing.add(targetCol); + if (!missing.isEmpty()) { + throw new BizException("数据集缺少列: " + String.join(", ", missing)); + } + if (inputCols.contains(targetCol)) { + throw new BizException("target_col不能与input_cols重复: " + targetCol); + } + + boolean derive = Boolean.TRUE.equals(featureMapConfig.get("derive")); + if (derive) { + Object rulesObj = featureMapConfig.get("derived_rules"); + List> rules = toMapList(rulesObj); + if (rules == null || rules.isEmpty()) { + throw new BizException("derive=true时,derived_rules不能为空"); + } + + List reservedNames = new ArrayList<>(inputCols); + reservedNames.add(targetCol); + Map seen = new LinkedHashMap<>(); + + for (Map r : rules) { + String name = toNonBlankString(r.get("name")); + String expr = toNonBlankString(r.get("expr")); + if (name == null || expr == null) { + throw new BizException("derived_rules中每条规则的name/expr不能为空"); + } + if (reservedNames.contains(name)) { + throw new BizException("derived_rules.name不能与input_cols/target_col重名: " + name); + } + int cnt = seen.getOrDefault(name, 0) + 1; + seen.put(name, cnt); + if (cnt > 1) { + throw new BizException("derived_rules.name重复: " + name); + } + if (!DERIVED_EXPR_ALLOWED_PATTERN.matcher(expr).matches()) { + throw new BizException("derived_rules.expr包含非法字符: " + name); + } + + Matcher m = IDENTIFIER_PATTERN.matcher(expr); + List vars = new ArrayList<>(); + while (m.find()) { + vars.add(m.group()); + } + List unknown = new ArrayList<>(); + for (String v : vars) { + if (!inputCols.contains(v)) { + unknown.add(v); + } + } + if (!unknown.isEmpty()) { + throw new BizException("derived_rules.expr包含未知变量(必须来自input_cols): " + name + " -> " + String.join(", ", unknown)); + } + } + } + } + + private String toNonBlankString(Object v) { + if (v == null) return null; + String s = String.valueOf(v).trim(); + return s.isEmpty() ? null : s; + } + + private List toStringList(Object v) { + if (v == null) return null; + if (v instanceof List list) { + List out = new ArrayList<>(); + for (Object x : list) { + if (x == null) continue; + String s = String.valueOf(x).trim(); + if (!s.isEmpty()) out.add(s); + } + return out; + } + return null; + } + + private List> toMapList(Object v) { + if (v == null) return null; + if (v instanceof List list) { + List> out = new ArrayList<>(); + for (Object x : list) { + if (x instanceof Map m) { + Map mm = new HashMap<>(); + for (Map.Entry e : m.entrySet()) { + if (e.getKey() != null) { + mm.put(String.valueOf(e.getKey()), e.getValue()); + } + } + out.add(mm); + } + } + return out; + } + return null; + } + private List parseExcelHeader(File file) { DataFormatter df = new DataFormatter(); try (Workbook wb = WorkbookFactory.create(file)) {