模型训练添加数据集校验

This commit is contained in:
wanxiaoli 2026-04-15 14:55:56 +08:00
parent 8395224f42
commit bfb3d234eb
3 changed files with 188 additions and 11 deletions

View File

@ -57,11 +57,7 @@ public class ModelTrainController {
@Log(value = "上传训练数据集", module = "模型训练")
@PostMapping("/upload")
public ResponseResult upload(@RequestParam("file") MultipartFile file) {
String path = modelTrainService.uploadDataset(file);
return ResponseResult.successData(Map.of(
"path", path,
"columns", modelTrainService.parseDatasetColumns(path)
));
return ResponseResult.successData(modelTrainService.uploadAndInspectDataset(file));
}
/**
@ -90,6 +86,8 @@ public class ModelTrainController {
String taskId = modelTrainService.submitTask(task);
// System.out.println("提交任务成功任务ID: " + taskId);
return ResponseResult.successData(taskId);
} catch (com.yfd.business.css.common.exception.BizException e) {
return ResponseResult.error(e.getMessage());
} catch (JsonProcessingException e) {
return ResponseResult.error("参数解析失败: " + e.getMessage());
} catch (Exception e) {

View File

@ -5,6 +5,7 @@ import com.yfd.business.css.domain.ModelTrainTask;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
import java.util.Map;
public interface ModelTrainService extends IService<ModelTrainTask> {
/**
@ -16,6 +17,8 @@ public interface ModelTrainService extends IService<ModelTrainTask> {
List<String> parseDatasetColumns(String datasetPath);
Map<String, Object> uploadAndInspectDataset(MultipartFile file);
/**
* 提交训练任务
* @param task 任务信息

View File

@ -41,12 +41,14 @@ import java.nio.file.StandardCopyOption;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Map;
import java.util.UUID;
import java.util.regex.Pattern;
import java.util.regex.Matcher;
@Slf4j
@Service
@ -71,6 +73,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
private ObjectMapper objectMapper;
private static final Pattern VERSION_TAG_PATTERN = Pattern.compile("^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$");
private static final Pattern IDENTIFIER_PATTERN = Pattern.compile("[A-Za-z_][A-Za-z0-9_]*");
private static final Pattern DERIVED_EXPR_ALLOWED_PATTERN = Pattern.compile("^[A-Za-z0-9_+\\-*/()\\s]+$");
@Override
public String uploadDataset(MultipartFile file) {
@ -138,6 +142,18 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
throw new BizException("不支持的数据集格式: " + p.getFileName());
}
@Override
public Map<String, Object> uploadAndInspectDataset(MultipartFile file) {
String path = uploadDataset(file);
List<String> columns = parseDatasetColumns(path);
List<Map<String, Object>> warnings = validateUploadColumns(path, columns);
return Map.of(
"path", path,
"columns", columns,
"warnings", warnings
);
}
@Override
@Transactional
public String submitTask(ModelTrainTask task) {
@ -146,13 +162,18 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (task.getTaskId() == null) {
task.setTaskId(UUID.randomUUID().toString());
}
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
if (task.getFeatureMapConfig() == null || task.getFeatureMapConfig().isBlank()) {
throw new BizException("feature_map_config不能为空");
}
Map<String, Object> fmMap;
try {
objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
} catch (Exception ignored) {
fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
} catch (Exception e) {
throw new BizException("feature_map_config JSON解析失败");
}
}
validateDatasetForSubmit(task.getDatasetPath(), fmMap);
this.save(task);
// 2. 异步调用 Python 训练
@ -251,6 +272,161 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
}
}
private List<Map<String, Object>> validateUploadColumns(String datasetPath, List<String> columns) {
if (datasetPath == null || datasetPath.isBlank()) {
throw new BizException("数据集路径不能为空");
}
Path p = Paths.get(datasetPath);
String lower = p.getFileName().toString().toLowerCase();
if (!(lower.endsWith(".xlsx") || lower.endsWith(".xls") || lower.endsWith(".csv"))) {
throw new BizException("不支持的数据集格式: " + p.getFileName());
}
if (columns == null || columns.isEmpty()) {
throw new BizException("表头为空,请检查数据集第一行是否为表头");
}
Map<String, Integer> seen = new LinkedHashMap<>();
List<String> duplicates = new ArrayList<>();
for (String c : columns) {
String k = c == null ? "" : c.trim();
if (k.isEmpty()) continue;
int cnt = seen.getOrDefault(k, 0) + 1;
seen.put(k, cnt);
if (cnt == 2) duplicates.add(k);
}
if (!duplicates.isEmpty()) {
throw new BizException("表头列名重复: " + String.join(", ", duplicates));
}
List<Map<String, Object>> warnings = new ArrayList<>();
if (lower.endsWith(".csv") && columns.size() == 1) {
warnings.add(Map.of(
"code", "SUSPICIOUS_CSV",
"message", "CSV 可能不是逗号分隔,表头仅 1 列"
));
}
return warnings;
}
private void validateDatasetForSubmit(String datasetPath, Map<String, Object> featureMapConfig) {
if (datasetPath == null || datasetPath.isBlank()) {
throw new BizException("数据集路径不能为空,请上传文件或指定路径");
}
List<String> columns = parseDatasetColumns(datasetPath);
validateUploadColumns(datasetPath, columns);
Object inputColsObj = featureMapConfig.get("input_cols");
List<String> inputCols = toStringList(inputColsObj);
if (inputCols == null || inputCols.isEmpty()) {
throw new BizException("feature_map_config.input_cols不能为空");
}
String targetCol = toNonBlankString(featureMapConfig.get("target_col"));
if (targetCol == null) {
throw new BizException("feature_map_config.target_col不能为空");
}
List<String> missing = new ArrayList<>();
for (String c : inputCols) {
if (c == null || c.isBlank()) continue;
if (!columns.contains(c)) missing.add(c);
}
if (!columns.contains(targetCol)) missing.add(targetCol);
if (!missing.isEmpty()) {
throw new BizException("数据集缺少列: " + String.join(", ", missing));
}
if (inputCols.contains(targetCol)) {
throw new BizException("target_col不能与input_cols重复: " + targetCol);
}
boolean derive = Boolean.TRUE.equals(featureMapConfig.get("derive"));
if (derive) {
Object rulesObj = featureMapConfig.get("derived_rules");
List<Map<String, Object>> rules = toMapList(rulesObj);
if (rules == null || rules.isEmpty()) {
throw new BizException("derive=true时derived_rules不能为空");
}
List<String> reservedNames = new ArrayList<>(inputCols);
reservedNames.add(targetCol);
Map<String, Integer> seen = new LinkedHashMap<>();
for (Map<String, Object> r : rules) {
String name = toNonBlankString(r.get("name"));
String expr = toNonBlankString(r.get("expr"));
if (name == null || expr == null) {
throw new BizException("derived_rules中每条规则的name/expr不能为空");
}
if (reservedNames.contains(name)) {
throw new BizException("derived_rules.name不能与input_cols/target_col重名: " + name);
}
int cnt = seen.getOrDefault(name, 0) + 1;
seen.put(name, cnt);
if (cnt > 1) {
throw new BizException("derived_rules.name重复: " + name);
}
if (!DERIVED_EXPR_ALLOWED_PATTERN.matcher(expr).matches()) {
throw new BizException("derived_rules.expr包含非法字符: " + name);
}
Matcher m = IDENTIFIER_PATTERN.matcher(expr);
List<String> vars = new ArrayList<>();
while (m.find()) {
vars.add(m.group());
}
List<String> unknown = new ArrayList<>();
for (String v : vars) {
if (!inputCols.contains(v)) {
unknown.add(v);
}
}
if (!unknown.isEmpty()) {
throw new BizException("derived_rules.expr包含未知变量必须来自input_cols: " + name + " -> " + String.join(", ", unknown));
}
}
}
}
private String toNonBlankString(Object v) {
if (v == null) return null;
String s = String.valueOf(v).trim();
return s.isEmpty() ? null : s;
}
private List<String> toStringList(Object v) {
if (v == null) return null;
if (v instanceof List<?> list) {
List<String> out = new ArrayList<>();
for (Object x : list) {
if (x == null) continue;
String s = String.valueOf(x).trim();
if (!s.isEmpty()) out.add(s);
}
return out;
}
return null;
}
private List<Map<String, Object>> toMapList(Object v) {
if (v == null) return null;
if (v instanceof List<?> list) {
List<Map<String, Object>> out = new ArrayList<>();
for (Object x : list) {
if (x instanceof Map<?, ?> m) {
Map<String, Object> mm = new HashMap<>();
for (Map.Entry<?, ?> e : m.entrySet()) {
if (e.getKey() != null) {
mm.put(String.valueOf(e.getKey()), e.getValue());
}
}
out.add(mm);
}
}
return out;
}
return null;
}
private List<String> parseExcelHeader(File file) {
DataFormatter df = new DataFormatter();
try (Workbook wb = WorkbookFactory.create(file)) {