模型训练添加数据集校验
This commit is contained in:
parent
8395224f42
commit
bfb3d234eb
@ -57,11 +57,7 @@ public class ModelTrainController {
|
||||
@Log(value = "上传训练数据集", module = "模型训练")
|
||||
@PostMapping("/upload")
|
||||
public ResponseResult upload(@RequestParam("file") MultipartFile file) {
|
||||
String path = modelTrainService.uploadDataset(file);
|
||||
return ResponseResult.successData(Map.of(
|
||||
"path", path,
|
||||
"columns", modelTrainService.parseDatasetColumns(path)
|
||||
));
|
||||
return ResponseResult.successData(modelTrainService.uploadAndInspectDataset(file));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -90,6 +86,8 @@ public class ModelTrainController {
|
||||
String taskId = modelTrainService.submitTask(task);
|
||||
// System.out.println("提交任务成功,任务ID: " + taskId);
|
||||
return ResponseResult.successData(taskId);
|
||||
} catch (com.yfd.business.css.common.exception.BizException e) {
|
||||
return ResponseResult.error(e.getMessage());
|
||||
} catch (JsonProcessingException e) {
|
||||
return ResponseResult.error("参数解析失败: " + e.getMessage());
|
||||
} catch (Exception e) {
|
||||
|
||||
@ -5,6 +5,7 @@ import com.yfd.business.css.domain.ModelTrainTask;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface ModelTrainService extends IService<ModelTrainTask> {
|
||||
/**
|
||||
@ -16,6 +17,8 @@ public interface ModelTrainService extends IService<ModelTrainTask> {
|
||||
|
||||
List<String> parseDatasetColumns(String datasetPath);
|
||||
|
||||
Map<String, Object> uploadAndInspectDataset(MultipartFile file);
|
||||
|
||||
/**
|
||||
* 提交训练任务
|
||||
* @param task 任务信息
|
||||
|
||||
@ -41,12 +41,14 @@ import java.nio.file.StandardCopyOption;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.regex.Matcher;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@ -71,6 +73,8 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
private static final Pattern VERSION_TAG_PATTERN = Pattern.compile("^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$");
|
||||
private static final Pattern IDENTIFIER_PATTERN = Pattern.compile("[A-Za-z_][A-Za-z0-9_]*");
|
||||
private static final Pattern DERIVED_EXPR_ALLOWED_PATTERN = Pattern.compile("^[A-Za-z0-9_+\\-*/()\\s]+$");
|
||||
|
||||
@Override
|
||||
public String uploadDataset(MultipartFile file) {
|
||||
@ -138,6 +142,18 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
throw new BizException("不支持的数据集格式: " + p.getFileName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> uploadAndInspectDataset(MultipartFile file) {
|
||||
String path = uploadDataset(file);
|
||||
List<String> columns = parseDatasetColumns(path);
|
||||
List<Map<String, Object>> warnings = validateUploadColumns(path, columns);
|
||||
return Map.of(
|
||||
"path", path,
|
||||
"columns", columns,
|
||||
"warnings", warnings
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
@Transactional
|
||||
public String submitTask(ModelTrainTask task) {
|
||||
@ -146,13 +162,18 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
if (task.getTaskId() == null) {
|
||||
task.setTaskId(UUID.randomUUID().toString());
|
||||
}
|
||||
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
|
||||
if (task.getFeatureMapConfig() == null || task.getFeatureMapConfig().isBlank()) {
|
||||
throw new BizException("feature_map_config不能为空");
|
||||
}
|
||||
|
||||
Map<String, Object> fmMap;
|
||||
try {
|
||||
objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
|
||||
} catch (Exception ignored) {
|
||||
fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
|
||||
} catch (Exception e) {
|
||||
throw new BizException("feature_map_config JSON解析失败");
|
||||
}
|
||||
}
|
||||
validateDatasetForSubmit(task.getDatasetPath(), fmMap);
|
||||
|
||||
this.save(task);
|
||||
|
||||
// 2. 异步调用 Python 训练
|
||||
@ -251,6 +272,161 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
|
||||
}
|
||||
}
|
||||
|
||||
private List<Map<String, Object>> validateUploadColumns(String datasetPath, List<String> columns) {
|
||||
if (datasetPath == null || datasetPath.isBlank()) {
|
||||
throw new BizException("数据集路径不能为空");
|
||||
}
|
||||
Path p = Paths.get(datasetPath);
|
||||
String lower = p.getFileName().toString().toLowerCase();
|
||||
if (!(lower.endsWith(".xlsx") || lower.endsWith(".xls") || lower.endsWith(".csv"))) {
|
||||
throw new BizException("不支持的数据集格式: " + p.getFileName());
|
||||
}
|
||||
if (columns == null || columns.isEmpty()) {
|
||||
throw new BizException("表头为空,请检查数据集第一行是否为表头");
|
||||
}
|
||||
|
||||
Map<String, Integer> seen = new LinkedHashMap<>();
|
||||
List<String> duplicates = new ArrayList<>();
|
||||
for (String c : columns) {
|
||||
String k = c == null ? "" : c.trim();
|
||||
if (k.isEmpty()) continue;
|
||||
int cnt = seen.getOrDefault(k, 0) + 1;
|
||||
seen.put(k, cnt);
|
||||
if (cnt == 2) duplicates.add(k);
|
||||
}
|
||||
if (!duplicates.isEmpty()) {
|
||||
throw new BizException("表头列名重复: " + String.join(", ", duplicates));
|
||||
}
|
||||
|
||||
List<Map<String, Object>> warnings = new ArrayList<>();
|
||||
if (lower.endsWith(".csv") && columns.size() == 1) {
|
||||
warnings.add(Map.of(
|
||||
"code", "SUSPICIOUS_CSV",
|
||||
"message", "CSV 可能不是逗号分隔,表头仅 1 列"
|
||||
));
|
||||
}
|
||||
return warnings;
|
||||
}
|
||||
|
||||
private void validateDatasetForSubmit(String datasetPath, Map<String, Object> featureMapConfig) {
|
||||
if (datasetPath == null || datasetPath.isBlank()) {
|
||||
throw new BizException("数据集路径不能为空,请上传文件或指定路径");
|
||||
}
|
||||
List<String> columns = parseDatasetColumns(datasetPath);
|
||||
validateUploadColumns(datasetPath, columns);
|
||||
|
||||
Object inputColsObj = featureMapConfig.get("input_cols");
|
||||
List<String> inputCols = toStringList(inputColsObj);
|
||||
if (inputCols == null || inputCols.isEmpty()) {
|
||||
throw new BizException("feature_map_config.input_cols不能为空");
|
||||
}
|
||||
|
||||
String targetCol = toNonBlankString(featureMapConfig.get("target_col"));
|
||||
if (targetCol == null) {
|
||||
throw new BizException("feature_map_config.target_col不能为空");
|
||||
}
|
||||
|
||||
List<String> missing = new ArrayList<>();
|
||||
for (String c : inputCols) {
|
||||
if (c == null || c.isBlank()) continue;
|
||||
if (!columns.contains(c)) missing.add(c);
|
||||
}
|
||||
if (!columns.contains(targetCol)) missing.add(targetCol);
|
||||
if (!missing.isEmpty()) {
|
||||
throw new BizException("数据集缺少列: " + String.join(", ", missing));
|
||||
}
|
||||
if (inputCols.contains(targetCol)) {
|
||||
throw new BizException("target_col不能与input_cols重复: " + targetCol);
|
||||
}
|
||||
|
||||
boolean derive = Boolean.TRUE.equals(featureMapConfig.get("derive"));
|
||||
if (derive) {
|
||||
Object rulesObj = featureMapConfig.get("derived_rules");
|
||||
List<Map<String, Object>> rules = toMapList(rulesObj);
|
||||
if (rules == null || rules.isEmpty()) {
|
||||
throw new BizException("derive=true时,derived_rules不能为空");
|
||||
}
|
||||
|
||||
List<String> reservedNames = new ArrayList<>(inputCols);
|
||||
reservedNames.add(targetCol);
|
||||
Map<String, Integer> seen = new LinkedHashMap<>();
|
||||
|
||||
for (Map<String, Object> r : rules) {
|
||||
String name = toNonBlankString(r.get("name"));
|
||||
String expr = toNonBlankString(r.get("expr"));
|
||||
if (name == null || expr == null) {
|
||||
throw new BizException("derived_rules中每条规则的name/expr不能为空");
|
||||
}
|
||||
if (reservedNames.contains(name)) {
|
||||
throw new BizException("derived_rules.name不能与input_cols/target_col重名: " + name);
|
||||
}
|
||||
int cnt = seen.getOrDefault(name, 0) + 1;
|
||||
seen.put(name, cnt);
|
||||
if (cnt > 1) {
|
||||
throw new BizException("derived_rules.name重复: " + name);
|
||||
}
|
||||
if (!DERIVED_EXPR_ALLOWED_PATTERN.matcher(expr).matches()) {
|
||||
throw new BizException("derived_rules.expr包含非法字符: " + name);
|
||||
}
|
||||
|
||||
Matcher m = IDENTIFIER_PATTERN.matcher(expr);
|
||||
List<String> vars = new ArrayList<>();
|
||||
while (m.find()) {
|
||||
vars.add(m.group());
|
||||
}
|
||||
List<String> unknown = new ArrayList<>();
|
||||
for (String v : vars) {
|
||||
if (!inputCols.contains(v)) {
|
||||
unknown.add(v);
|
||||
}
|
||||
}
|
||||
if (!unknown.isEmpty()) {
|
||||
throw new BizException("derived_rules.expr包含未知变量(必须来自input_cols): " + name + " -> " + String.join(", ", unknown));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String toNonBlankString(Object v) {
|
||||
if (v == null) return null;
|
||||
String s = String.valueOf(v).trim();
|
||||
return s.isEmpty() ? null : s;
|
||||
}
|
||||
|
||||
private List<String> toStringList(Object v) {
|
||||
if (v == null) return null;
|
||||
if (v instanceof List<?> list) {
|
||||
List<String> out = new ArrayList<>();
|
||||
for (Object x : list) {
|
||||
if (x == null) continue;
|
||||
String s = String.valueOf(x).trim();
|
||||
if (!s.isEmpty()) out.add(s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private List<Map<String, Object>> toMapList(Object v) {
|
||||
if (v == null) return null;
|
||||
if (v instanceof List<?> list) {
|
||||
List<Map<String, Object>> out = new ArrayList<>();
|
||||
for (Object x : list) {
|
||||
if (x instanceof Map<?, ?> m) {
|
||||
Map<String, Object> mm = new HashMap<>();
|
||||
for (Map.Entry<?, ?> e : m.entrySet()) {
|
||||
if (e.getKey() != null) {
|
||||
mm.put(String.valueOf(e.getKey()), e.getValue());
|
||||
}
|
||||
}
|
||||
out.add(mm);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private List<String> parseExcelHeader(File file) {
|
||||
DataFormatter df = new DataFormatter();
|
||||
try (Workbook wb = WorkbookFactory.create(file)) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user