JavaProjectRepo/模型训练面板实现过程.md
2026-03-20 19:00:16 +08:00

193 lines
8.3 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 模型训练面板实现过程
根据《变更需求.md》中的第四点我们需要为系统新增一个“模型训练面板”。该面板旨在支持用户通过简单的交互界面上传数据、配置参数并启动模型训练最终将训练好的模型发布到系统中。
## 1. 核心流程分析
1. **数据准备**用户上传训练数据集Excel/CSV或从系统历史数据中选择后续扩展
2. **任务配置**:选择目标设备类型、算法类型,并输入模型超参数。
3. **任务提交**:后端接收请求,创建训练任务记录,并异步调用 Python 算法服务。
4. **过程监控**:前端轮询任务状态,展示实时训练进度和评估指标(如 Loss, Accuracy
5. **模型入库**:训练成功后,自动解析模型元数据(特征映射、评估指标图表),并将其注册到 `algorithm_model` 表中。
## 2. 详细实现步骤
### 2.1 数据库设计
虽然已有 `algorithm_model` 表用于存储正式模型,但训练过程是一个长耗时且可能失败的操作,直接写入正式表会导致垃圾数据。因此,建议新增一张 **训练任务表 (`model_train_task`)** 来管理训练生命周期。
```sql
CREATE TABLE model_train_task (
task_id VARCHAR(32) PRIMARY KEY COMMENT '任务ID',
task_name VARCHAR(100) COMMENT '任务名称',
algorithm_type VARCHAR(50) COMMENT '算法类型 (GPR, MLP...)',
device_type VARCHAR(50) COMMENT '设备类型 (CylindricalTank...)',
dataset_path VARCHAR(255) COMMENT '数据集文件路径',
train_params JSON COMMENT '训练超参数 ({"epochs": 100, "lr": 0.01...})',
status VARCHAR(20) COMMENT '状态 (PENDING, TRAINING, SUCCESS, FAILED)',
metrics JSON COMMENT '训练指标 ({"rmse": 0.002, "r2": 0.98})',
model_output_path VARCHAR(255) COMMENT '训练生成的临时模型路径',
feature_map_snapshot JSON COMMENT '特征映射快照',
metrics_image_path VARCHAR(255) COMMENT '误差散点图路径',
error_log TEXT COMMENT '错误日志',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
```
### 2.2 后端接口设计 (`ModelTrainController`)
| 接口路径 | 方法 | 描述 | 参数示例 |
| :--- | :--- | :--- | :--- |
| `/train/upload` | POST | 上传训练数据集 | File |
| `/train/submit` | POST | 提交训练任务 | `{name, algoType, deviceType, datasetPath, params}` |
| `/train/status/{taskId}` | GET | 查询训练状态与指标 | - |
| `/train/publish/{taskId}` | POST | 发布模型(入库) | `{versionTag, description}` |
| `/train/list` | GET | 查询历史训练任务 | `page, size` |
### 2.3 业务逻辑实现 (`ModelTrainService`)
1. **数据上传 (`uploadDataset`)**:
* 接收 Excel/CSV 文件。
* 校验文件格式(必需列:`input_features`, `target_label`)。
* 保存到服务器临时目录 `/data/uploads/{date}/{uuid}.csv`
2. **任务提交 (`createTask`)**:
* 创建 `model_train_task` 记录,状态设为 `PENDING`
* 构建请求参数,异步调用 Python 服务的 `/v1/train` 接口。
* 若调用成功,更新状态为 `TRAINING`;否则标记为 `FAILED`
3. **状态同步 (`syncTaskStatus`)**:
* 定时任务或前端轮询触发:调用 Python 服务的 `/v1/train/status/{taskId}` 接口。
* 更新数据库中的 `metrics``status`
* 如果状态为 `SUCCESS`,记录 `model_output_path`, `feature_map_snapshot`, `metrics_image_path`
4. **模型发布 (`publishModel`)**:
* 校验任务状态是否为 `SUCCESS`
* 将模型文件从临时目录移动到正式模型目录 `/models/{algo}/{device}/`
*`algorithm_model` 表中插入新记录:
* `algorithm_model_id`: 生成 UUID
* `version_tag`: 用户输入(如 v1.2
* `feature_map_snapshot`: 从任务表复制
* `metrics`: 从任务表复制
* `metrics_image_path`: 从任务表复制
* `is_current`: 默认为 0需人工激活
### 2.4 Python 算法服务扩展
需要在 Python 服务中新增以下 API
* `POST /v1/train`:
* 接收:`dataset_path`, `algorithm_type`, `device_type`, `hyperparameters`
* 动作:启动后台线程/进程进行训练。
* 返回:`task_id`。
* `GET /v1/train/status/{task_id}`:
* 返回:`status` (TRAINING/SUCCESS/FAILED), `progress` (0-100%), `metrics`, `model_path`, `feature_map`, `metrics_image`
### 2.5 前端界面开发
1. **新建训练任务向导**
* **Step 1: 基础配置**
* 任务名称:文本框。
* 设备类型:下拉框(从字典加载)。
* 算法类型下拉框GPR, MLP...)。
* **Step 2: 数据准备**
* 数据源选择:[上传文件] / [选择历史数据]。
* 上传组件:支持 .csv, .xlsx。
* **Step 3: 参数配置**
* 文本框输入 JSON 或 Key-Value 对。
* 提供“加载默认参数”按钮。
2. **训练监控面板**
* 展示任务列表。
* 点击“查看详情”进入监控页。
* **实时图表**:展示 Loss 曲线。
* **评估结果**:展示 RMSE, R2 指标及误差散点图(`metrics_image_path`)。
3. **模型发布弹窗**
* 输入版本号。
* 点击确认后,将训练好的模型正式注册到系统中。
## 3. 关键问题与解决方案
1. **训练耗时问题**
* **方案**:后端采用异步调用,前端使用轮询(每 3-5 秒。Python 端必须使用非阻塞方式执行训练(如 `threading``Celery`)。
2. **数据格式校验**
* **方案**:后端在上传阶段需解析 Excel 表头,确保包含必要的特征列。若缺失关键列,直接拒绝上传。
3. **模型文件管理**
* **方案**:区分“临时模型”和“正式模型”。训练生成的模型先放在临时目录,只有用户确认发布后才移动到正式目录,避免无效模型占用空间。
4. **特征映射一致性**
* **方案**:训练时必须生成 `feature_map_snapshot`(记录输入特征与 Java 实体属性的对应关系),并将其保存到 `algorithm_model` 表中。推理服务加载模型时,必须读取该快照以正确组装输入数据。
## 4. 接口改进
1. list接口加上查询条件
* `algoType`: 算法类型GPR, MLP...)。
* `deviceType`: 设备类型CPU, GPU...)。
* `status`: 任务状态PENDING, TRAINING, SUCCESS, FAILED
* `name`: 任务名称(模糊查询)。
* `page`: 分页页码。
* `size`: 每页数量(默认 10
2. submit 接口加上MultipartFile file 非必须参数
* `datasetPath`: 数据集路径(上传时返回的路径)。--这个需要根据是否上传文件来判断是否需要
3.超参数面板
GPR根据类型获取超参数面板。
[
{
"name": "RBF 平滑尺度",
"key": "rbf_length_scale",
"description": "控制模型整体平滑程度。值越小模型越敏感,值越大预测曲线越平滑。",
"default": 1.0,
"range": [0.01, 100],
"category": "基础参数"
},
{
"name": "RQ 平滑尺度",
"key": "rq_length_scale",
"description": "控制多尺度变化的平滑程度,用于描述局部变化特征。",
"default": 1.0,
"range": [0.01, 100],
"category": "基础参数"
},
{
"name": "RQ 多尺度系数",
"key": "rq_alpha",
"description": "控制多尺度变化强度。值越小多尺度变化越明显值越大模型趋近于RBF。",
"default": 1.0,
"range": [0.1, 10]
},
{
"name": "噪声水平",
"key": "noise_level",
"description": "数据观测噪声大小,用于提升模型鲁棒性。",
"default": 1e-5,
"range": [1e-8, 1e-1],
"category": "基础参数"
},
{
"name": "优化重启次数",
"key": "optimizer_restarts",
"description": "模型超参数优化时的随机重启次数,用于避免陷入局部最优。",
"default": 5,
"range": [0, 20],
"category": "训练参数"
}
]
生成这样的数据:{
"rbf_length_scale": 2.0,
"rq_length_scale": 1.5,
"rq_alpha": 0.8,
"noise_level": 1e-4,
"optimizer_restarts": 10
}