后端更改

This commit is contained in:
wanxiaoli 2026-03-26 16:55:28 +08:00
parent 2800d02243
commit b38ea4ad65
8 changed files with 223 additions and 80 deletions

View File

@ -134,6 +134,79 @@
</annotationProcessorPaths> </annotationProcessorPaths>
</configuration> </configuration>
</plugin> </plugin>
<!-- frontend-maven-plugin 插件:用于在 Maven 构建时自动编译前端并拷贝到 static 目录 -->
<plugin>
<groupId>com.github.eirslett</groupId>
<artifactId>frontend-maven-plugin</artifactId>
<version>1.15.0</version>
<!-- 前端代码所在的目录 -->
<configuration>
<workingDirectory>frontend</workingDirectory>
</configuration>
<executions>
<!-- 1. 安装 Node 和 npm -->
<execution>
<id>install node and npm</id>
<goals>
<goal>install-node-and-npm</goal>
</goals>
<phase>generate-resources</phase>
<configuration>
<!-- 可以根据您机器上的实际 node 版本修改 -->
<nodeVersion>v18.17.0</nodeVersion>
<npmVersion>9.6.7</npmVersion>
</configuration>
</execution>
<!-- 2. 安装依赖 (npm install) -->
<execution>
<id>npm install</id>
<goals>
<goal>npm</goal>
</goals>
<phase>generate-resources</phase>
<configuration>
<arguments>install</arguments>
</configuration>
</execution>
<!-- 3. 执行前端构建命令 (npm run build:mvn) -->
<execution>
<id>npm run build</id>
<goals>
<goal>npm</goal>
</goals>
<phase>generate-resources</phase>
<configuration>
<arguments>run build:mvn</arguments>
</configuration>
</execution>
</executions>
</plugin>
<!-- maven-resources-plugin将前端构建的 dist 目录内容拷贝到 target/classes/static -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<executions>
<execution>
<id>copy-frontend-dist</id>
<phase>process-resources</phase>
<goals>
<goal>copy-resources</goal>
</goals>
<configuration>
<outputDirectory>${project.build.outputDirectory}/static</outputDirectory>
<resources>
<resource>
<directory>frontend/dist</directory>
<filtering>false</filtering>
</resource>
</resources>
</configuration>
</execution>
</executions>
</plugin>
</plugins> </plugins>
</build> </build>

View File

@ -0,0 +1,28 @@
package com.yfd.business.css.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketBrokerConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
// 前端连接的端点支持跨域
registry.addEndpoint("/ws/train")
.setAllowedOriginPatterns("*")
.withSockJS();
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
// 客户端订阅的路径前缀
registry.enableSimpleBroker("/topic");
// 客户端发送消息的路径前缀
registry.setApplicationDestinationPrefixes("/app");
}
}

View File

@ -6,6 +6,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.yfd.business.css.domain.ModelTrainTask; import com.yfd.business.css.domain.ModelTrainTask;
import com.yfd.business.css.service.ModelTrainService; import com.yfd.business.css.service.ModelTrainService;
import com.yfd.business.css.service.TrainWebSocketService;
import com.yfd.platform.config.ResponseResult; import com.yfd.platform.config.ResponseResult;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@ -21,9 +22,34 @@ public class ModelTrainController {
@Autowired @Autowired
private ModelTrainService modelTrainService; private ModelTrainService modelTrainService;
@Autowired
private TrainWebSocketService trainWebSocketService;
@Autowired @Autowired
private ObjectMapper objectMapper; private ObjectMapper objectMapper;
/**
* 接收 Python 端的训练状态回调
*/
@PostMapping("/internal/callback")
public ResponseResult handleTrainCallback(@RequestBody Map<String, Object> callbackData) {
System.out.println("====== 收到 Python 端训练回调 ======");
System.out.println("回调数据: " + callbackData);
String taskId = (String) callbackData.get("taskId");
if (taskId == null) {
return ResponseResult.error("taskId不能为空");
}
// 1. 更新数据库任务状态
modelTrainService.updateTaskStatusFromCallback(taskId, callbackData);
// 2. 触发 WebSocket 推送给前端
trainWebSocketService.sendTrainStatus(taskId, callbackData);
return ResponseResult.success();
}
/** /**
* 上传数据集 * 上传数据集
*/ */

View File

@ -26,6 +26,13 @@ public interface ModelTrainService extends IService<ModelTrainTask> {
*/ */
ModelTrainTask syncTaskStatus(String taskId); ModelTrainTask syncTaskStatus(String taskId);
/**
* 从回调中更新任务状态
* @param taskId 任务ID
* @param callbackData 回调数据
*/
void updateTaskStatusFromCallback(String taskId, java.util.Map<String, Object> callbackData);
/** /**
* 发布模型 * 发布模型
* @param taskId 任务ID * @param taskId 任务ID

View File

@ -0,0 +1,32 @@
package com.yfd.business.css.service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
@Slf4j
@Service
public class TrainWebSocketService {
@Autowired
private SimpMessagingTemplate messagingTemplate;
/**
* 向前端推送模型训练状态
*
* @param taskId 任务ID
* @param data 状态数据
*/
public void sendTrainStatus(String taskId, Map<String, Object> data) {
// 1. 细粒度推送供详情页使用
String specificDestination = "/topic/train-status/" + taskId;
messagingTemplate.convertAndSend(specificDestination, data);
// 2. 全局广播推送供列表页使用
String globalDestination = "/topic/train-status/all";
log.info("全局广播训练状态到 {}, 数据: {}", globalDestination, data);
messagingTemplate.convertAndSend(globalDestination, data);
}
}

View File

@ -206,92 +206,67 @@ public class ModelTrainServiceImpl extends ServiceImpl<ModelTrainTaskMapper, Mod
if (task == null) { if (task == null) {
throw new BizException("任务不存在"); throw new BizException("任务不存在");
} }
// 由于改为 WebSocket 异步推送这里简化为直接查库返回
// 只有在 Training Pending 状态才去查询 Python 服务 return task;
if ("Training".equalsIgnoreCase(task.getStatus()) || "Pending".equalsIgnoreCase(task.getStatus())) { }
try {
String url = pythonApiUrl + "/v1/train/status/" + taskId; @Override
ResponseEntity<Map> response = restTemplate.getForEntity(url, Map.class); public void updateTaskStatusFromCallback(String taskId, Map<String, Object> callbackData) {
ModelTrainTask task = this.getById(taskId);
if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { if (task == null) {
Map<String, Object> body = response.getBody(); log.warn("回调通知的任务不存在, taskId: {}", taskId);
Object codeObj = body.get("code"); return;
Integer code = null; }
if (codeObj instanceof Number) {
code = ((Number) codeObj).intValue(); try {
} else if (codeObj instanceof String) { String status = (String) callbackData.get("status");
try { if (status != null) {
code = Integer.parseInt((String) codeObj); // 转换状态为首字母大写
} catch (Exception ignored) { if (status.equalsIgnoreCase("Success")) {
} status = "Success";
} else if (status.equalsIgnoreCase("Failed")) {
status = "Failed";
} else if (status.equalsIgnoreCase("Training")) {
status = "Training";
} else if (status.equalsIgnoreCase("Pending")) {
status = "Pending";
}
task.setStatus(status);
if ("Success".equals(status)) {
String modelPathRaw = firstNonBlank(
(String) callbackData.get("model_path"),
(String) callbackData.get("model_relative_path")
);
task.setModelOutputPath(normalizeModelPath(modelPathRaw));
// Map -> JSON String
if (callbackData.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(callbackData.get("metrics")));
} }
if (code != null && code != 0) { if (callbackData.get("feature_map") != null) {
task.setStatus("Failed"); task.setFeatureMapSnapshot(objectMapper.writeValueAsString(callbackData.get("feature_map")));
task.setErrorLog(String.valueOf(body.get("msg")));
this.updateById(task);
return task;
} }
Object dataObj = body.get("data"); String metricsPathRaw = firstNonBlank(
if (!(dataObj instanceof Map)) { (String) callbackData.get("metrics_image"),
return task; (String) callbackData.get("metrics_image_relative_path")
} );
Map<String, Object> data = (Map<String, Object>) dataObj; task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
String status = (String) data.get("status"); } else if ("Failed".equals(status)) {
task.setErrorLog((String) callbackData.get("message"));
if (status != null) { } else if ("Training".equals(status)) {
// 转换状态为首字母大写 if (callbackData.get("metrics") != null) {
if (status.equalsIgnoreCase("Success")) { task.setMetrics(objectMapper.writeValueAsString(callbackData.get("metrics")));
status = "Success";
} else if (status.equalsIgnoreCase("Failed")) {
status = "Failed";
} else if (status.equalsIgnoreCase("Training")) {
status = "Training";
} else if (status.equalsIgnoreCase("Pending")) {
status = "Pending";
}
task.setStatus(status);
if ("Success".equals(status)) {
String modelPathRaw = firstNonBlank(
(String) data.get("model_relative_path"),
(String) data.get("model_path_rel_project"),
(String) data.get("model_path")
);
task.setModelOutputPath(normalizeModelPath(modelPathRaw));
// Map -> JSON String
if (data.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
}
if (data.get("feature_map") != null) {
task.setFeatureMapSnapshot(objectMapper.writeValueAsString(data.get("feature_map")));
}
String metricsPathRaw = firstNonBlank(
(String) data.get("metrics_image_relative_path"),
(String) data.get("metrics_image_rel_project"),
(String) data.get("metrics_image")
);
task.setMetricsImagePath(normalizeModelPath(metricsPathRaw));
} else if ("Failed".equals(status)) {
task.setErrorLog((String) data.get("error"));
} else if ("Training".equals(status)) {
if (data.get("metrics") != null) {
task.setMetrics(objectMapper.writeValueAsString(data.get("metrics")));
}
}
this.updateById(task);
} }
} }
} catch (Exception e) {
log.error("同步任务状态失败: {}", e.getMessage(), e); this.updateById(task);
} }
} catch (Exception e) {
log.error("从回调更新任务状态失败: {}", e.getMessage(), e);
} }
return task;
} }
@Override @Override

View File

@ -81,7 +81,7 @@ public class ProjectServiceImpl
Sheet sheet = wb.createSheet("projects"); Sheet sheet = wb.createSheet("projects");
int r = 0; int r = 0;
Row header = sheet.createRow(r++); Row header = sheet.createRow(r++);
String[] cols = {"project_id","code","name","description","topology","created_at","updated_at","modifier"}; String[] cols = {"项目id","项目编号","项目名称","项目描述","项目建模拓扑","创建时间","修改时间","创建人"};
for (int i = 0; i < cols.length; i++) header.createCell(i).setCellValue(cols[i]); for (int i = 0; i < cols.length; i++) header.createCell(i).setCellValue(cols[i]);
DateTimeFormatter fmt = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); DateTimeFormatter fmt = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
for (Project p : list) { for (Project p : list) {

View File

@ -64,6 +64,7 @@ public class SecurityConfig {
.requestMatchers(HttpMethod.GET, .requestMatchers(HttpMethod.GET,
"/*.html", "/*.html",
"/webSocket/**", "/webSocket/**",
"/ws/**",
"/assets/**", "/assets/**",
"/icon/**").permitAll() "/icon/**").permitAll()
.requestMatchers( .requestMatchers(
@ -80,6 +81,7 @@ public class SecurityConfig {
"/pageimage/**", "/pageimage/**",
"/avatar/**", "/avatar/**",
"/systemurl/**", "/systemurl/**",
"/train/internal/callback",
"/api/imageserver/upload").permitAll() "/api/imageserver/upload").permitAll()
.anyRequest().authenticated(); .anyRequest().authenticated();
} }