150 lines
6.3 KiB
Python
150 lines
6.3 KiB
Python
import os
|
||
import time
|
||
import pickle
|
||
import numpy as np
|
||
from sklearn.pipeline import Pipeline
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.gaussian_process import GaussianProcessRegressor
|
||
from sklearn.gaussian_process.kernels import RBF, RationalQuadratic, ConstantKernel, WhiteKernel
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, max_error
|
||
try:
|
||
from .type_config import load_from_excel, features_to_vector
|
||
except ImportError:
|
||
from type_config import load_from_excel, features_to_vector
|
||
|
||
"""
|
||
训练与推理服务方法
|
||
- build_gpr_model: 构建 GPR 模型
|
||
- build_pipeline: 构建包含标准化与模型的 Pipeline
|
||
- compute_metrics: 计算评估指标
|
||
- train_one_type: 按设备类型训练并保存模型
|
||
- infer_one: 单条推理
|
||
- infer_batch: 批量推理
|
||
"""
|
||
|
||
RANDOM_STATE = 42
|
||
TEST_SIZE = 0.2
|
||
|
||
def build_gpr_model():
|
||
"""构建高斯过程回归模型,含复合核与输出中心化"""
|
||
kernel = ConstantKernel(1.0) * (0.7 * RBF(1.0) + 0.3 * RationalQuadratic(1.0, 1.0)) + WhiteKernel(1e-5)
|
||
return GaussianProcessRegressor(kernel=kernel, alpha=0.0, n_restarts_optimizer=5, normalize_y=True, random_state=RANDOM_STATE)
|
||
|
||
def build_pipeline():
|
||
"""构建训练/推理管线:输入标准化 + 模型"""
|
||
return Pipeline([("scaler", StandardScaler()), ("model", build_gpr_model())])
|
||
|
||
def compute_metrics(y_true, y_pred):
|
||
"""计算回归误差指标:RMSE/MAE/R2/maxe"""
|
||
return {
|
||
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||
"mae": float(mean_absolute_error(y_true, y_pred)),
|
||
"r2": float(r2_score(y_true, y_pred)),
|
||
"maxe": float(max_error(y_true, y_pred))
|
||
}
|
||
|
||
def train_one_type(device_type, dataset_path, model_dir):
|
||
"""按设备类型训练模型并保存 Pipeline,返回指标与耗时"""
|
||
X, y = load_from_excel(device_type, dataset_path)
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
|
||
pipe = build_pipeline()
|
||
pipe.fit(X_train, y_train)
|
||
X_all = np.vstack([X_train, X_test])
|
||
y_all = np.concatenate([y_train, y_test])
|
||
t0 = time.time()
|
||
pred_all = pipe.predict(X_all)
|
||
t1 = time.time()
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
model_path = os.path.join(model_dir, "pipeline.pkl")
|
||
with open(model_path, "wb") as f:
|
||
pickle.dump(pipe, f)
|
||
return {"metrics": compute_metrics(y_all, pred_all), "infer_batch_time_sec": float(t1 - t0), "model_path": model_path}
|
||
|
||
def infer_one(device_type, features, model_dir):
|
||
"""加载 Pipeline,按设备类型将特征映射为向量并进行单条推理"""
|
||
model_path = os.path.join(model_dir, "pipeline.pkl")
|
||
with open(model_path, "rb") as f:
|
||
pipe = pickle.load(f)
|
||
X = features_to_vector(device_type, features)
|
||
y = pipe.predict(X)
|
||
return float(y[0])
|
||
|
||
def infer_batch(device_type, features_list, model_dir):
|
||
"""加载 Pipeline,批量将特征映射为向量并进行推理"""
|
||
model_path = os.path.join(model_dir, "pipeline.pkl")
|
||
with open(model_path, "rb") as f:
|
||
pipe = pickle.load(f)
|
||
Xs = []
|
||
for feat in features_list:
|
||
Xs.append(features_to_vector(device_type, feat))
|
||
X = np.vstack(Xs)
|
||
ys = pipe.predict(X)
|
||
return [float(v) for v in ys]
|
||
|
||
def train_one_type_from_samples(device_type, features_list, labels, model_dir):
|
||
"""按设备类型基于样本集合训练模型并保存 Pipeline"""
|
||
Xs = []
|
||
for feat in features_list:
|
||
Xs.append(features_to_vector(device_type, feat))
|
||
X = np.vstack(Xs)
|
||
y = np.array(labels, dtype=float)
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
|
||
pipe = build_pipeline()
|
||
pipe.fit(X_train, y_train)
|
||
X_all = np.vstack([X_train, X_test])
|
||
y_all = np.concatenate([y_train, y_test])
|
||
t0 = time.time()
|
||
pred_all = pipe.predict(X_all)
|
||
t1 = time.time()
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
model_path = os.path.join(model_dir, "pipeline.pkl")
|
||
with open(model_path, "wb") as f:
|
||
pickle.dump(pipe, f)
|
||
return {"metrics": compute_metrics(y_all, pred_all), "infer_batch_time_sec": float(t1 - t0), "model_path": model_path}
|
||
|
||
if __name__ == "__main__":
|
||
root = os.getcwd()
|
||
tasks = [
|
||
("cylindrical_tank", os.path.join(root, "circle.xlsx"), os.path.join(root, "models", "cylindrical_tank")),
|
||
("ring_tank", os.path.join(root, "ring.xlsx"), os.path.join(root, "models", "ring_tank"))
|
||
]
|
||
results = {}
|
||
for dt, ds, md in tasks:
|
||
if not os.path.exists(ds):
|
||
print(f"[skip] dataset not found: {ds}")
|
||
continue
|
||
res = train_one_type(dt, ds, md)
|
||
results[dt] = res
|
||
print(f"[train] {dt} -> {res}")
|
||
sample_cyl = {"直径": 160, "高度": 160, "铀浓度": 20, "铀富集度": 0.01}
|
||
cyl_dir = os.path.join(root, "models", "cylindrical_tank")
|
||
if os.path.exists(os.path.join(cyl_dir, "pipeline.pkl")):
|
||
y_cyl = infer_one("cylindrical_tank", sample_cyl, cyl_dir)
|
||
print(f"[infer] cylindrical_tank keff={y_cyl}")
|
||
sample_ring = {"外径": 70, "高度": 70, "Pu浓度": 40, "Pu240占比": 0.05}
|
||
ring_dir = os.path.join(root, "models", "ring_tank")
|
||
if os.path.exists(os.path.join(ring_dir, "pipeline.pkl")):
|
||
y_ring = infer_one("ring_tank", sample_ring, ring_dir)
|
||
print(f"[infer] ring_tank keff={y_ring}")
|
||
|
||
def main():
|
||
tasks = [
|
||
("cylindrical_tank", "circle.xlsx", os.path.join("models", "cylindrical_tank")),
|
||
("ring_tank", "ring.xlsx", os.path.join("models", "ring_tank"))
|
||
]
|
||
results = {}
|
||
for t, path, out in tasks:
|
||
res = train_one_type(t, path, out)
|
||
results[t] = res
|
||
print({"train_results": results})
|
||
sample_cyl = {"直径": 160, "高度": 160, "铀浓度": 20, "铀富集度": 0.01}
|
||
y_cyl = infer_one("cylindrical_tank", sample_cyl, os.path.join("models", "cylindrical_tank"))
|
||
sample_ring = {"外径": 70, "高度": 70, "Pu浓度": 40, "Pu240占比": 0.05}
|
||
y_ring = infer_one("ring_tank", sample_ring, os.path.join("models", "ring_tank"))
|
||
print({"infer_examples": {"cylindrical_tank": {"features": sample_cyl, "keff": y_cyl},
|
||
"ring_tank": {"features": sample_ring, "keff": y_ring}}})
|
||
|
||
if __name__ == "__main__":
|
||
main()
|