8.3 KiB
8.3 KiB
模型训练面板实现过程
根据《变更需求.md》中的第四点,我们需要为系统新增一个“模型训练面板”。该面板旨在支持用户通过简单的交互界面,上传数据、配置参数并启动模型训练,最终将训练好的模型发布到系统中。
1. 核心流程分析
- 数据准备:用户上传训练数据集(Excel/CSV),或从系统历史数据中选择(后续扩展)。
- 任务配置:选择目标设备类型、算法类型,并输入模型超参数。
- 任务提交:后端接收请求,创建训练任务记录,并异步调用 Python 算法服务。
- 过程监控:前端轮询任务状态,展示实时训练进度和评估指标(如 Loss, Accuracy)。
- 模型入库:训练成功后,自动解析模型元数据(特征映射、评估指标图表),并将其注册到
algorithm_model表中。
2. 详细实现步骤
2.1 数据库设计
虽然已有 algorithm_model 表用于存储正式模型,但训练过程是一个长耗时且可能失败的操作,直接写入正式表会导致垃圾数据。因此,建议新增一张 训练任务表 (model_train_task) 来管理训练生命周期。
CREATE TABLE model_train_task (
task_id VARCHAR(32) PRIMARY KEY COMMENT '任务ID',
task_name VARCHAR(100) COMMENT '任务名称',
algorithm_type VARCHAR(50) COMMENT '算法类型 (GPR, MLP...)',
device_type VARCHAR(50) COMMENT '设备类型 (CylindricalTank...)',
dataset_path VARCHAR(255) COMMENT '数据集文件路径',
train_params JSON COMMENT '训练超参数 ({"epochs": 100, "lr": 0.01...})',
status VARCHAR(20) COMMENT '状态 (PENDING, TRAINING, SUCCESS, FAILED)',
metrics JSON COMMENT '训练指标 ({"rmse": 0.002, "r2": 0.98})',
model_output_path VARCHAR(255) COMMENT '训练生成的临时模型路径',
feature_map_snapshot JSON COMMENT '特征映射快照',
metrics_image_path VARCHAR(255) COMMENT '误差散点图路径',
error_log TEXT COMMENT '错误日志',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
2.2 后端接口设计 (ModelTrainController)
| 接口路径 | 方法 | 描述 | 参数示例 |
|---|---|---|---|
/train/upload |
POST | 上传训练数据集 | File |
/train/submit |
POST | 提交训练任务 | {name, algoType, deviceType, datasetPath, params} |
/train/status/{taskId} |
GET | 查询训练状态与指标 | - |
/train/publish/{taskId} |
POST | 发布模型(入库) | {versionTag, description} |
/train/list |
GET | 查询历史训练任务 | page, size |
2.3 业务逻辑实现 (ModelTrainService)
-
数据上传 (
uploadDataset):- 接收 Excel/CSV 文件。
- 校验文件格式(必需列:
input_features,target_label)。 - 保存到服务器临时目录
/data/uploads/{date}/{uuid}.csv。
-
任务提交 (
createTask):- 创建
model_train_task记录,状态设为PENDING。 - 构建请求参数,异步调用 Python 服务的
/v1/train接口。 - 若调用成功,更新状态为
TRAINING;否则标记为FAILED。
- 创建
-
状态同步 (
syncTaskStatus):- 定时任务或前端轮询触发:调用 Python 服务的
/v1/train/status/{taskId}接口。 - 更新数据库中的
metrics和status。 - 如果状态为
SUCCESS,记录model_output_path,feature_map_snapshot,metrics_image_path。
- 定时任务或前端轮询触发:调用 Python 服务的
-
模型发布 (
publishModel):- 校验任务状态是否为
SUCCESS。 - 将模型文件从临时目录移动到正式模型目录
/models/{algo}/{device}/。 - 在
algorithm_model表中插入新记录:algorithm_model_id: 生成 UUIDversion_tag: 用户输入(如 v1.2)feature_map_snapshot: 从任务表复制metrics: 从任务表复制metrics_image_path: 从任务表复制is_current: 默认为 0(需人工激活)
- 校验任务状态是否为
2.4 Python 算法服务扩展
需要在 Python 服务中新增以下 API:
POST /v1/train:- 接收:
dataset_path,algorithm_type,device_type,hyperparameters。 - 动作:启动后台线程/进程进行训练。
- 返回:
task_id。
- 接收:
GET /v1/train/status/{task_id}:- 返回:
status(TRAINING/SUCCESS/FAILED),progress(0-100%),metrics,model_path,feature_map,metrics_image。
- 返回:
2.5 前端界面开发
-
新建训练任务向导:
- Step 1: 基础配置
- 任务名称:文本框。
- 设备类型:下拉框(从字典加载)。
- 算法类型:下拉框(GPR, MLP...)。
- Step 2: 数据准备
- 数据源选择:[上传文件] / [选择历史数据]。
- 上传组件:支持 .csv, .xlsx。
- Step 3: 参数配置
- 文本框输入 JSON 或 Key-Value 对。
- 提供“加载默认参数”按钮。
- Step 1: 基础配置
-
训练监控面板:
- 展示任务列表。
- 点击“查看详情”进入监控页。
- 实时图表:展示 Loss 曲线。
- 评估结果:展示 RMSE, R2 指标及误差散点图(
metrics_image_path)。
-
模型发布弹窗:
- 输入版本号。
- 点击确认后,将训练好的模型正式注册到系统中。
3. 关键问题与解决方案
-
训练耗时问题:
- 方案:后端采用异步调用,前端使用轮询(每 3-5 秒)。Python 端必须使用非阻塞方式执行训练(如
threading或Celery)。
- 方案:后端采用异步调用,前端使用轮询(每 3-5 秒)。Python 端必须使用非阻塞方式执行训练(如
-
数据格式校验:
- 方案:后端在上传阶段需解析 Excel 表头,确保包含必要的特征列。若缺失关键列,直接拒绝上传。
-
模型文件管理:
- 方案:区分“临时模型”和“正式模型”。训练生成的模型先放在临时目录,只有用户确认发布后才移动到正式目录,避免无效模型占用空间。
-
特征映射一致性:
- 方案:训练时必须生成
feature_map_snapshot(记录输入特征与 Java 实体属性的对应关系),并将其保存到algorithm_model表中。推理服务加载模型时,必须读取该快照以正确组装输入数据。
- 方案:训练时必须生成
4. 接口改进
- list接口加上查询条件:
algoType: 算法类型(GPR, MLP...)。deviceType: 设备类型(CPU, GPU...)。status: 任务状态(PENDING, TRAINING, SUCCESS, FAILED)。name: 任务名称(模糊查询)。page: 分页页码。size: 每页数量(默认 10)。
- submit 接口加上MultipartFile file 非必须参数
datasetPath: 数据集路径(上传时返回的路径)。--这个需要根据是否上传文件来判断是否需要 、 3.超参数面板 GPR根据类型获取超参数面板。 [ { "name": "RBF 平滑尺度", "key": "rbf_length_scale", "description": "控制模型整体平滑程度。值越小模型越敏感,值越大预测曲线越平滑。", "default": 1.0, "range": [0.01, 100], "category": "基础参数" }, { "name": "RQ 平滑尺度", "key": "rq_length_scale", "description": "控制多尺度变化的平滑程度,用于描述局部变化特征。", "default": 1.0, "range": [0.01, 100], "category": "基础参数" }, { "name": "RQ 多尺度系数", "key": "rq_alpha", "description": "控制多尺度变化强度。值越小多尺度变化越明显,值越大模型趋近于RBF。", "default": 1.0, "range": [0.1, 10] }, { "name": "噪声水平", "key": "noise_level", "description": "数据观测噪声大小,用于提升模型鲁棒性。", "default": 1e-5, "range": [1e-8, 1e-1], "category": "基础参数" }, { "name": "优化重启次数", "key": "optimizer_restarts", "description": "模型超参数优化时的随机重启次数,用于避免陷入局部最优。", "default": 5, "range": [0, 20], "category": "训练参数" } ]
生成这样的数据:{ "rbf_length_scale": 2.0, "rq_length_scale": 1.5, "rq_alpha": 0.8, "noise_level": 1e-4, "optimizer_restarts": 10 }