JavaProjectRepo/python-ml/app.py
2026-01-05 15:18:21 +08:00

73 lines
2.8 KiB
Python

from flask import Flask, request, jsonify
import os
try:
from .service import train_one_type, infer_one, infer_batch, train_one_type_from_samples
except ImportError:
import sys
sys.path.append(os.path.dirname(__file__))
from service import train_one_type, infer_one, infer_batch, train_one_type_from_samples
app = Flask(__name__)
def ensure_model_dir(device_type, model_dir):
root = model_dir if model_dir else os.path.join("models", device_type)
return root
@app.post("/v1/train/<device_type>")
def train(device_type):
data = request.get_json(force=True)
dataset_path = data.get("dataset_path")
model_dir = data.get("model_dir")
if not dataset_path or not os.path.exists(dataset_path):
return jsonify({"code": 1, "msg": "dataset_path not found"}), 400
out_dir = ensure_model_dir(device_type, model_dir)
res = train_one_type(device_type, dataset_path, out_dir)
return jsonify({"code": 0, "msg": "ok", "data": res})
@app.post("/v1/train/<device_type>/from-samples")
def train_from_samples(device_type):
data = request.get_json(force=True)
samples = data.get("samples")
model_dir = data.get("model_dir")
if not samples or len(samples) == 0:
return jsonify({"code": 1, "msg": "samples required"}), 400
feats = []
labels = []
for s in samples:
feats.append(s.get("features"))
labels.append(s.get("label"))
out_dir = ensure_model_dir(device_type, model_dir)
res = train_one_type_from_samples(device_type, feats, labels, out_dir)
return jsonify({"code": 0, "msg": "ok", "data": res})
@app.post("/v1/infer/<device_type>/keff")
def infer(device_type):
data = request.get_json(force=True)
model_dir = data.get("model_dir")
out_dir = ensure_model_dir(device_type, model_dir)
batch = data.get("batch")
features = data.get("features")
meta = data.get("meta") or {}
if batch and len(batch) > 0:
feats_list = []
metas = []
for s in batch:
if isinstance(s, dict) and "features" in s:
feats_list.append(s["features"])
metas.append(s.get("meta") or {})
else:
feats_list.append(s)
metas.append({})
ys = infer_batch(device_type, feats_list, out_dir)
items = []
for i, y in enumerate(ys):
items.append({"meta": metas[i], "features": feats_list[i], "keff": y})
return jsonify({"code": 0, "msg": "ok", "data": {"items": items}})
if not features:
return jsonify({"code": 1, "msg": "features required"}), 400
y = infer_one(device_type, features, out_dir)
return jsonify({"code": 0, "msg": "ok", "data": {"meta": meta, "features": features, "keff": y}})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000, debug=True)