From e1f770c79518c36b4efaf95d694684291f663cb7 Mon Sep 17 00:00:00 2001 From: wanxiaoli Date: Fri, 17 Apr 2026 11:13:18 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/impl/ModelTrainServiceImpl.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java index 2bedfcf..a62594b 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java @@ -21,6 +21,7 @@ import org.springframework.http.ResponseEntity; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.RestTemplate; import org.springframework.web.multipart.MultipartFile; import org.apache.poi.ss.usermodel.DataFormatter; @@ -222,6 +223,7 @@ public class ModelTrainServiceImpl extends ServiceImpl fmMap) { + if (fmMap == null || fmMap.isEmpty()) return; + if (!Boolean.TRUE.equals(fmMap.get("derive"))) return; + Object rulesObj = fmMap.get("derived_rules"); + if (rulesObj instanceof String) { + List> rules = toMapList(rulesObj); + if (rules != null) { + fmMap.put("derived_rules", rules); + } + } + } + private List> validateUploadColumns(String datasetPath, List columns) { if (datasetPath == null || datasetPath.isBlank()) { throw new BizException("数据集路径不能为空"); @@ -409,6 +433,15 @@ public class ModelTrainServiceImpl extends ServiceImpl> toMapList(Object v) { if (v == null) return null; + if (v instanceof String s) { + String t = s.trim(); + if (t.isEmpty()) return null; + try { + return objectMapper.readValue(t, new TypeReference>>() {}); + } catch (Exception ignored) { + return null; + } + } if (v instanceof List list) { List> out = new ArrayList<>(); for (Object x : list) { @@ -509,6 +542,7 @@ public class ModelTrainServiceImpl extends ServiceImpl callbackData) { + if (callbackData == null || callbackData.isEmpty()) return; + + Object fmObj = callbackData.get("feature_map"); + if (fmObj == null) return; + + Map featureMap = null; + if (fmObj instanceof String s) { + String t = s.trim(); + if (!t.isEmpty()) { + try { + featureMap = objectMapper.readValue(t, new TypeReference>() {}); + } catch (Exception ignored) { + } + } + } else if (fmObj instanceof Map m) { + featureMap = new HashMap<>(); + for (Map.Entry e : m.entrySet()) { + if (e.getKey() != null) { + featureMap.put(String.valueOf(e.getKey()), e.getValue()); + } + } + } + if (featureMap == null || featureMap.isEmpty()) return; + + normalizeFeatureMapConfig(featureMap); + + boolean derive = Boolean.TRUE.equals(featureMap.get("derive")); + if (derive) { + Object rulesObj = featureMap.get("derived_rules"); + List> rules = toMapList(rulesObj); + if (rules == null || rules.isEmpty()) { + Map fmMap = null; + if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) { + try { + fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference>() {}); + } catch (Exception ignored) { + } + } + if (fmMap != null && !fmMap.isEmpty()) { + normalizeFeatureMapConfig(fmMap); + rules = toMapList(fmMap.get("derived_rules")); + if (rules != null && !rules.isEmpty()) { + featureMap.put("derived_rules", rules); + } + } + } else { + featureMap.put("derived_rules", rules); + } + + Object derivedColsObj = featureMap.get("derived_cols"); + boolean hasDerivedCols = false; + if (derivedColsObj instanceof List list) { + hasDerivedCols = !list.isEmpty(); + } + if (!hasDerivedCols && rules != null && !rules.isEmpty()) { + List derivedCols = new ArrayList<>(); + for (Map r : rules) { + String name = toNonBlankString(r.get("name")); + if (name != null) derivedCols.add(name); + } + featureMap.put("derived_cols", derivedCols); + } + } + + callbackData.put("feature_map", featureMap); + } + @Override @Transactional public boolean publishModel(String taskId, String versionTag) {