269 lines
10 KiB
Python
269 lines
10 KiB
Python
"""仿真管理服务。
|
||
|
||
负责:
|
||
- 仿真生命周期(start/stop)
|
||
- 设备接入(目前为 MockVehicleDevice)
|
||
- 信号采样与广播(WebSocket)
|
||
- 信号落库(TimescaleDB hypertable)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import secrets
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
|
||
from sqlalchemy import insert
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||
|
||
from backend.database.schema import SimulationTask, vehicle_signals
|
||
from backend.device.mock_vehicle import MockVehicleDevice
|
||
from backend.services.broadcaster import Broadcaster
|
||
from backend.services.unity_socket_client import UnitySocketClient
|
||
from backend.utils import utc_now
|
||
|
||
|
||
logger = logging.getLogger("backend.simulation")
|
||
|
||
|
||
@dataclass
|
||
class SimulationRuntime:
|
||
"""运行中的仿真信息(内存态)。"""
|
||
|
||
simulation_id: str
|
||
status: str
|
||
task: asyncio.Task | None = None
|
||
|
||
|
||
class SimulationManager:
|
||
"""仿真生命周期管理器。"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_factory: async_sessionmaker[AsyncSession],
|
||
broadcaster: Broadcaster,
|
||
unity_client: UnitySocketClient | None = None,
|
||
) -> None:
|
||
self._session_factory = session_factory
|
||
self._broadcaster = broadcaster
|
||
self._unity_client = unity_client
|
||
self._runtime: SimulationRuntime | None = None
|
||
self._device = MockVehicleDevice()
|
||
self._seq = 0
|
||
self._command_seq = 0
|
||
|
||
def current(self) -> SimulationRuntime | None:
|
||
"""返回当前运行中的仿真(若无则为 None)。"""
|
||
return self._runtime
|
||
|
||
async def register_device(self, device: MockVehicleDevice) -> None:
|
||
"""注册仿真设备实现(用于采样)。"""
|
||
self._device = device
|
||
|
||
async def init_config(self, init_config: dict[str, Any]) -> str:
|
||
session_info = init_config.get("session") or {}
|
||
driver_info = init_config.get("driver") or {}
|
||
vehicle_info = init_config.get("vehicle") or {}
|
||
scene_info = init_config.get("scene") or {}
|
||
|
||
task_id = str(session_info.get("taskId") or "").strip() or None
|
||
simulation_id = None
|
||
if task_id and len(task_id) <= 64:
|
||
simulation_id = task_id
|
||
if not simulation_id:
|
||
simulation_id = "SIM" + utc_now().strftime("%Y%m%d%H%M%S") + secrets.token_hex(2).upper()
|
||
|
||
now = utc_now()
|
||
task_name = (session_info.get("taskName") or None)
|
||
sync_timestamp = session_info.get("syncTimestamp")
|
||
driver_id = (driver_info.get("driverId") or None)
|
||
vehicle_id = (vehicle_info.get("vehicleId") or None)
|
||
scene_id = (scene_info.get("sceneId") or None)
|
||
scene_name = (scene_info.get("sceneName") or None)
|
||
scene_config = scene_info if isinstance(scene_info, dict) else {}
|
||
operator = driver_info.get("name") or None
|
||
|
||
async with self._session_factory() as session:
|
||
sim = await session.get(SimulationTask, simulation_id)
|
||
if sim is None:
|
||
sim = SimulationTask(
|
||
task_id=simulation_id,
|
||
task_name=task_name,
|
||
scene_id=scene_id,
|
||
scene_name=scene_name,
|
||
scene_config=scene_config,
|
||
config_created_at=now,
|
||
started_at=now,
|
||
ended_at=None,
|
||
status="wait",
|
||
operator=operator,
|
||
unity_host=self._unity_client.host if self._unity_client is not None else None,
|
||
unity_port=self._unity_client.port if self._unity_client is not None else None,
|
||
sync_timestamp=int(sync_timestamp) if sync_timestamp is not None else None,
|
||
init_config=init_config,
|
||
init_sent_at=None,
|
||
)
|
||
session.add(sim)
|
||
else:
|
||
sim.task_name = task_name
|
||
sim.scene_id = scene_id
|
||
sim.scene_name = scene_name
|
||
sim.scene_config = scene_config
|
||
sim.config_created_at = now
|
||
sim.operator = operator
|
||
sim.sync_timestamp = int(sync_timestamp) if sync_timestamp is not None else None
|
||
sim.init_config = init_config
|
||
if self._unity_client is not None:
|
||
sim.unity_host = self._unity_client.host
|
||
sim.unity_port = self._unity_client.port
|
||
|
||
await session.commit()
|
||
|
||
if self._unity_client is not None:
|
||
payload = dict(init_config)
|
||
payload.setdefault("msgType", "init")
|
||
await self._unity_client.send_json(payload)
|
||
async with self._session_factory() as session:
|
||
sim = await session.get(SimulationTask, simulation_id)
|
||
if sim is not None:
|
||
sim.init_sent_at = utc_now()
|
||
await session.commit()
|
||
|
||
await self._broadcaster.broadcast_json(
|
||
{"type": "simulation.init_config", "ts": now.timestamp(), "simulation_id": simulation_id, "payload": init_config}
|
||
)
|
||
return simulation_id
|
||
|
||
async def send_command(self, command: dict[str, Any]) -> None:
|
||
if self._unity_client is None:
|
||
raise RuntimeError("unity client not configured")
|
||
payload = dict(command)
|
||
payload.setdefault("msgType", "command")
|
||
payload.setdefault("timestamp", int(utc_now().timestamp() * 1000))
|
||
if "seqId" not in payload:
|
||
self._command_seq += 1
|
||
payload["seqId"] = self._command_seq
|
||
await self._unity_client.send_json(payload)
|
||
await self._broadcaster.broadcast_json(
|
||
{"type": "simulation.command", "ts": utc_now().timestamp(), "payload": payload}
|
||
)
|
||
|
||
async def start(self, scenario_config: dict[str, Any]) -> str:
|
||
"""启动仿真并返回 simulation_id。
|
||
|
||
说明:如果已有仿真在运行,会直接返回当前 simulation_id(幂等)。
|
||
"""
|
||
if self._runtime and self._runtime.status == "running":
|
||
return self._runtime.simulation_id
|
||
|
||
simulation_id = "SIM" + utc_now().strftime("%Y%m%d%H%M%S") + secrets.token_hex(2).upper()
|
||
started_at = utc_now()
|
||
task_name = scenario_config.get("scenario")
|
||
operator = scenario_config.get("driver") or scenario_config.get("operator")
|
||
config_created_at = started_at
|
||
|
||
async with self._session_factory() as session:
|
||
session.add(
|
||
SimulationTask(
|
||
task_id=simulation_id,
|
||
task_name=task_name,
|
||
scene_id=None,
|
||
scene_name=None,
|
||
scene_config=scenario_config,
|
||
config_created_at=config_created_at,
|
||
started_at=started_at,
|
||
ended_at=None,
|
||
status="running",
|
||
operator=operator,
|
||
unity_host=self._unity_client.host if self._unity_client is not None else None,
|
||
unity_port=self._unity_client.port if self._unity_client is not None else None,
|
||
sync_timestamp=None,
|
||
init_config=None,
|
||
init_sent_at=None,
|
||
)
|
||
)
|
||
await session.commit()
|
||
|
||
await self._device.connect()
|
||
self._runtime = SimulationRuntime(simulation_id=simulation_id, status="running")
|
||
self._runtime.task = asyncio.create_task(self._run_loop(simulation_id))
|
||
await self._broadcaster.broadcast_json(
|
||
{"type": "simulation.status", "ts": started_at.timestamp(), "simulation_id": simulation_id, "payload": {"status": "running"}}
|
||
)
|
||
if self._unity_client is not None:
|
||
await self.send_command({"action": "start"})
|
||
return simulation_id
|
||
|
||
async def stop(self, simulation_id: str) -> None:
|
||
"""停止仿真(若 simulation_id 不匹配当前运行实例则忽略)。"""
|
||
runtime = self._runtime
|
||
if not runtime or runtime.simulation_id != simulation_id:
|
||
return
|
||
|
||
runtime.status = "stopping"
|
||
if runtime.task:
|
||
runtime.task.cancel()
|
||
try:
|
||
await runtime.task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
await self._device.disconnect()
|
||
ended_at = utc_now()
|
||
|
||
async with self._session_factory() as session:
|
||
sim = await session.get(SimulationTask, simulation_id)
|
||
if sim:
|
||
sim.status = "stopped"
|
||
sim.ended_at = ended_at
|
||
await session.commit()
|
||
|
||
await self._broadcaster.broadcast_json(
|
||
{"type": "simulation.status", "ts": ended_at.timestamp(), "simulation_id": simulation_id, "payload": {"status": "stopped"}}
|
||
)
|
||
if self._unity_client is not None:
|
||
await self.send_command({"action": "stop"})
|
||
self._runtime = None
|
||
|
||
async def _run_loop(self, simulation_id: str) -> None:
|
||
"""仿真运行循环:采样设备信号、广播并写入数据库。"""
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(0.05)
|
||
if not await self._device.is_connected():
|
||
continue
|
||
|
||
self._seq += 1
|
||
ts = utc_now()
|
||
payload = self._device.sample().to_dict()
|
||
message = {
|
||
"type": "vehicle.signal",
|
||
"ts": ts.timestamp(),
|
||
"simulation_id": simulation_id,
|
||
"device_id": self._device.device_id,
|
||
"seq": self._seq,
|
||
"payload": payload,
|
||
}
|
||
await self._broadcaster.broadcast_json(message)
|
||
await self._persist_signal(ts, simulation_id, self._device.device_id, self._seq, payload)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception:
|
||
logger.exception("simulation loop crashed")
|
||
|
||
async def _persist_signal(self, ts, simulation_id: str, device_id: str, seq: int, signals: dict[str, Any]) -> None:
|
||
"""将单条信号写入 sim_vehicle_signals(TimescaleDB)。"""
|
||
async with self._session_factory() as session:
|
||
await session.execute(
|
||
insert(vehicle_signals).values(
|
||
ts=ts,
|
||
simulation_id=simulation_id,
|
||
device_id=device_id,
|
||
seq=seq,
|
||
signals=signals,
|
||
)
|
||
)
|
||
await session.commit()
|