73 lines
2.8 KiB
Python
73 lines
2.8 KiB
Python
from flask import Flask, request, jsonify
|
|
import os
|
|
try:
|
|
from .service import train_one_type, infer_one, infer_batch, train_one_type_from_samples
|
|
except ImportError:
|
|
import sys
|
|
sys.path.append(os.path.dirname(__file__))
|
|
from service import train_one_type, infer_one, infer_batch, train_one_type_from_samples
|
|
|
|
app = Flask(__name__)
|
|
|
|
def ensure_model_dir(device_type, model_dir):
|
|
root = model_dir if model_dir else os.path.join("models", device_type)
|
|
return root
|
|
|
|
@app.post("/v1/train/<device_type>")
|
|
def train(device_type):
|
|
data = request.get_json(force=True)
|
|
dataset_path = data.get("dataset_path")
|
|
model_dir = data.get("model_dir")
|
|
if not dataset_path or not os.path.exists(dataset_path):
|
|
return jsonify({"code": 1, "msg": "dataset_path not found"}), 400
|
|
out_dir = ensure_model_dir(device_type, model_dir)
|
|
res = train_one_type(device_type, dataset_path, out_dir)
|
|
return jsonify({"code": 0, "msg": "ok", "data": res})
|
|
|
|
@app.post("/v1/train/<device_type>/from-samples")
|
|
def train_from_samples(device_type):
|
|
data = request.get_json(force=True)
|
|
samples = data.get("samples")
|
|
model_dir = data.get("model_dir")
|
|
if not samples or len(samples) == 0:
|
|
return jsonify({"code": 1, "msg": "samples required"}), 400
|
|
feats = []
|
|
labels = []
|
|
for s in samples:
|
|
feats.append(s.get("features"))
|
|
labels.append(s.get("label"))
|
|
out_dir = ensure_model_dir(device_type, model_dir)
|
|
res = train_one_type_from_samples(device_type, feats, labels, out_dir)
|
|
return jsonify({"code": 0, "msg": "ok", "data": res})
|
|
|
|
@app.post("/v1/infer/<device_type>/keff")
|
|
def infer(device_type):
|
|
data = request.get_json(force=True)
|
|
model_dir = data.get("model_dir")
|
|
out_dir = ensure_model_dir(device_type, model_dir)
|
|
batch = data.get("batch")
|
|
features = data.get("features")
|
|
meta = data.get("meta") or {}
|
|
if batch and len(batch) > 0:
|
|
feats_list = []
|
|
metas = []
|
|
for s in batch:
|
|
if isinstance(s, dict) and "features" in s:
|
|
feats_list.append(s["features"])
|
|
metas.append(s.get("meta") or {})
|
|
else:
|
|
feats_list.append(s)
|
|
metas.append({})
|
|
ys = infer_batch(device_type, feats_list, out_dir)
|
|
items = []
|
|
for i, y in enumerate(ys):
|
|
items.append({"meta": metas[i], "features": feats_list[i], "keff": y})
|
|
return jsonify({"code": 0, "msg": "ok", "data": {"items": items}})
|
|
if not features:
|
|
return jsonify({"code": 1, "msg": "features required"}), 400
|
|
y = infer_one(device_type, features, out_dir)
|
|
return jsonify({"code": 0, "msg": "ok", "data": {"meta": meta, "features": features, "keff": y}})
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=8000, debug=True)
|