模型训练代码调整

This commit is contained in:
wanxiaoli 2026-04-17 11:13:18 +08:00
parent bfb3d234eb
commit e1f770c795

View File

@ -21,6 +21,7 @@ import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.client.HttpStatusCodeException;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.apache.poi.ss.usermodel.DataFormatter; import org.apache.poi.ss.usermodel.DataFormatter;
@ -222,6 +223,7 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (fmMap == null || fmMap.isEmpty()) { if (fmMap == null || fmMap.isEmpty()) {
throw new BizException("feature_map_config不能为空"); throw new BizException("feature_map_config不能为空");
} }
normalizeFeatureMapConfig(fmMap);
request.put("feature_map_config", fmMap); request.put("feature_map_config", fmMap);
// System.out.println("request="+request.toString()); // System.out.println("request="+request.toString());
@ -264,6 +266,16 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
this.updateById(task); this.updateById(task);
} }
} catch (HttpStatusCodeException e) {
task.setStatus("Failed");
String body = null;
try {
body = e.getResponseBodyAsString();
} catch (Exception ignored) {
}
task.setErrorLog("调用 Python 服务异常: " + e.getStatusCode() + (body == null || body.isBlank() ? "" : (": " + body)));
this.updateById(task);
log.error("调用 Python 训练服务异常: {} {}", e.getStatusCode(), body, e);
} catch (Exception e) { } catch (Exception e) {
task.setStatus("Failed"); task.setStatus("Failed");
task.setErrorLog("调用 Python 服务异常: " + e.getMessage()); task.setErrorLog("调用 Python 服务异常: " + e.getMessage());
@ -272,6 +284,18 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
} }
} }
private void normalizeFeatureMapConfig(Map<String, Object> fmMap) {
if (fmMap == null || fmMap.isEmpty()) return;
if (!Boolean.TRUE.equals(fmMap.get("derive"))) return;
Object rulesObj = fmMap.get("derived_rules");
if (rulesObj instanceof String) {
List<Map<String, Object>> rules = toMapList(rulesObj);
if (rules != null) {
fmMap.put("derived_rules", rules);
}
}
}
private List<Map<String, Object>> validateUploadColumns(String datasetPath, List<String> columns) { private List<Map<String, Object>> validateUploadColumns(String datasetPath, List<String> columns) {
if (datasetPath == null || datasetPath.isBlank()) { if (datasetPath == null || datasetPath.isBlank()) {
throw new BizException("数据集路径不能为空"); throw new BizException("数据集路径不能为空");
@ -409,6 +433,15 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
private List<Map<String, Object>> toMapList(Object v) { private List<Map<String, Object>> toMapList(Object v) {
if (v == null) return null; if (v == null) return null;
if (v instanceof String s) {
String t = s.trim();
if (t.isEmpty()) return null;
try {
return objectMapper.readValue(t, new TypeReference<List<Map<String, Object>>>() {});
} catch (Exception ignored) {
return null;
}
}
if (v instanceof List<?> list) { if (v instanceof List<?> list) {
List<Map<String, Object>> out = new ArrayList<>(); List<Map<String, Object>> out = new ArrayList<>();
for (Object x : list) { for (Object x : list) {
@ -509,6 +542,7 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
} }
try { try {
normalizeCallbackFeatureMap(task, callbackData);
String status = (String) callbackData.get("status"); String status = (String) callbackData.get("status");
if (status != null) { if (status != null) {
// 转换状态为首字母大写 // 转换状态为首字母大写
@ -559,6 +593,74 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
} }
} }
private void normalizeCallbackFeatureMap(ModelTrainTask task, Map<String, Object> callbackData) {
if (callbackData == null || callbackData.isEmpty()) return;
Object fmObj = callbackData.get("feature_map");
if (fmObj == null) return;
Map<String, Object> featureMap = null;
if (fmObj instanceof String s) {
String t = s.trim();
if (!t.isEmpty()) {
try {
featureMap = objectMapper.readValue(t, new TypeReference<Map<String, Object>>() {});
} catch (Exception ignored) {
}
}
} else if (fmObj instanceof Map<?, ?> m) {
featureMap = new HashMap<>();
for (Map.Entry<?, ?> e : m.entrySet()) {
if (e.getKey() != null) {
featureMap.put(String.valueOf(e.getKey()), e.getValue());
}
}
}
if (featureMap == null || featureMap.isEmpty()) return;
normalizeFeatureMapConfig(featureMap);
boolean derive = Boolean.TRUE.equals(featureMap.get("derive"));
if (derive) {
Object rulesObj = featureMap.get("derived_rules");
List<Map<String, Object>> rules = toMapList(rulesObj);
if (rules == null || rules.isEmpty()) {
Map<String, Object> fmMap = null;
if (task.getFeatureMapConfig() != null && !task.getFeatureMapConfig().isBlank()) {
try {
fmMap = objectMapper.readValue(task.getFeatureMapConfig(), new TypeReference<Map<String, Object>>() {});
} catch (Exception ignored) {
}
}
if (fmMap != null && !fmMap.isEmpty()) {
normalizeFeatureMapConfig(fmMap);
rules = toMapList(fmMap.get("derived_rules"));
if (rules != null && !rules.isEmpty()) {
featureMap.put("derived_rules", rules);
}
}
} else {
featureMap.put("derived_rules", rules);
}
Object derivedColsObj = featureMap.get("derived_cols");
boolean hasDerivedCols = false;
if (derivedColsObj instanceof List<?> list) {
hasDerivedCols = !list.isEmpty();
}
if (!hasDerivedCols && rules != null && !rules.isEmpty()) {
List<String> derivedCols = new ArrayList<>();
for (Map<String, Object> r : rules) {
String name = toNonBlankString(r.get("name"));
if (name != null) derivedCols.add(name);
}
featureMap.put("derived_cols", derivedCols);
}
}
callbackData.put("feature_map", featureMap);
}
@Override @Override
@Transactional @Transactional
public boolean publishModel(String taskId, String versionTag) { public boolean publishModel(String taskId, String versionTag) {