diff --git a/business-css/pom.xml b/business-css/pom.xml
index 9a76a68..1a1f02d 100644
--- a/business-css/pom.xml
+++ b/business-css/pom.xml
@@ -134,6 +134,79 @@
+
+
+
+ com.github.eirslett
+ frontend-maven-plugin
+ 1.15.0
+
+
+ frontend
+
+
+
+
+ install node and npm
+
+ install-node-and-npm
+
+ generate-resources
+
+
+ v18.17.0
+ 9.6.7
+
+
+
+
+ npm install
+
+ npm
+
+ generate-resources
+
+ install
+
+
+
+
+ npm run build
+
+ npm
+
+ generate-resources
+
+ run build:mvn
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-resources-plugin
+
+
+ copy-frontend-dist
+ process-resources
+
+ copy-resources
+
+
+ ${project.build.outputDirectory}/static
+
+
+ frontend/dist
+ false
+
+
+
+
+
+
+
diff --git a/business-css/src/main/java/com/yfd/business/css/config/WebSocketBrokerConfig.java b/business-css/src/main/java/com/yfd/business/css/config/WebSocketBrokerConfig.java
new file mode 100644
index 0000000..da7f093
--- /dev/null
+++ b/business-css/src/main/java/com/yfd/business/css/config/WebSocketBrokerConfig.java
@@ -0,0 +1,28 @@
+package com.yfd.business.css.config;
+
+import org.springframework.context.annotation.Configuration;
+import org.springframework.messaging.simp.config.MessageBrokerRegistry;
+import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
+import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
+import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
+
+@Configuration
+@EnableWebSocketMessageBroker
+public class WebSocketBrokerConfig implements WebSocketMessageBrokerConfigurer {
+
+ @Override
+ public void registerStompEndpoints(StompEndpointRegistry registry) {
+ // 前端连接的端点,支持跨域
+ registry.addEndpoint("/ws/train")
+ .setAllowedOriginPatterns("*")
+ .withSockJS();
+ }
+
+ @Override
+ public void configureMessageBroker(MessageBrokerRegistry registry) {
+ // 客户端订阅的路径前缀
+ registry.enableSimpleBroker("/topic");
+ // 客户端发送消息的路径前缀
+ registry.setApplicationDestinationPrefixes("/app");
+ }
+}
diff --git a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java
index 3944edd..0b9645e 100644
--- a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java
+++ b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java
@@ -6,6 +6,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.domain.ModelTrainTask;
import com.yfd.business.css.service.ModelTrainService;
+import com.yfd.business.css.service.TrainWebSocketService;
import com.yfd.platform.config.ResponseResult;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@@ -21,9 +22,34 @@ public class ModelTrainController {
@Autowired
private ModelTrainService modelTrainService;
+ @Autowired
+ private TrainWebSocketService trainWebSocketService;
+
@Autowired
private ObjectMapper objectMapper;
+ /**
+ * 接收 Python 端的训练状态回调
+ */
+ @PostMapping("/internal/callback")
+ public ResponseResult handleTrainCallback(@RequestBody Map callbackData) {
+ System.out.println("====== 收到 Python 端训练回调 ======");
+ System.out.println("回调数据: " + callbackData);
+
+ String taskId = (String) callbackData.get("taskId");
+ if (taskId == null) {
+ return ResponseResult.error("taskId不能为空");
+ }
+
+ // 1. 更新数据库任务状态
+ modelTrainService.updateTaskStatusFromCallback(taskId, callbackData);
+
+ // 2. 触发 WebSocket 推送给前端
+ trainWebSocketService.sendTrainStatus(taskId, callbackData);
+
+ return ResponseResult.success();
+ }
+
/**
* 上传数据集
*/
diff --git a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java
index 289d121..4f16b03 100644
--- a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java
+++ b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java
@@ -26,6 +26,13 @@ public interface ModelTrainService extends IService {
*/
ModelTrainTask syncTaskStatus(String taskId);
+ /**
+ * 从回调中更新任务状态
+ * @param taskId 任务ID
+ * @param callbackData 回调数据
+ */
+ void updateTaskStatusFromCallback(String taskId, java.util.Map callbackData);
+
/**
* 发布模型
* @param taskId 任务ID
diff --git a/business-css/src/main/java/com/yfd/business/css/service/TrainWebSocketService.java b/business-css/src/main/java/com/yfd/business/css/service/TrainWebSocketService.java
new file mode 100644
index 0000000..e5a5937
--- /dev/null
+++ b/business-css/src/main/java/com/yfd/business/css/service/TrainWebSocketService.java
@@ -0,0 +1,32 @@
+package com.yfd.business.css.service;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.messaging.simp.SimpMessagingTemplate;
+import org.springframework.stereotype.Service;
+import lombok.extern.slf4j.Slf4j;
+import java.util.Map;
+
+@Slf4j
+@Service
+public class TrainWebSocketService {
+
+ @Autowired
+ private SimpMessagingTemplate messagingTemplate;
+
+ /**
+ * 向前端推送模型训练状态
+ *
+ * @param taskId 任务ID
+ * @param data 状态数据
+ */
+ public void sendTrainStatus(String taskId, Map data) {
+ // 1. 细粒度推送(供详情页使用)
+ String specificDestination = "/topic/train-status/" + taskId;
+ messagingTemplate.convertAndSend(specificDestination, data);
+
+ // 2. 全局广播推送(供列表页使用)
+ String globalDestination = "/topic/train-status/all";
+ log.info("全局广播训练状态到 {}, 数据: {}", globalDestination, data);
+ messagingTemplate.convertAndSend(globalDestination, data);
+ }
+}
diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java
index 0963ad6..33f2b56 100644
--- a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java
+++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java
@@ -206,92 +206,67 @@ public class ModelTrainServiceImpl extends ServiceImpl response = restTemplate.getForEntity(url, Map.class);
-
- if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
- Map body = response.getBody();
- Object codeObj = body.get("code");
- Integer code = null;
- if (codeObj instanceof Number) {
- code = ((Number) codeObj).intValue();
- } else if (codeObj instanceof String) {
- try {
- code = Integer.parseInt((String) codeObj);
- } catch (Exception ignored) {
- }
+ // 由于改为 WebSocket 异步推送,这里简化为直接查库返回
+ return task;
+ }
+
+ @Override
+ public void updateTaskStatusFromCallback(String taskId, Map callbackData) {
+ ModelTrainTask task = this.getById(taskId);
+ if (task == null) {
+ log.warn("回调通知的任务不存在, taskId: {}", taskId);
+ return;
+ }
+
+ try {
+ String status = (String) callbackData.get("status");
+ if (status != null) {
+ // 转换状态为首字母大写
+ if (status.equalsIgnoreCase("Success")) {
+ status = "Success";
+ } else if (status.equalsIgnoreCase("Failed")) {
+ status = "Failed";
+ } else if (status.equalsIgnoreCase("Training")) {
+ status = "Training";
+ } else if (status.equalsIgnoreCase("Pending")) {
+ status = "Pending";
+ }
+
+ task.setStatus(status);
+
+ if ("Success".equals(status)) {
+ String modelPathRaw = firstNonBlank(
+ (String) callbackData.get("model_path"),
+ (String) callbackData.get("model_relative_path")
+ );
+ task.setModelOutputPath(normalizeModelPath(modelPathRaw));
+
+ // Map -> JSON String
+ if (callbackData.get("metrics") != null) {
+ task.setMetrics(objectMapper.writeValueAsString(callbackData.get("metrics")));
}
- if (code != null && code != 0) {
- task.setStatus("Failed");
- task.setErrorLog(String.valueOf(body.get("msg")));
- this.updateById(task);
- return task;
+ if (callbackData.get("feature_map") != null) {
+ task.setFeatureMapSnapshot(objectMapper.writeValueAsString(callbackData.get("feature_map")));
}
- Object dataObj = body.get("data");
- if (!(dataObj instanceof Map)) {
- return task;
- }
- Map data = (Map) dataObj;
- String status = (String) data.get("status");
-
- if (status != null) {
- // 转换状态为首字母大写
- if (status.equalsIgnoreCase("Success")) {
- status = "Success";
- } else if (status.equalsIgnoreCase("Failed")) {
- status = "Failed";
- } else if (status.equalsIgnoreCase("Training")) {
- status = "Training";
- } else if (status.equalsIgnoreCase("Pending")) {
- status = "Pending";
- }
-
- task.setStatus(status);
-
- if ("Success".equals(status)) {
- String modelPathRaw = firstNonBlank(
- (String) data.get("model_relative_path"),
- (String) data.get("model_path_rel_project"),
- (String) data.get("model_path")
- );
- task.setModelOutputPath(normalizeModelPath(modelPathRaw));
-
- // Map -> JSON String
- if (data.get("metrics") != null) {
- task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
- }
- if (data.get("feature_map") != null) {
- task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map")));
- }
-
- String metricsPathRaw = firstNonBlank(
- (String) data.get("metrics_image_relative_path"),
- (String) data.get("metrics_image_rel_project"),
- (String) data.get("metrics_image")
- );
- task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
- } else if ("Failed".equals(status)) {
- task.setErrorLog((String) data.get("error"));
- } else if ("Training".equals(status)) {
- if (data.get("metrics") != null) {
- task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
- }
- }
-
- this.updateById(task);
+ String metricsPathRaw = firstNonBlank(
+ (String) callbackData.get("metrics_image"),
+ (String) callbackData.get("metrics_image_relative_path")
+ );
+ task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
+ } else if ("Failed".equals(status)) {
+ task.setErrorLog((String) callbackData.get("message"));
+ } else if ("Training".equals(status)) {
+ if (callbackData.get("metrics") != null) {
+ task.setMetrics(objectMapper.writeValueAsString(callbackData.get("metrics")));
}
}
- } catch (Exception e) {
- log.error("同步任务状态失败: {}", e.getMessage(), e);
+
+ this.updateById(task);
}
+ } catch (Exception e) {
+ log.error("从回调更新任务状态失败: {}", e.getMessage(), e);
}
-
- return task;
}
@Override
diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java
index d90f254..b618945 100644
--- a/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java
+++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java
@@ -81,7 +81,7 @@ public class ProjectServiceImpl
Sheet sheet = wb.createSheet("projects");
int r = 0;
Row header = sheet.createRow(r++);
- String[] cols = {"project_id","code","name","description","topology","created_at","updated_at","modifier"};
+ String[] cols = {"项目id","项目编号","项目名称","项目描述","项目建模拓扑","创建时间","修改时间","创建人"};
for (int i = 0; i < cols.length; i++) header.createCell(i).setCellValue(cols[i]);
DateTimeFormatter fmt = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
for (Project p : list) {
diff --git a/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java b/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java
index 4c9710b..fe6003d 100644
--- a/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java
+++ b/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java
@@ -64,6 +64,7 @@ public class SecurityConfig {
.requestMatchers(HttpMethod.GET,
"/*.html",
"/webSocket/**",
+ "/ws/**",
"/assets/**",
"/icon/**").permitAll()
.requestMatchers(
@@ -80,6 +81,7 @@ public class SecurityConfig {
"/pageimage/**",
"/avatar/**",
"/systemurl/**",
+ "/train/internal/callback",
"/api/imageserver/upload").permitAll()
.anyRequest().authenticated();
}