diff --git a/business-css/pom.xml b/business-css/pom.xml index 9a76a68..1a1f02d 100644 --- a/business-css/pom.xml +++ b/business-css/pom.xml @@ -134,6 +134,79 @@ + + + + com.github.eirslett + frontend-maven-plugin + 1.15.0 + + + frontend + + + + + install node and npm + + install-node-and-npm + + generate-resources + + + v18.17.0 + 9.6.7 + + + + + npm install + + npm + + generate-resources + + install + + + + + npm run build + + npm + + generate-resources + + run build:mvn + + + + + + + + org.apache.maven.plugins + maven-resources-plugin + + + copy-frontend-dist + process-resources + + copy-resources + + + ${project.build.outputDirectory}/static + + + frontend/dist + false + + + + + + + diff --git a/business-css/src/main/java/com/yfd/business/css/config/WebSocketBrokerConfig.java b/business-css/src/main/java/com/yfd/business/css/config/WebSocketBrokerConfig.java new file mode 100644 index 0000000..da7f093 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/config/WebSocketBrokerConfig.java @@ -0,0 +1,28 @@ +package com.yfd.business.css.config; + +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; + +@Configuration +@EnableWebSocketMessageBroker +public class WebSocketBrokerConfig implements WebSocketMessageBrokerConfigurer { + + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + // 前端连接的端点,支持跨域 + registry.addEndpoint("/ws/train") + .setAllowedOriginPatterns("*") + .withSockJS(); + } + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + // 客户端订阅的路径前缀 + registry.enableSimpleBroker("/topic"); + // 客户端发送消息的路径前缀 + registry.setApplicationDestinationPrefixes("/app"); + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java index 3944edd..0b9645e 100644 --- a/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java +++ b/business-css/src/main/java/com/yfd/business/css/controller/ModelTrainController.java @@ -6,6 +6,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.yfd.business.css.domain.ModelTrainTask; import com.yfd.business.css.service.ModelTrainService; +import com.yfd.business.css.service.TrainWebSocketService; import com.yfd.platform.config.ResponseResult; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -21,9 +22,34 @@ public class ModelTrainController { @Autowired private ModelTrainService modelTrainService; + @Autowired + private TrainWebSocketService trainWebSocketService; + @Autowired private ObjectMapper objectMapper; + /** + * 接收 Python 端的训练状态回调 + */ + @PostMapping("/internal/callback") + public ResponseResult handleTrainCallback(@RequestBody Map callbackData) { + System.out.println("====== 收到 Python 端训练回调 ======"); + System.out.println("回调数据: " + callbackData); + + String taskId = (String) callbackData.get("taskId"); + if (taskId == null) { + return ResponseResult.error("taskId不能为空"); + } + + // 1. 更新数据库任务状态 + modelTrainService.updateTaskStatusFromCallback(taskId, callbackData); + + // 2. 触发 WebSocket 推送给前端 + trainWebSocketService.sendTrainStatus(taskId, callbackData); + + return ResponseResult.success(); + } + /** * 上传数据集 */ diff --git a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java index 289d121..4f16b03 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java +++ b/business-css/src/main/java/com/yfd/business/css/service/ModelTrainService.java @@ -26,6 +26,13 @@ public interface ModelTrainService extends IService { */ ModelTrainTask syncTaskStatus(String taskId); + /** + * 从回调中更新任务状态 + * @param taskId 任务ID + * @param callbackData 回调数据 + */ + void updateTaskStatusFromCallback(String taskId, java.util.Map callbackData); + /** * 发布模型 * @param taskId 任务ID diff --git a/business-css/src/main/java/com/yfd/business/css/service/TrainWebSocketService.java b/business-css/src/main/java/com/yfd/business/css/service/TrainWebSocketService.java new file mode 100644 index 0000000..e5a5937 --- /dev/null +++ b/business-css/src/main/java/com/yfd/business/css/service/TrainWebSocketService.java @@ -0,0 +1,32 @@ +package com.yfd.business.css.service; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.stereotype.Service; +import lombok.extern.slf4j.Slf4j; +import java.util.Map; + +@Slf4j +@Service +public class TrainWebSocketService { + + @Autowired + private SimpMessagingTemplate messagingTemplate; + + /** + * 向前端推送模型训练状态 + * + * @param taskId 任务ID + * @param data 状态数据 + */ + public void sendTrainStatus(String taskId, Map data) { + // 1. 细粒度推送(供详情页使用) + String specificDestination = "/topic/train-status/" + taskId; + messagingTemplate.convertAndSend(specificDestination, data); + + // 2. 全局广播推送(供列表页使用) + String globalDestination = "/topic/train-status/all"; + log.info("全局广播训练状态到 {}, 数据: {}", globalDestination, data); + messagingTemplate.convertAndSend(globalDestination, data); + } +} diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java index 0963ad6..33f2b56 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ModelTrainServiceImpl.java @@ -206,92 +206,67 @@ public class ModelTrainServiceImpl extends ServiceImpl response = restTemplate.getForEntity(url, Map.class); - - if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { - Map body = response.getBody(); - Object codeObj = body.get("code"); - Integer code = null; - if (codeObj instanceof Number) { - code = ((Number) codeObj).intValue(); - } else if (codeObj instanceof String) { - try { - code = Integer.parseInt((String) codeObj); - } catch (Exception ignored) { - } + // 由于改为 WebSocket 异步推送,这里简化为直接查库返回 + return task; + } + + @Override + public void updateTaskStatusFromCallback(String taskId, Map callbackData) { + ModelTrainTask task = this.getById(taskId); + if (task == null) { + log.warn("回调通知的任务不存在, taskId: {}", taskId); + return; + } + + try { + String status = (String) callbackData.get("status"); + if (status != null) { + // 转换状态为首字母大写 + if (status.equalsIgnoreCase("Success")) { + status = "Success"; + } else if (status.equalsIgnoreCase("Failed")) { + status = "Failed"; + } else if (status.equalsIgnoreCase("Training")) { + status = "Training"; + } else if (status.equalsIgnoreCase("Pending")) { + status = "Pending"; + } + + task.setStatus(status); + + if ("Success".equals(status)) { + String modelPathRaw = firstNonBlank( + (String) callbackData.get("model_path"), + (String) callbackData.get("model_relative_path") + ); + task.setModelOutputPath(normalizeModelPath(modelPathRaw)); + + // Map -> JSON String + if (callbackData.get("metrics") != null) { + task.setMetrics(objectMapper.writeValueAsString(callbackData.get("metrics"))); } - if (code != null && code != 0) { - task.setStatus("Failed"); - task.setErrorLog(String.valueOf(body.get("msg"))); - this.updateById(task); - return task; + if (callbackData.get("feature_map") != null) { + task.setFeatureMapSnapshot(objectMapper.writeValueAsString(callbackData.get("feature_map"))); } - Object dataObj = body.get("data"); - if (!(dataObj instanceof Map)) { - return task; - } - Map data = (Map) dataObj; - String status = (String) data.get("status"); - - if (status != null) { - // 转换状态为首字母大写 - if (status.equalsIgnoreCase("Success")) { - status = "Success"; - } else if (status.equalsIgnoreCase("Failed")) { - status = "Failed"; - } else if (status.equalsIgnoreCase("Training")) { - status = "Training"; - } else if (status.equalsIgnoreCase("Pending")) { - status = "Pending"; - } - - task.setStatus(status); - - if ("Success".equals(status)) { - String modelPathRaw = firstNonBlank( - (String) data.get("model_relative_path"), - (String) data.get("model_path_rel_project"), - (String) data.get("model_path") - ); - task.setModelOutputPath(normalizeModelPath(modelPathRaw)); - - // Map -> JSON String - if (data.get("metrics") != null) { - task.setMetrics(objectMapper.writeValueAsString(data.get("metrics"))); - } - if (data.get("feature_map") != null) { - task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map"))); - } - - String metricsPathRaw = firstNonBlank( - (String) data.get("metrics_image_relative_path"), - (String) data.get("metrics_image_rel_project"), - (String) data.get("metrics_image") - ); - task.setMetricsImagePath(normalizeModelPath(metricsPathRaw)); - } else if ("Failed".equals(status)) { - task.setErrorLog((String) data.get("error")); - } else if ("Training".equals(status)) { - if (data.get("metrics") != null) { - task.setMetrics(objectMapper.writeValueAsString(data.get("metrics"))); - } - } - - this.updateById(task); + String metricsPathRaw = firstNonBlank( + (String) callbackData.get("metrics_image"), + (String) callbackData.get("metrics_image_relative_path") + ); + task.setMetricsImagePath(normalizeModelPath(metricsPathRaw)); + } else if ("Failed".equals(status)) { + task.setErrorLog((String) callbackData.get("message")); + } else if ("Training".equals(status)) { + if (callbackData.get("metrics") != null) { + task.setMetrics(objectMapper.writeValueAsString(callbackData.get("metrics"))); } } - } catch (Exception e) { - log.error("同步任务状态失败: {}", e.getMessage(), e); + + this.updateById(task); } + } catch (Exception e) { + log.error("从回调更新任务状态失败: {}", e.getMessage(), e); } - - return task; } @Override diff --git a/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java b/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java index d90f254..b618945 100644 --- a/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java +++ b/business-css/src/main/java/com/yfd/business/css/service/impl/ProjectServiceImpl.java @@ -81,7 +81,7 @@ public class ProjectServiceImpl Sheet sheet = wb.createSheet("projects"); int r = 0; Row header = sheet.createRow(r++); - String[] cols = {"project_id","code","name","description","topology","created_at","updated_at","modifier"}; + String[] cols = {"项目id","项目编号","项目名称","项目描述","项目建模拓扑","创建时间","修改时间","创建人"}; for (int i = 0; i < cols.length; i++) header.createCell(i).setCellValue(cols[i]); DateTimeFormatter fmt = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); for (Project p : list) { diff --git a/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java b/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java index 4c9710b..fe6003d 100644 --- a/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java +++ b/framework/src/main/java/com/yfd/platform/config/SecurityConfig.java @@ -64,6 +64,7 @@ public class SecurityConfig { .requestMatchers(HttpMethod.GET, "/*.html", "/webSocket/**", + "/ws/**", "/assets/**", "/icon/**").permitAll() .requestMatchers( @@ -80,6 +81,7 @@ public class SecurityConfig { "/pageimage/**", "/avatar/**", "/systemurl/**", + "/train/internal/callback", "/api/imageserver/upload").permitAll() .anyRequest().authenticated(); }