JavaProjectRepo/模型训练面板编码实现规划.md
2026-03-20 19:00:16 +08:00

135 lines
5.7 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 模型训练面板编码实现规划
本规划基于《模型训练面板实现过程.md》旨在提供具体的编码指导。**核心原则:简化开发,直接使用 Domain 实体,避免繁琐的 DTO 转换。**
## 1. 数据库与实体层
### 1.1 数据库表 (`model_train_task`)
请确保数据库已执行以下 SQL
```sql
CREATE TABLE IF NOT EXISTS `model_train_task` (
`task_id` char(36) NOT NULL COMMENT '任务ID',
`task_name` varchar(100) DEFAULT NULL COMMENT '任务名称',
`algorithm_type` varchar(50) DEFAULT NULL COMMENT '算法类型',
`device_type` varchar(50) DEFAULT NULL COMMENT '设备类型',
`dataset_path` varchar(255) DEFAULT NULL COMMENT '数据集路径',
`train_params` json DEFAULT NULL COMMENT '训练参数',
`status` varchar(20) DEFAULT 'PENDING' COMMENT '状态',
`metrics` json DEFAULT NULL COMMENT '训练指标',
`model_output_path` varchar(255) DEFAULT NULL COMMENT '临时模型路径',
`feature_map_snapshot` json DEFAULT NULL COMMENT '特征映射',
`metrics_image_path` varchar(255) DEFAULT NULL COMMENT '指标图路径',
`error_log` text DEFAULT NULL COMMENT '错误日志',
`created_at` datetime DEFAULT CURRENT_TIMESTAMP,
`updated_at` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`task_id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;
```
### 1.2 Domain 实体 (`ModelTrainTask.java`)
* **位置**: `com.yfd.business.css.domain`
* **注解**: `@TableName("model_train_task")`, `@TableId(type = IdType.ASSIGN_UUID)`
* **JSON 字段处理**: 使用 MyBatis-Plus 的 `@TableField(typeHandler = JacksonTypeHandler.class)` 处理 `trainParams`, `metrics`, `featureMapSnapshot` 等 JSON 字段。建议直接映射为 `Map<String, Object>``JsonNode`,或者简单的 `String`(前端解析)。为了简化,使用 `String` 存储 JSON 字符串。
```java
@Data
@TableName(value = "model_train_task", autoResultMap = true)
public class ModelTrainTask implements Serializable {
@TableId(type = IdType.ASSIGN_UUID)
private String taskId;
private String taskName;
private String algorithmType;
private String deviceType;
private String datasetPath;
@TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> trainParams; // 或 String
private String status; // PENDING, TRAINING, SUCCESS, FAILED
@TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> metrics;
private String modelOutputPath;
@TableField(typeHandler = JacksonTypeHandler.class)
private Map<String, Object> featureMapSnapshot;
private String metricsImagePath;
private String errorLog;
@TableField(fill = FieldFill.INSERT)
private LocalDateTime createdAt;
@TableField(fill = FieldFill.INSERT_UPDATE)
private LocalDateTime updatedAt;
}
```
## 2. Mapper 层 (`ModelTrainTaskMapper.java`)
* **位置**: `com.yfd.business.css.mapper`
* **继承**: `BaseMapper<ModelTrainTask>`
## 3. Service 层
### 3.1 接口 (`ModelTrainService.java`)
* **位置**: `com.yfd.business.css.service`
* **继承**: `IService<ModelTrainTask>`
* **核心方法**:
* `String uploadDataset(MultipartFile file)`: 上传文件,返回路径。
* `String submitTask(ModelTrainTask task)`: 提交任务(接收前端传来的实体,补充初始状态后保存,并异步调用 Python
* `ModelTrainTask syncTaskStatus(String taskId)`: 同步并返回最新状态。
* `boolean publishModel(String taskId, String versionTag, String description)`: 发布模型。
### 3.2 实现 (`ModelTrainServiceImpl.java`)
* **位置**: `com.yfd.business.css.service.impl`
* **依赖**: `ModelTrainTaskMapper`, `RestTemplate` (调用 Python), `AlgorithmModelService` (发布模型)。
## 4. Controller 层 (`ModelTrainController.java`)
* **位置**: `com.yfd.business.css.controller`
* **原则**: 直接接收和返回 `ModelTrainTask` 实体或 `Map`
### 4.1 接口定义
1. **上传数据集**
* `POST /train/upload`
* 参数: `@RequestParam("file") MultipartFile file`
* 返回: `Result<String>` (文件路径)
2. **提交任务**
* `POST /train/submit`
* 参数: `@RequestBody ModelTrainTask task`
* 逻辑: 前端构建好 `task` 对象(包含 name, algoType, deviceType, datasetPath, trainParams传给后端。后端设置 status=PENDING, save, async call python。
3. **查询任务列表**
* `GET /train/list`
* 参数: `PageQuery`, `ModelTrainTask query` (查询条件)
* 返回: `Result<Page<ModelTrainTask>>`
4. **查询任务详情/状态**
* `GET /train/status/{taskId}`
* 逻辑: 调用 Service 的 `syncTaskStatus`,触发一次 Python 状态查询,更新数据库,然后返回最新实体。
5. **发布模型**
* `POST /train/publish`
* 参数: `Map<String, String> body` (包含 taskId, versionTag, description)
* 逻辑: 检查状态 -> 移动文件 -> 插入 `algorithm_model` 表。
## 5. Python 服务交互
* **调用方式**: `RestTemplate`
* **API 对应**:
* Java `submitTask` -> Python `POST /v1/train`
* Java `syncTaskStatus` -> Python `GET /v1/train/status/{taskId}`
## 6. 开发步骤建议
1. **创建表**:执行 SQL。
2. **生成代码**:创建 Domain, Mapper, Service, Controller。
3. **实现上传接口**:确保文件能保存到指定目录(如 `data/upload/`)。
4. **实现提交接口**
* 先实现保存到数据库。
* 再实现 `RestTemplate` 调用 PythonMock 阶段可先跳过 Python 调用,直接设为 SUCCESS
5. **实现状态查询**:对接 Python 查询接口。
6. **实现发布接口**:操作 `algorithm_model` 表。