135 lines
5.7 KiB
Markdown
135 lines
5.7 KiB
Markdown
|
|
# 模型训练面板编码实现规划
|
|||
|
|
|
|||
|
|
本规划基于《模型训练面板实现过程.md》,旨在提供具体的编码指导。**核心原则:简化开发,直接使用 Domain 实体,避免繁琐的 DTO 转换。**
|
|||
|
|
|
|||
|
|
## 1. 数据库与实体层
|
|||
|
|
|
|||
|
|
### 1.1 数据库表 (`model_train_task`)
|
|||
|
|
请确保数据库已执行以下 SQL:
|
|||
|
|
```sql
|
|||
|
|
CREATE TABLE IF NOT EXISTS `model_train_task` (
|
|||
|
|
`task_id` char(36) NOT NULL COMMENT '任务ID',
|
|||
|
|
`task_name` varchar(100) DEFAULT NULL COMMENT '任务名称',
|
|||
|
|
`algorithm_type` varchar(50) DEFAULT NULL COMMENT '算法类型',
|
|||
|
|
`device_type` varchar(50) DEFAULT NULL COMMENT '设备类型',
|
|||
|
|
`dataset_path` varchar(255) DEFAULT NULL COMMENT '数据集路径',
|
|||
|
|
`train_params` json DEFAULT NULL COMMENT '训练参数',
|
|||
|
|
`status` varchar(20) DEFAULT 'PENDING' COMMENT '状态',
|
|||
|
|
`metrics` json DEFAULT NULL COMMENT '训练指标',
|
|||
|
|
`model_output_path` varchar(255) DEFAULT NULL COMMENT '临时模型路径',
|
|||
|
|
`feature_map_snapshot` json DEFAULT NULL COMMENT '特征映射',
|
|||
|
|
`metrics_image_path` varchar(255) DEFAULT NULL COMMENT '指标图路径',
|
|||
|
|
`error_log` text DEFAULT NULL COMMENT '错误日志',
|
|||
|
|
`created_at` datetime DEFAULT CURRENT_TIMESTAMP,
|
|||
|
|
`updated_at` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
|||
|
|
PRIMARY KEY (`task_id`)
|
|||
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 1.2 Domain 实体 (`ModelTrainTask.java`)
|
|||
|
|
* **位置**: `com.yfd.business.css.domain`
|
|||
|
|
* **注解**: `@TableName("model_train_task")`, `@TableId(type = IdType.ASSIGN_UUID)`
|
|||
|
|
* **JSON 字段处理**: 使用 MyBatis-Plus 的 `@TableField(typeHandler = JacksonTypeHandler.class)` 处理 `trainParams`, `metrics`, `featureMapSnapshot` 等 JSON 字段。建议直接映射为 `Map<String, Object>` 或 `JsonNode`,或者简单的 `String`(前端解析)。为了简化,使用 `String` 存储 JSON 字符串。
|
|||
|
|
|
|||
|
|
```java
|
|||
|
|
@Data
|
|||
|
|
@TableName(value = "model_train_task", autoResultMap = true)
|
|||
|
|
public class ModelTrainTask implements Serializable {
|
|||
|
|
@TableId(type = IdType.ASSIGN_UUID)
|
|||
|
|
private String taskId;
|
|||
|
|
private String taskName;
|
|||
|
|
private String algorithmType;
|
|||
|
|
private String deviceType;
|
|||
|
|
private String datasetPath;
|
|||
|
|
|
|||
|
|
@TableField(typeHandler = JacksonTypeHandler.class)
|
|||
|
|
private Map<String, Object> trainParams; // 或 String
|
|||
|
|
|
|||
|
|
private String status; // PENDING, TRAINING, SUCCESS, FAILED
|
|||
|
|
|
|||
|
|
@TableField(typeHandler = JacksonTypeHandler.class)
|
|||
|
|
private Map<String, Object> metrics;
|
|||
|
|
|
|||
|
|
private String modelOutputPath;
|
|||
|
|
|
|||
|
|
@TableField(typeHandler = JacksonTypeHandler.class)
|
|||
|
|
private Map<String, Object> featureMapSnapshot;
|
|||
|
|
|
|||
|
|
private String metricsImagePath;
|
|||
|
|
private String errorLog;
|
|||
|
|
|
|||
|
|
@TableField(fill = FieldFill.INSERT)
|
|||
|
|
private LocalDateTime createdAt;
|
|||
|
|
|
|||
|
|
@TableField(fill = FieldFill.INSERT_UPDATE)
|
|||
|
|
private LocalDateTime updatedAt;
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## 2. Mapper 层 (`ModelTrainTaskMapper.java`)
|
|||
|
|
* **位置**: `com.yfd.business.css.mapper`
|
|||
|
|
* **继承**: `BaseMapper<ModelTrainTask>`
|
|||
|
|
|
|||
|
|
## 3. Service 层
|
|||
|
|
|
|||
|
|
### 3.1 接口 (`ModelTrainService.java`)
|
|||
|
|
* **位置**: `com.yfd.business.css.service`
|
|||
|
|
* **继承**: `IService<ModelTrainTask>`
|
|||
|
|
* **核心方法**:
|
|||
|
|
* `String uploadDataset(MultipartFile file)`: 上传文件,返回路径。
|
|||
|
|
* `String submitTask(ModelTrainTask task)`: 提交任务(接收前端传来的实体,补充初始状态后保存,并异步调用 Python)。
|
|||
|
|
* `ModelTrainTask syncTaskStatus(String taskId)`: 同步并返回最新状态。
|
|||
|
|
* `boolean publishModel(String taskId, String versionTag, String description)`: 发布模型。
|
|||
|
|
|
|||
|
|
### 3.2 实现 (`ModelTrainServiceImpl.java`)
|
|||
|
|
* **位置**: `com.yfd.business.css.service.impl`
|
|||
|
|
* **依赖**: `ModelTrainTaskMapper`, `RestTemplate` (调用 Python), `AlgorithmModelService` (发布模型)。
|
|||
|
|
|
|||
|
|
## 4. Controller 层 (`ModelTrainController.java`)
|
|||
|
|
|
|||
|
|
* **位置**: `com.yfd.business.css.controller`
|
|||
|
|
* **原则**: 直接接收和返回 `ModelTrainTask` 实体或 `Map`。
|
|||
|
|
|
|||
|
|
### 4.1 接口定义
|
|||
|
|
|
|||
|
|
1. **上传数据集**
|
|||
|
|
* `POST /train/upload`
|
|||
|
|
* 参数: `@RequestParam("file") MultipartFile file`
|
|||
|
|
* 返回: `Result<String>` (文件路径)
|
|||
|
|
|
|||
|
|
2. **提交任务**
|
|||
|
|
* `POST /train/submit`
|
|||
|
|
* 参数: `@RequestBody ModelTrainTask task`
|
|||
|
|
* 逻辑: 前端构建好 `task` 对象(包含 name, algoType, deviceType, datasetPath, trainParams),传给后端。后端设置 status=PENDING, save, async call python。
|
|||
|
|
|
|||
|
|
3. **查询任务列表**
|
|||
|
|
* `GET /train/list`
|
|||
|
|
* 参数: `PageQuery`, `ModelTrainTask query` (查询条件)
|
|||
|
|
* 返回: `Result<Page<ModelTrainTask>>`
|
|||
|
|
|
|||
|
|
4. **查询任务详情/状态**
|
|||
|
|
* `GET /train/status/{taskId}`
|
|||
|
|
* 逻辑: 调用 Service 的 `syncTaskStatus`,触发一次 Python 状态查询,更新数据库,然后返回最新实体。
|
|||
|
|
|
|||
|
|
5. **发布模型**
|
|||
|
|
* `POST /train/publish`
|
|||
|
|
* 参数: `Map<String, String> body` (包含 taskId, versionTag, description)
|
|||
|
|
* 逻辑: 检查状态 -> 移动文件 -> 插入 `algorithm_model` 表。
|
|||
|
|
|
|||
|
|
## 5. Python 服务交互
|
|||
|
|
* **调用方式**: `RestTemplate`
|
|||
|
|
* **API 对应**:
|
|||
|
|
* Java `submitTask` -> Python `POST /v1/train`
|
|||
|
|
* Java `syncTaskStatus` -> Python `GET /v1/train/status/{taskId}`
|
|||
|
|
|
|||
|
|
## 6. 开发步骤建议
|
|||
|
|
|
|||
|
|
1. **创建表**:执行 SQL。
|
|||
|
|
2. **生成代码**:创建 Domain, Mapper, Service, Controller。
|
|||
|
|
3. **实现上传接口**:确保文件能保存到指定目录(如 `data/upload/`)。
|
|||
|
|
4. **实现提交接口**:
|
|||
|
|
* 先实现保存到数据库。
|
|||
|
|
* 再实现 `RestTemplate` 调用 Python(Mock 阶段可先跳过 Python 调用,直接设为 SUCCESS)。
|
|||
|
|
5. **实现状态查询**:对接 Python 查询接口。
|
|||
|
|
6. **实现发布接口**:操作 `algorithm_model` 表。
|