SmartEDT/backend/services/simulation_manager.py

269 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""仿真管理服务。
负责:
- 仿真生命周期start/stop
- 设备接入(目前为 MockVehicleDevice
- 信号采样与广播WebSocket
- 信号落库TimescaleDB hypertable
"""
from __future__ import annotations
import asyncio
import logging
import secrets
from dataclasses import dataclass
from typing import Any
from sqlalchemy import insert
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from backend.database.schema import SimulationTask, vehicle_signals
from backend.device.mock_vehicle import MockVehicleDevice
from backend.services.broadcaster import Broadcaster
from backend.services.unity_socket_client import UnitySocketClient
from backend.utils import utc_now
logger = logging.getLogger("backend.simulation")
@dataclass
class SimulationRuntime:
"""运行中的仿真信息(内存态)。"""
simulation_id: str
status: str
task: asyncio.Task | None = None
class SimulationManager:
"""仿真生命周期管理器。"""
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
broadcaster: Broadcaster,
unity_client: UnitySocketClient | None = None,
) -> None:
self._session_factory = session_factory
self._broadcaster = broadcaster
self._unity_client = unity_client
self._runtime: SimulationRuntime | None = None
self._device = MockVehicleDevice()
self._seq = 0
self._command_seq = 0
def current(self) -> SimulationRuntime | None:
"""返回当前运行中的仿真(若无则为 None"""
return self._runtime
async def register_device(self, device: MockVehicleDevice) -> None:
"""注册仿真设备实现(用于采样)。"""
self._device = device
async def init_config(self, init_config: dict[str, Any]) -> str:
session_info = init_config.get("session") or {}
driver_info = init_config.get("driver") or {}
vehicle_info = init_config.get("vehicle") or {}
scene_info = init_config.get("scene") or {}
task_id = str(session_info.get("taskId") or "").strip() or None
simulation_id = None
if task_id and len(task_id) <= 64:
simulation_id = task_id
if not simulation_id:
simulation_id = "SIM" + utc_now().strftime("%Y%m%d%H%M%S") + secrets.token_hex(2).upper()
now = utc_now()
task_name = (session_info.get("taskName") or None)
sync_timestamp = session_info.get("syncTimestamp")
driver_id = (driver_info.get("driverId") or None)
vehicle_id = (vehicle_info.get("vehicleId") or None)
scene_id = (scene_info.get("sceneId") or None)
scene_name = (scene_info.get("sceneName") or None)
scene_config = scene_info if isinstance(scene_info, dict) else {}
operator = driver_info.get("name") or None
async with self._session_factory() as session:
sim = await session.get(SimulationTask, simulation_id)
if sim is None:
sim = SimulationTask(
task_id=simulation_id,
task_name=task_name,
scene_id=scene_id,
scene_name=scene_name,
scene_config=scene_config,
config_created_at=now,
started_at=now,
ended_at=None,
status="wait",
operator=operator,
unity_host=self._unity_client.host if self._unity_client is not None else None,
unity_port=self._unity_client.port if self._unity_client is not None else None,
sync_timestamp=int(sync_timestamp) if sync_timestamp is not None else None,
init_config=init_config,
init_sent_at=None,
)
session.add(sim)
else:
sim.task_name = task_name
sim.scene_id = scene_id
sim.scene_name = scene_name
sim.scene_config = scene_config
sim.config_created_at = now
sim.operator = operator
sim.sync_timestamp = int(sync_timestamp) if sync_timestamp is not None else None
sim.init_config = init_config
if self._unity_client is not None:
sim.unity_host = self._unity_client.host
sim.unity_port = self._unity_client.port
await session.commit()
if self._unity_client is not None:
payload = dict(init_config)
payload.setdefault("msgType", "init")
await self._unity_client.send_json(payload)
async with self._session_factory() as session:
sim = await session.get(SimulationTask, simulation_id)
if sim is not None:
sim.init_sent_at = utc_now()
await session.commit()
await self._broadcaster.broadcast_json(
{"type": "simulation.init_config", "ts": now.timestamp(), "simulation_id": simulation_id, "payload": init_config}
)
return simulation_id
async def send_command(self, command: dict[str, Any]) -> None:
if self._unity_client is None:
raise RuntimeError("unity client not configured")
payload = dict(command)
payload.setdefault("msgType", "command")
payload.setdefault("timestamp", int(utc_now().timestamp() * 1000))
if "seqId" not in payload:
self._command_seq += 1
payload["seqId"] = self._command_seq
await self._unity_client.send_json(payload)
await self._broadcaster.broadcast_json(
{"type": "simulation.command", "ts": utc_now().timestamp(), "payload": payload}
)
async def start(self, scenario_config: dict[str, Any]) -> str:
"""启动仿真并返回 simulation_id。
说明:如果已有仿真在运行,会直接返回当前 simulation_id幂等
"""
if self._runtime and self._runtime.status == "running":
return self._runtime.simulation_id
simulation_id = "SIM" + utc_now().strftime("%Y%m%d%H%M%S") + secrets.token_hex(2).upper()
started_at = utc_now()
task_name = scenario_config.get("scenario")
operator = scenario_config.get("driver") or scenario_config.get("operator")
config_created_at = started_at
async with self._session_factory() as session:
session.add(
SimulationTask(
task_id=simulation_id,
task_name=task_name,
scene_id=None,
scene_name=None,
scene_config=scenario_config,
config_created_at=config_created_at,
started_at=started_at,
ended_at=None,
status="running",
operator=operator,
unity_host=self._unity_client.host if self._unity_client is not None else None,
unity_port=self._unity_client.port if self._unity_client is not None else None,
sync_timestamp=None,
init_config=None,
init_sent_at=None,
)
)
await session.commit()
await self._device.connect()
self._runtime = SimulationRuntime(simulation_id=simulation_id, status="running")
self._runtime.task = asyncio.create_task(self._run_loop(simulation_id))
await self._broadcaster.broadcast_json(
{"type": "simulation.status", "ts": started_at.timestamp(), "simulation_id": simulation_id, "payload": {"status": "running"}}
)
if self._unity_client is not None:
await self.send_command({"action": "start"})
return simulation_id
async def stop(self, simulation_id: str) -> None:
"""停止仿真(若 simulation_id 不匹配当前运行实例则忽略)。"""
runtime = self._runtime
if not runtime or runtime.simulation_id != simulation_id:
return
runtime.status = "stopping"
if runtime.task:
runtime.task.cancel()
try:
await runtime.task
except asyncio.CancelledError:
pass
await self._device.disconnect()
ended_at = utc_now()
async with self._session_factory() as session:
sim = await session.get(SimulationTask, simulation_id)
if sim:
sim.status = "stopped"
sim.ended_at = ended_at
await session.commit()
await self._broadcaster.broadcast_json(
{"type": "simulation.status", "ts": ended_at.timestamp(), "simulation_id": simulation_id, "payload": {"status": "stopped"}}
)
if self._unity_client is not None:
await self.send_command({"action": "stop"})
self._runtime = None
async def _run_loop(self, simulation_id: str) -> None:
"""仿真运行循环:采样设备信号、广播并写入数据库。"""
try:
while True:
await asyncio.sleep(0.05)
if not await self._device.is_connected():
continue
self._seq += 1
ts = utc_now()
payload = self._device.sample().to_dict()
message = {
"type": "vehicle.signal",
"ts": ts.timestamp(),
"simulation_id": simulation_id,
"device_id": self._device.device_id,
"seq": self._seq,
"payload": payload,
}
await self._broadcaster.broadcast_json(message)
await self._persist_signal(ts, simulation_id, self._device.device_id, self._seq, payload)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("simulation loop crashed")
async def _persist_signal(self, ts, simulation_id: str, device_id: str, seq: int, signals: dict[str, Any]) -> None:
"""将单条信号写入 sim_vehicle_signalsTimescaleDB"""
async with self._session_factory() as session:
await session.execute(
insert(vehicle_signals).values(
ts=ts,
simulation_id=simulation_id,
device_id=device_id,
seq=seq,
signals=signals,
)
)
await session.commit()