193 lines
8.3 KiB
Markdown
193 lines
8.3 KiB
Markdown
# 模型训练面板实现过程
|
||
|
||
根据《变更需求.md》中的第四点,我们需要为系统新增一个“模型训练面板”。该面板旨在支持用户通过简单的交互界面,上传数据、配置参数并启动模型训练,最终将训练好的模型发布到系统中。
|
||
|
||
## 1. 核心流程分析
|
||
|
||
1. **数据准备**:用户上传训练数据集(Excel/CSV),或从系统历史数据中选择(后续扩展)。
|
||
2. **任务配置**:选择目标设备类型、算法类型,并输入模型超参数。
|
||
3. **任务提交**:后端接收请求,创建训练任务记录,并异步调用 Python 算法服务。
|
||
4. **过程监控**:前端轮询任务状态,展示实时训练进度和评估指标(如 Loss, Accuracy)。
|
||
5. **模型入库**:训练成功后,自动解析模型元数据(特征映射、评估指标图表),并将其注册到 `algorithm_model` 表中。
|
||
|
||
## 2. 详细实现步骤
|
||
|
||
### 2.1 数据库设计
|
||
|
||
虽然已有 `algorithm_model` 表用于存储正式模型,但训练过程是一个长耗时且可能失败的操作,直接写入正式表会导致垃圾数据。因此,建议新增一张 **训练任务表 (`model_train_task`)** 来管理训练生命周期。
|
||
|
||
```sql
|
||
CREATE TABLE model_train_task (
|
||
task_id VARCHAR(32) PRIMARY KEY COMMENT '任务ID',
|
||
task_name VARCHAR(100) COMMENT '任务名称',
|
||
algorithm_type VARCHAR(50) COMMENT '算法类型 (GPR, MLP...)',
|
||
device_type VARCHAR(50) COMMENT '设备类型 (CylindricalTank...)',
|
||
dataset_path VARCHAR(255) COMMENT '数据集文件路径',
|
||
train_params JSON COMMENT '训练超参数 ({"epochs": 100, "lr": 0.01...})',
|
||
status VARCHAR(20) COMMENT '状态 (PENDING, TRAINING, SUCCESS, FAILED)',
|
||
metrics JSON COMMENT '训练指标 ({"rmse": 0.002, "r2": 0.98})',
|
||
model_output_path VARCHAR(255) COMMENT '训练生成的临时模型路径',
|
||
feature_map_snapshot JSON COMMENT '特征映射快照',
|
||
metrics_image_path VARCHAR(255) COMMENT '误差散点图路径',
|
||
error_log TEXT COMMENT '错误日志',
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||
);
|
||
```
|
||
|
||
### 2.2 后端接口设计 (`ModelTrainController`)
|
||
|
||
| 接口路径 | 方法 | 描述 | 参数示例 |
|
||
| :--- | :--- | :--- | :--- |
|
||
| `/train/upload` | POST | 上传训练数据集 | File |
|
||
| `/train/submit` | POST | 提交训练任务 | `{name, algoType, deviceType, datasetPath, params}` |
|
||
| `/train/status/{taskId}` | GET | 查询训练状态与指标 | - |
|
||
| `/train/publish/{taskId}` | POST | 发布模型(入库) | `{versionTag, description}` |
|
||
| `/train/list` | GET | 查询历史训练任务 | `page, size` |
|
||
|
||
### 2.3 业务逻辑实现 (`ModelTrainService`)
|
||
|
||
1. **数据上传 (`uploadDataset`)**:
|
||
* 接收 Excel/CSV 文件。
|
||
* 校验文件格式(必需列:`input_features`, `target_label`)。
|
||
* 保存到服务器临时目录 `/data/uploads/{date}/{uuid}.csv`。
|
||
|
||
2. **任务提交 (`createTask`)**:
|
||
* 创建 `model_train_task` 记录,状态设为 `PENDING`。
|
||
* 构建请求参数,异步调用 Python 服务的 `/v1/train` 接口。
|
||
* 若调用成功,更新状态为 `TRAINING`;否则标记为 `FAILED`。
|
||
|
||
3. **状态同步 (`syncTaskStatus`)**:
|
||
* 定时任务或前端轮询触发:调用 Python 服务的 `/v1/train/status/{taskId}` 接口。
|
||
* 更新数据库中的 `metrics` 和 `status`。
|
||
* 如果状态为 `SUCCESS`,记录 `model_output_path`, `feature_map_snapshot`, `metrics_image_path`。
|
||
|
||
4. **模型发布 (`publishModel`)**:
|
||
* 校验任务状态是否为 `SUCCESS`。
|
||
* 将模型文件从临时目录移动到正式模型目录 `/models/{algo}/{device}/`。
|
||
* 在 `algorithm_model` 表中插入新记录:
|
||
* `algorithm_model_id`: 生成 UUID
|
||
* `version_tag`: 用户输入(如 v1.2)
|
||
* `feature_map_snapshot`: 从任务表复制
|
||
* `metrics`: 从任务表复制
|
||
* `metrics_image_path`: 从任务表复制
|
||
* `is_current`: 默认为 0(需人工激活)
|
||
|
||
### 2.4 Python 算法服务扩展
|
||
|
||
需要在 Python 服务中新增以下 API:
|
||
|
||
* `POST /v1/train`:
|
||
* 接收:`dataset_path`, `algorithm_type`, `device_type`, `hyperparameters`。
|
||
* 动作:启动后台线程/进程进行训练。
|
||
* 返回:`task_id`。
|
||
* `GET /v1/train/status/{task_id}`:
|
||
* 返回:`status` (TRAINING/SUCCESS/FAILED), `progress` (0-100%), `metrics`, `model_path`, `feature_map`, `metrics_image`。
|
||
|
||
### 2.5 前端界面开发
|
||
|
||
1. **新建训练任务向导**:
|
||
* **Step 1: 基础配置**
|
||
* 任务名称:文本框。
|
||
* 设备类型:下拉框(从字典加载)。
|
||
* 算法类型:下拉框(GPR, MLP...)。
|
||
* **Step 2: 数据准备**
|
||
* 数据源选择:[上传文件] / [选择历史数据]。
|
||
* 上传组件:支持 .csv, .xlsx。
|
||
* **Step 3: 参数配置**
|
||
* 文本框输入 JSON 或 Key-Value 对。
|
||
* 提供“加载默认参数”按钮。
|
||
|
||
2. **训练监控面板**:
|
||
* 展示任务列表。
|
||
* 点击“查看详情”进入监控页。
|
||
* **实时图表**:展示 Loss 曲线。
|
||
* **评估结果**:展示 RMSE, R2 指标及误差散点图(`metrics_image_path`)。
|
||
|
||
3. **模型发布弹窗**:
|
||
* 输入版本号。
|
||
* 点击确认后,将训练好的模型正式注册到系统中。
|
||
|
||
## 3. 关键问题与解决方案
|
||
|
||
1. **训练耗时问题**:
|
||
* **方案**:后端采用异步调用,前端使用轮询(每 3-5 秒)。Python 端必须使用非阻塞方式执行训练(如 `threading` 或 `Celery`)。
|
||
|
||
2. **数据格式校验**:
|
||
* **方案**:后端在上传阶段需解析 Excel 表头,确保包含必要的特征列。若缺失关键列,直接拒绝上传。
|
||
|
||
3. **模型文件管理**:
|
||
* **方案**:区分“临时模型”和“正式模型”。训练生成的模型先放在临时目录,只有用户确认发布后才移动到正式目录,避免无效模型占用空间。
|
||
|
||
4. **特征映射一致性**:
|
||
* **方案**:训练时必须生成 `feature_map_snapshot`(记录输入特征与 Java 实体属性的对应关系),并将其保存到 `algorithm_model` 表中。推理服务加载模型时,必须读取该快照以正确组装输入数据。
|
||
|
||
|
||
|
||
## 4. 接口改进
|
||
1. list接口加上查询条件:
|
||
* `algoType`: 算法类型(GPR, MLP...)。
|
||
* `deviceType`: 设备类型(CPU, GPU...)。
|
||
* `status`: 任务状态(PENDING, TRAINING, SUCCESS, FAILED)。
|
||
* `name`: 任务名称(模糊查询)。
|
||
* `page`: 分页页码。
|
||
* `size`: 每页数量(默认 10)。
|
||
2. submit 接口加上MultipartFile file 非必须参数
|
||
* `datasetPath`: 数据集路径(上传时返回的路径)。--这个需要根据是否上传文件来判断是否需要
|
||
、
|
||
3.超参数面板
|
||
GPR根据类型获取超参数面板。
|
||
[
|
||
{
|
||
"name": "RBF 平滑尺度",
|
||
"key": "rbf_length_scale",
|
||
"description": "控制模型整体平滑程度。值越小模型越敏感,值越大预测曲线越平滑。",
|
||
"default": 1.0,
|
||
"range": [0.01, 100],
|
||
"category": "基础参数"
|
||
},
|
||
{
|
||
"name": "RQ 平滑尺度",
|
||
"key": "rq_length_scale",
|
||
"description": "控制多尺度变化的平滑程度,用于描述局部变化特征。",
|
||
"default": 1.0,
|
||
"range": [0.01, 100],
|
||
"category": "基础参数"
|
||
},
|
||
{
|
||
"name": "RQ 多尺度系数",
|
||
"key": "rq_alpha",
|
||
"description": "控制多尺度变化强度。值越小多尺度变化越明显,值越大模型趋近于RBF。",
|
||
"default": 1.0,
|
||
"range": [0.1, 10]
|
||
},
|
||
{
|
||
"name": "噪声水平",
|
||
"key": "noise_level",
|
||
"description": "数据观测噪声大小,用于提升模型鲁棒性。",
|
||
"default": 1e-5,
|
||
"range": [1e-8, 1e-1],
|
||
"category": "基础参数"
|
||
},
|
||
{
|
||
"name": "优化重启次数",
|
||
"key": "optimizer_restarts",
|
||
"description": "模型超参数优化时的随机重启次数,用于避免陷入局部最优。",
|
||
"default": 5,
|
||
"range": [0, 20],
|
||
"category": "训练参数"
|
||
}
|
||
]
|
||
|
||
生成这样的数据:{
|
||
"rbf_length_scale": 2.0,
|
||
"rq_length_scale": 1.5,
|
||
"rq_alpha": 0.8,
|
||
"noise_level": 1e-4,
|
||
"optimizer_restarts": 10
|
||
}
|
||
|
||
|
||
|
||
|