SmartEDT/backend/database/test_db.py

167 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据库性能与功能测试脚本(开发/压测用途)。
该脚本会:
- 初始化 schema 与 TimescaleDB若可用
- 批量写入模拟车辆信号JSONB
- 运行几类常见查询并输出耗时
注意:该脚本会写入大量数据,请不要在生产库中执行。
"""
import asyncio
import os
import time
import json
import random
from datetime import datetime, timezone
from sqlalchemy import insert, select, text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.engine.url import make_url
from backend.database.schema import SimulationTask, unity_frames, vehicle_signals, init_schema, init_timescaledb
from backend.config.settings import load_settings
# 模拟数据生成
def generate_payload():
"""生成一条模拟车辆信号负载(用于写入 JSONB"""
return {
"steering_wheel_angle_deg": round(random.uniform(-450, 450), 1),
"brake_pedal_travel_mm": round(random.uniform(0, 100), 1),
"throttle_pedal_travel_mm": round(random.uniform(0, 100), 1),
"gear": random.choice(["P", "N", "D", "R"]),
"handbrake": random.choice([0, 1]),
"vehicle_speed_kmh": round(random.uniform(0, 180), 1),
"wheel_speed_rpm": {
"FL": random.randint(0, 2000),
"FR": random.randint(0, 2000),
"RL": random.randint(0, 2000),
"RR": random.randint(0, 2000)
},
"lights": {
"left_turn": random.choice([0, 1]),
"right_turn": random.choice([0, 1]),
"hazard": random.choice([0, 1]),
"brake": random.choice([0, 1])
},
"soc_percent": round(random.uniform(0, 100), 1),
"voltage_v": round(random.uniform(300, 400), 1),
"current_a": round(random.uniform(-50, 200), 1),
"temperature_c": round(random.uniform(20, 80), 1)
}
def _redact_url(url: str) -> str:
"""隐藏数据库 URL 中的密码,避免误打印敏感信息。"""
try:
parsed = make_url(url)
if parsed.password:
parsed = parsed.set(password="***")
return str(parsed)
except Exception:
return url
async def run_test():
"""执行写入/查询性能测试。"""
settings = load_settings()
db_url = os.getenv("SMARTEDT_TEST_DATABASE_URL", settings.database.url).strip()
print(f"Connecting to DB: {_redact_url(db_url)}")
engine = create_async_engine(db_url, echo=False)
# 0. 初始化表结构 (如果不存在)
print("Initializing schema...")
await init_schema(engine)
try:
await init_timescaledb(engine)
print("Schema and TimescaleDB initialized.")
except Exception as e:
print(f"TimescaleDB init warning (might already exist): {e}")
# 1. 准备测试数据
total_records = int(os.getenv("SMARTEDT_TEST_RECORDS", "300000"))
batch_size = int(os.getenv("SMARTEDT_TEST_BATCH_SIZE", "1000"))
simulation_id = f"TEST_SIM_{int(time.time())}"
device_id = "test_device_01"
print(f"Generating {total_records} records for simulation {simulation_id}...")
print("Starting insertion test...")
# 2. 插入性能测试
insert_start_time = time.time()
async with engine.begin() as conn:
# 分批插入
for base_seq in range(0, total_records, batch_size):
batch = []
end_seq = min(base_seq + batch_size, total_records)
for seq in range(base_seq, end_seq):
batch.append(
{
"ts": datetime.now(timezone.utc),
"simulation_id": simulation_id,
"device_id": device_id,
"seq": seq,
"signals": generate_payload(),
}
)
await conn.execute(insert(vehicle_signals), batch)
if end_seq % 50000 == 0:
print(f"Inserted {end_seq} records...")
insert_end_time = time.time()
insert_duration = insert_end_time - insert_start_time
print(f"\n✅ Insertion Test Complete:")
print(f"Total Records: {total_records}")
print(f"Time Taken: {insert_duration:.4f} seconds")
print(f"Throughput: {total_records / insert_duration:.2f} records/sec")
# 3. 查询性能测试
print("\nStarting query performance test...")
# 3.1 简单计数查询
query_start = time.time()
async with engine.connect() as conn:
result = await conn.execute(
select(text("count(*)")).select_from(vehicle_signals).where(vehicle_signals.c.simulation_id == simulation_id)
)
count = result.scalar()
query_end = time.time()
print(f"Query 1 (Count): Found {count} records in {query_end - query_start:.4f} seconds")
# 3.2 复杂 JSONB 查询 (查询车速 > 100 的记录数)
# 注意JSONB 查询语法取决于数据库和 SQLAlchemy 版本,这里使用 text() 以确保兼容性
query_start = time.time()
async with engine.connect() as conn:
# 查询 signals->>'vehicle_speed_kmh' > 100
stmt = text(
"SELECT count(*) FROM vehicle_signals "
"WHERE simulation_id = :sim_id "
"AND (signals->>'vehicle_speed_kmh')::float > 100"
)
result = await conn.execute(stmt, {"sim_id": simulation_id})
high_speed_count = result.scalar()
query_end = time.time()
print(f"Query 2 (JSONB Filter): Found {high_speed_count} records with speed > 100 in {query_end - query_start:.4f} seconds")
# 3.3 时间范围查询 (查询最近 1000 条)
query_start = time.time()
async with engine.connect() as conn:
stmt = (
select(vehicle_signals)
.where(vehicle_signals.c.simulation_id == simulation_id)
.order_by(vehicle_signals.c.ts.desc())
.limit(1000)
)
result = await conn.execute(stmt)
rows = result.fetchall()
query_end = time.time()
print(f"Query 3 (Time Range Limit): Retrieved {len(rows)} records in {query_end - query_start:.4f} seconds")
await engine.dispose()
if __name__ == "__main__":
# 确保在 Windows 上正确运行 asyncio
if hasattr(asyncio, 'WindowsSelectorEventLoopPolicy'):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(run_test())