JavaProjectRepo/模型训练面板实现过程.md
2026-03-20 19:00:16 +08:00

8.3 KiB
Raw Blame History

模型训练面板实现过程

根据《变更需求.md》中的第四点我们需要为系统新增一个“模型训练面板”。该面板旨在支持用户通过简单的交互界面上传数据、配置参数并启动模型训练最终将训练好的模型发布到系统中。

1. 核心流程分析

  1. 数据准备用户上传训练数据集Excel/CSV或从系统历史数据中选择后续扩展
  2. 任务配置:选择目标设备类型、算法类型,并输入模型超参数。
  3. 任务提交:后端接收请求,创建训练任务记录,并异步调用 Python 算法服务。
  4. 过程监控:前端轮询任务状态,展示实时训练进度和评估指标(如 Loss, Accuracy
  5. 模型入库:训练成功后,自动解析模型元数据(特征映射、评估指标图表),并将其注册到 algorithm_model 表中。

2. 详细实现步骤

2.1 数据库设计

虽然已有 algorithm_model 表用于存储正式模型,但训练过程是一个长耗时且可能失败的操作,直接写入正式表会导致垃圾数据。因此,建议新增一张 训练任务表 (model_train_task) 来管理训练生命周期。

CREATE TABLE model_train_task (
    task_id VARCHAR(32) PRIMARY KEY COMMENT '任务ID',
    task_name VARCHAR(100) COMMENT '任务名称',
    algorithm_type VARCHAR(50) COMMENT '算法类型 (GPR, MLP...)',
    device_type VARCHAR(50) COMMENT '设备类型 (CylindricalTank...)',
    dataset_path VARCHAR(255) COMMENT '数据集文件路径',
    train_params JSON COMMENT '训练超参数 ({"epochs": 100, "lr": 0.01...})',
    status VARCHAR(20) COMMENT '状态 (PENDING, TRAINING, SUCCESS, FAILED)',
    metrics JSON COMMENT '训练指标 ({"rmse": 0.002, "r2": 0.98})',
    model_output_path VARCHAR(255) COMMENT '训练生成的临时模型路径',
    feature_map_snapshot JSON COMMENT '特征映射快照',
    metrics_image_path VARCHAR(255) COMMENT '误差散点图路径',
    error_log TEXT COMMENT '错误日志',
    created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
    updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);

2.2 后端接口设计 (ModelTrainController)

接口路径 方法 描述 参数示例
/train/upload POST 上传训练数据集 File
/train/submit POST 提交训练任务 {name, algoType, deviceType, datasetPath, params}
/train/status/{taskId} GET 查询训练状态与指标 -
/train/publish/{taskId} POST 发布模型(入库) {versionTag, description}
/train/list GET 查询历史训练任务 page, size

2.3 业务逻辑实现 (ModelTrainService)

  1. 数据上传 (uploadDataset):

    • 接收 Excel/CSV 文件。
    • 校验文件格式(必需列:input_features, target_label)。
    • 保存到服务器临时目录 /data/uploads/{date}/{uuid}.csv
  2. 任务提交 (createTask):

    • 创建 model_train_task 记录,状态设为 PENDING
    • 构建请求参数,异步调用 Python 服务的 /v1/train 接口。
    • 若调用成功,更新状态为 TRAINING;否则标记为 FAILED
  3. 状态同步 (syncTaskStatus):

    • 定时任务或前端轮询触发:调用 Python 服务的 /v1/train/status/{taskId} 接口。
    • 更新数据库中的 metricsstatus
    • 如果状态为 SUCCESS,记录 model_output_path, feature_map_snapshot, metrics_image_path
  4. 模型发布 (publishModel):

    • 校验任务状态是否为 SUCCESS
    • 将模型文件从临时目录移动到正式模型目录 /models/{algo}/{device}/
    • algorithm_model 表中插入新记录:
      • algorithm_model_id: 生成 UUID
      • version_tag: 用户输入(如 v1.2
      • feature_map_snapshot: 从任务表复制
      • metrics: 从任务表复制
      • metrics_image_path: 从任务表复制
      • is_current: 默认为 0需人工激活

2.4 Python 算法服务扩展

需要在 Python 服务中新增以下 API

  • POST /v1/train:
    • 接收:dataset_path, algorithm_type, device_type, hyperparameters
    • 动作:启动后台线程/进程进行训练。
    • 返回:task_id
  • GET /v1/train/status/{task_id}:
    • 返回:status (TRAINING/SUCCESS/FAILED), progress (0-100%), metrics, model_path, feature_map, metrics_image

2.5 前端界面开发

  1. 新建训练任务向导

    • Step 1: 基础配置
      • 任务名称:文本框。
      • 设备类型:下拉框(从字典加载)。
      • 算法类型下拉框GPR, MLP...)。
    • Step 2: 数据准备
      • 数据源选择:[上传文件] / [选择历史数据]。
      • 上传组件:支持 .csv, .xlsx。
    • Step 3: 参数配置
      • 文本框输入 JSON 或 Key-Value 对。
      • 提供“加载默认参数”按钮。
  2. 训练监控面板

    • 展示任务列表。
    • 点击“查看详情”进入监控页。
    • 实时图表:展示 Loss 曲线。
    • 评估结果:展示 RMSE, R2 指标及误差散点图(metrics_image_path)。
  3. 模型发布弹窗

    • 输入版本号。
    • 点击确认后,将训练好的模型正式注册到系统中。

3. 关键问题与解决方案

  1. 训练耗时问题

    • 方案:后端采用异步调用,前端使用轮询(每 3-5 秒。Python 端必须使用非阻塞方式执行训练(如 threadingCelery)。
  2. 数据格式校验

    • 方案:后端在上传阶段需解析 Excel 表头,确保包含必要的特征列。若缺失关键列,直接拒绝上传。
  3. 模型文件管理

    • 方案:区分“临时模型”和“正式模型”。训练生成的模型先放在临时目录,只有用户确认发布后才移动到正式目录,避免无效模型占用空间。
  4. 特征映射一致性

    • 方案:训练时必须生成 feature_map_snapshot(记录输入特征与 Java 实体属性的对应关系),并将其保存到 algorithm_model 表中。推理服务加载模型时,必须读取该快照以正确组装输入数据。

4. 接口改进

  1. list接口加上查询条件
    • algoType: 算法类型GPR, MLP...)。
    • deviceType: 设备类型CPU, GPU...)。
    • status: 任务状态PENDING, TRAINING, SUCCESS, FAILED
    • name: 任务名称(模糊查询)。
    • page: 分页页码。
    • size: 每页数量(默认 10
  2. submit 接口加上MultipartFile file 非必须参数
    • datasetPath: 数据集路径(上传时返回的路径)。--这个需要根据是否上传文件来判断是否需要 、 3.超参数面板 GPR根据类型获取超参数面板。 [ { "name": "RBF 平滑尺度", "key": "rbf_length_scale", "description": "控制模型整体平滑程度。值越小模型越敏感,值越大预测曲线越平滑。", "default": 1.0, "range": [0.01, 100], "category": "基础参数" }, { "name": "RQ 平滑尺度", "key": "rq_length_scale", "description": "控制多尺度变化的平滑程度,用于描述局部变化特征。", "default": 1.0, "range": [0.01, 100], "category": "基础参数" }, { "name": "RQ 多尺度系数", "key": "rq_alpha", "description": "控制多尺度变化强度。值越小多尺度变化越明显值越大模型趋近于RBF。", "default": 1.0, "range": [0.1, 10] }, { "name": "噪声水平", "key": "noise_level", "description": "数据观测噪声大小,用于提升模型鲁棒性。", "default": 1e-5, "range": [1e-8, 1e-1], "category": "基础参数" }, { "name": "优化重启次数", "key": "optimizer_restarts", "description": "模型超参数优化时的随机重启次数,用于避免陷入局部最优。", "default": 5, "range": [0, 20], "category": "训练参数" } ]

生成这样的数据:{ "rbf_length_scale": 2.0, "rq_length_scale": 1.5, "rq_alpha": 0.8, "noise_level": 1e-4, "optimizer_restarts": 10 }