# 模型训练面板实现过程 根据《变更需求.md》中的第四点,我们需要为系统新增一个“模型训练面板”。该面板旨在支持用户通过简单的交互界面,上传数据、配置参数并启动模型训练,最终将训练好的模型发布到系统中。 ## 1. 核心流程分析 1. **数据准备**:用户上传训练数据集(Excel/CSV),或从系统历史数据中选择(后续扩展)。 2. **任务配置**:选择目标设备类型、算法类型,并输入模型超参数。 3. **任务提交**:后端接收请求,创建训练任务记录,并异步调用 Python 算法服务。 4. **过程监控**:前端轮询任务状态,展示实时训练进度和评估指标(如 Loss, Accuracy)。 5. **模型入库**:训练成功后,自动解析模型元数据(特征映射、评估指标图表),并将其注册到 `algorithm_model` 表中。 ## 2. 详细实现步骤 ### 2.1 数据库设计 虽然已有 `algorithm_model` 表用于存储正式模型,但训练过程是一个长耗时且可能失败的操作,直接写入正式表会导致垃圾数据。因此,建议新增一张 **训练任务表 (`model_train_task`)** 来管理训练生命周期。 ```sql CREATE TABLE model_train_task ( task_id VARCHAR(32) PRIMARY KEY COMMENT '任务ID', task_name VARCHAR(100) COMMENT '任务名称', algorithm_type VARCHAR(50) COMMENT '算法类型 (GPR, MLP...)', device_type VARCHAR(50) COMMENT '设备类型 (CylindricalTank...)', dataset_path VARCHAR(255) COMMENT '数据集文件路径', train_params JSON COMMENT '训练超参数 ({"epochs": 100, "lr": 0.01...})', status VARCHAR(20) COMMENT '状态 (PENDING, TRAINING, SUCCESS, FAILED)', metrics JSON COMMENT '训练指标 ({"rmse": 0.002, "r2": 0.98})', model_output_path VARCHAR(255) COMMENT '训练生成的临时模型路径', feature_map_snapshot JSON COMMENT '特征映射快照', metrics_image_path VARCHAR(255) COMMENT '误差散点图路径', error_log TEXT COMMENT '错误日志', created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ); ``` ### 2.2 后端接口设计 (`ModelTrainController`) | 接口路径 | 方法 | 描述 | 参数示例 | | :--- | :--- | :--- | :--- | | `/train/upload` | POST | 上传训练数据集 | File | | `/train/submit` | POST | 提交训练任务 | `{name, algoType, deviceType, datasetPath, params}` | | `/train/status/{taskId}` | GET | 查询训练状态与指标 | - | | `/train/publish/{taskId}` | POST | 发布模型(入库) | `{versionTag, description}` | | `/train/list` | GET | 查询历史训练任务 | `page, size` | ### 2.3 业务逻辑实现 (`ModelTrainService`) 1. **数据上传 (`uploadDataset`)**: * 接收 Excel/CSV 文件。 * 校验文件格式(必需列:`input_features`, `target_label`)。 * 保存到服务器临时目录 `/data/uploads/{date}/{uuid}.csv`。 2. **任务提交 (`createTask`)**: * 创建 `model_train_task` 记录,状态设为 `PENDING`。 * 构建请求参数,异步调用 Python 服务的 `/v1/train` 接口。 * 若调用成功,更新状态为 `TRAINING`;否则标记为 `FAILED`。 3. **状态同步 (`syncTaskStatus`)**: * 定时任务或前端轮询触发:调用 Python 服务的 `/v1/train/status/{taskId}` 接口。 * 更新数据库中的 `metrics` 和 `status`。 * 如果状态为 `SUCCESS`,记录 `model_output_path`, `feature_map_snapshot`, `metrics_image_path`。 4. **模型发布 (`publishModel`)**: * 校验任务状态是否为 `SUCCESS`。 * 将模型文件从临时目录移动到正式模型目录 `/models/{algo}/{device}/`。 * 在 `algorithm_model` 表中插入新记录: * `algorithm_model_id`: 生成 UUID * `version_tag`: 用户输入(如 v1.2) * `feature_map_snapshot`: 从任务表复制 * `metrics`: 从任务表复制 * `metrics_image_path`: 从任务表复制 * `is_current`: 默认为 0(需人工激活) ### 2.4 Python 算法服务扩展 需要在 Python 服务中新增以下 API: * `POST /v1/train`: * 接收:`dataset_path`, `algorithm_type`, `device_type`, `hyperparameters`。 * 动作:启动后台线程/进程进行训练。 * 返回:`task_id`。 * `GET /v1/train/status/{task_id}`: * 返回:`status` (TRAINING/SUCCESS/FAILED), `progress` (0-100%), `metrics`, `model_path`, `feature_map`, `metrics_image`。 ### 2.5 前端界面开发 1. **新建训练任务向导**: * **Step 1: 基础配置** * 任务名称:文本框。 * 设备类型:下拉框(从字典加载)。 * 算法类型:下拉框(GPR, MLP...)。 * **Step 2: 数据准备** * 数据源选择:[上传文件] / [选择历史数据]。 * 上传组件:支持 .csv, .xlsx。 * **Step 3: 参数配置** * 文本框输入 JSON 或 Key-Value 对。 * 提供“加载默认参数”按钮。 2. **训练监控面板**: * 展示任务列表。 * 点击“查看详情”进入监控页。 * **实时图表**:展示 Loss 曲线。 * **评估结果**:展示 RMSE, R2 指标及误差散点图(`metrics_image_path`)。 3. **模型发布弹窗**: * 输入版本号。 * 点击确认后,将训练好的模型正式注册到系统中。 ## 3. 关键问题与解决方案 1. **训练耗时问题**: * **方案**:后端采用异步调用,前端使用轮询(每 3-5 秒)。Python 端必须使用非阻塞方式执行训练(如 `threading` 或 `Celery`)。 2. **数据格式校验**: * **方案**:后端在上传阶段需解析 Excel 表头,确保包含必要的特征列。若缺失关键列,直接拒绝上传。 3. **模型文件管理**: * **方案**:区分“临时模型”和“正式模型”。训练生成的模型先放在临时目录,只有用户确认发布后才移动到正式目录,避免无效模型占用空间。 4. **特征映射一致性**: * **方案**:训练时必须生成 `feature_map_snapshot`(记录输入特征与 Java 实体属性的对应关系),并将其保存到 `algorithm_model` 表中。推理服务加载模型时,必须读取该快照以正确组装输入数据。 ## 4. 接口改进 1. list接口加上查询条件: * `algoType`: 算法类型(GPR, MLP...)。 * `deviceType`: 设备类型(CPU, GPU...)。 * `status`: 任务状态(PENDING, TRAINING, SUCCESS, FAILED)。 * `name`: 任务名称(模糊查询)。 * `page`: 分页页码。 * `size`: 每页数量(默认 10)。 2. submit 接口加上MultipartFile file 非必须参数 * `datasetPath`: 数据集路径(上传时返回的路径)。--这个需要根据是否上传文件来判断是否需要 、 3.超参数面板 GPR根据类型获取超参数面板。 [ { "name": "RBF 平滑尺度", "key": "rbf_length_scale", "description": "控制模型整体平滑程度。值越小模型越敏感,值越大预测曲线越平滑。", "default": 1.0, "range": [0.01, 100], "category": "基础参数" }, { "name": "RQ 平滑尺度", "key": "rq_length_scale", "description": "控制多尺度变化的平滑程度,用于描述局部变化特征。", "default": 1.0, "range": [0.01, 100], "category": "基础参数" }, { "name": "RQ 多尺度系数", "key": "rq_alpha", "description": "控制多尺度变化强度。值越小多尺度变化越明显,值越大模型趋近于RBF。", "default": 1.0, "range": [0.1, 10] }, { "name": "噪声水平", "key": "noise_level", "description": "数据观测噪声大小,用于提升模型鲁棒性。", "default": 1e-5, "range": [1e-8, 1e-1], "category": "基础参数" }, { "name": "优化重启次数", "key": "optimizer_restarts", "description": "模型超参数优化时的随机重启次数,用于避免陷入局部最优。", "default": 5, "range": [0, 20], "category": "训练参数" } ] 生成这样的数据:{ "rbf_length_scale": 2.0, "rq_length_scale": 1.5, "rq_alpha": 0.8, "noise_level": 1e-4, "optimizer_restarts": 10 }