JavaProjectRepo/模型训练面板实现过程.md

193 lines
8.3 KiB
Markdown
Raw Normal View History

2026-03-20 19:00:16 +08:00
# 模型训练面板实现过程
根据《变更需求.md》中的第四点我们需要为系统新增一个“模型训练面板”。该面板旨在支持用户通过简单的交互界面上传数据、配置参数并启动模型训练最终将训练好的模型发布到系统中。
## 1. 核心流程分析
1. **数据准备**用户上传训练数据集Excel/CSV或从系统历史数据中选择后续扩展
2. **任务配置**:选择目标设备类型、算法类型,并输入模型超参数。
3. **任务提交**:后端接收请求,创建训练任务记录,并异步调用 Python 算法服务。
4. **过程监控**:前端轮询任务状态,展示实时训练进度和评估指标(如 Loss, Accuracy
5. **模型入库**:训练成功后,自动解析模型元数据(特征映射、评估指标图表),并将其注册到 `algorithm_model` 表中。
## 2. 详细实现步骤
### 2.1 数据库设计
虽然已有 `algorithm_model` 表用于存储正式模型,但训练过程是一个长耗时且可能失败的操作,直接写入正式表会导致垃圾数据。因此,建议新增一张 **训练任务表 (`model_train_task`)** 来管理训练生命周期。
```sql
CREATE TABLE model_train_task (
task_id VARCHAR(32) PRIMARY KEY COMMENT '任务ID',
task_name VARCHAR(100) COMMENT '任务名称',
algorithm_type VARCHAR(50) COMMENT '算法类型 (GPR, MLP...)',
device_type VARCHAR(50) COMMENT '设备类型 (CylindricalTank...)',
dataset_path VARCHAR(255) COMMENT '数据集文件路径',
train_params JSON COMMENT '训练超参数 ({"epochs": 100, "lr": 0.01...})',
status VARCHAR(20) COMMENT '状态 (PENDING, TRAINING, SUCCESS, FAILED)',
metrics JSON COMMENT '训练指标 ({"rmse": 0.002, "r2": 0.98})',
model_output_path VARCHAR(255) COMMENT '训练生成的临时模型路径',
feature_map_snapshot JSON COMMENT '特征映射快照',
metrics_image_path VARCHAR(255) COMMENT '误差散点图路径',
error_log TEXT COMMENT '错误日志',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
```
### 2.2 后端接口设计 (`ModelTrainController`)
| 接口路径 | 方法 | 描述 | 参数示例 |
| :--- | :--- | :--- | :--- |
| `/train/upload` | POST | 上传训练数据集 | File |
| `/train/submit` | POST | 提交训练任务 | `{name, algoType, deviceType, datasetPath, params}` |
| `/train/status/{taskId}` | GET | 查询训练状态与指标 | - |
| `/train/publish/{taskId}` | POST | 发布模型(入库) | `{versionTag, description}` |
| `/train/list` | GET | 查询历史训练任务 | `page, size` |
### 2.3 业务逻辑实现 (`ModelTrainService`)
1. **数据上传 (`uploadDataset`)**:
* 接收 Excel/CSV 文件。
* 校验文件格式(必需列:`input_features`, `target_label`)。
* 保存到服务器临时目录 `/data/uploads/{date}/{uuid}.csv`
2. **任务提交 (`createTask`)**:
* 创建 `model_train_task` 记录,状态设为 `PENDING`
* 构建请求参数,异步调用 Python 服务的 `/v1/train` 接口。
* 若调用成功,更新状态为 `TRAINING`;否则标记为 `FAILED`
3. **状态同步 (`syncTaskStatus`)**:
* 定时任务或前端轮询触发:调用 Python 服务的 `/v1/train/status/{taskId}` 接口。
* 更新数据库中的 `metrics``status`
* 如果状态为 `SUCCESS`,记录 `model_output_path`, `feature_map_snapshot`, `metrics_image_path`
4. **模型发布 (`publishModel`)**:
* 校验任务状态是否为 `SUCCESS`
* 将模型文件从临时目录移动到正式模型目录 `/models/{algo}/{device}/`
*`algorithm_model` 表中插入新记录:
* `algorithm_model_id`: 生成 UUID
* `version_tag`: 用户输入(如 v1.2
* `feature_map_snapshot`: 从任务表复制
* `metrics`: 从任务表复制
* `metrics_image_path`: 从任务表复制
* `is_current`: 默认为 0需人工激活
### 2.4 Python 算法服务扩展
需要在 Python 服务中新增以下 API
* `POST /v1/train`:
* 接收:`dataset_path`, `algorithm_type`, `device_type`, `hyperparameters`
* 动作:启动后台线程/进程进行训练。
* 返回:`task_id`。
* `GET /v1/train/status/{task_id}`:
* 返回:`status` (TRAINING/SUCCESS/FAILED), `progress` (0-100%), `metrics`, `model_path`, `feature_map`, `metrics_image`
### 2.5 前端界面开发
1. **新建训练任务向导**
* **Step 1: 基础配置**
* 任务名称:文本框。
* 设备类型:下拉框(从字典加载)。
* 算法类型下拉框GPR, MLP...)。
* **Step 2: 数据准备**
* 数据源选择:[上传文件] / [选择历史数据]。
* 上传组件:支持 .csv, .xlsx。
* **Step 3: 参数配置**
* 文本框输入 JSON 或 Key-Value 对。
* 提供“加载默认参数”按钮。
2. **训练监控面板**
* 展示任务列表。
* 点击“查看详情”进入监控页。
* **实时图表**:展示 Loss 曲线。
* **评估结果**:展示 RMSE, R2 指标及误差散点图(`metrics_image_path`)。
3. **模型发布弹窗**
* 输入版本号。
* 点击确认后,将训练好的模型正式注册到系统中。
## 3. 关键问题与解决方案
1. **训练耗时问题**
* **方案**:后端采用异步调用,前端使用轮询(每 3-5 秒。Python 端必须使用非阻塞方式执行训练(如 `threading``Celery`)。
2. **数据格式校验**
* **方案**:后端在上传阶段需解析 Excel 表头,确保包含必要的特征列。若缺失关键列,直接拒绝上传。
3. **模型文件管理**
* **方案**:区分“临时模型”和“正式模型”。训练生成的模型先放在临时目录,只有用户确认发布后才移动到正式目录,避免无效模型占用空间。
4. **特征映射一致性**
* **方案**:训练时必须生成 `feature_map_snapshot`(记录输入特征与 Java 实体属性的对应关系),并将其保存到 `algorithm_model` 表中。推理服务加载模型时,必须读取该快照以正确组装输入数据。
## 4. 接口改进
1. list接口加上查询条件
* `algoType`: 算法类型GPR, MLP...)。
* `deviceType`: 设备类型CPU, GPU...)。
* `status`: 任务状态PENDING, TRAINING, SUCCESS, FAILED
* `name`: 任务名称(模糊查询)。
* `page`: 分页页码。
* `size`: 每页数量(默认 10
2. submit 接口加上MultipartFile file 非必须参数
* `datasetPath`: 数据集路径(上传时返回的路径)。--这个需要根据是否上传文件来判断是否需要
3.超参数面板
GPR根据类型获取超参数面板。
[
{
"name": "RBF 平滑尺度",
"key": "rbf_length_scale",
"description": "控制模型整体平滑程度。值越小模型越敏感,值越大预测曲线越平滑。",
"default": 1.0,
"range": [0.01, 100],
"category": "基础参数"
},
{
"name": "RQ 平滑尺度",
"key": "rq_length_scale",
"description": "控制多尺度变化的平滑程度,用于描述局部变化特征。",
"default": 1.0,
"range": [0.01, 100],
"category": "基础参数"
},
{
"name": "RQ 多尺度系数",
"key": "rq_alpha",
"description": "控制多尺度变化强度。值越小多尺度变化越明显值越大模型趋近于RBF。",
"default": 1.0,
"range": [0.1, 10]
},
{
"name": "噪声水平",
"key": "noise_level",
"description": "数据观测噪声大小,用于提升模型鲁棒性。",
"default": 1e-5,
"range": [1e-8, 1e-1],
"category": "基础参数"
},
{
"name": "优化重启次数",
"key": "optimizer_restarts",
"description": "模型超参数优化时的随机重启次数,用于避免陷入局部最优。",
"default": 5,
"range": [0, 20],
"category": "训练参数"
}
]
生成这样的数据:{
"rbf_length_scale": 2.0,
"rq_length_scale": 1.5,
"rq_alpha": 0.8,
"noise_level": 1e-4,
"optimizer_restarts": 10
}