SmartEDT/backend/database/test_db.py

154 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
import time
import json
import random
from datetime import datetime, timezone
from sqlalchemy import insert, select, text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.engine.url import make_url
from backend.database.schema import vehicle_signals, Simulation, init_schema, init_timescaledb
from backend.config.settings import load_settings
# 模拟数据生成
def generate_payload():
return {
"steering_wheel_angle_deg": round(random.uniform(-450, 450), 1),
"brake_pedal_travel_mm": round(random.uniform(0, 100), 1),
"throttle_pedal_travel_mm": round(random.uniform(0, 100), 1),
"gear": random.choice(["P", "N", "D", "R"]),
"handbrake": random.choice([0, 1]),
"vehicle_speed_kmh": round(random.uniform(0, 180), 1),
"wheel_speed_rpm": {
"FL": random.randint(0, 2000),
"FR": random.randint(0, 2000),
"RL": random.randint(0, 2000),
"RR": random.randint(0, 2000)
},
"lights": {
"left_turn": random.choice([0, 1]),
"right_turn": random.choice([0, 1]),
"hazard": random.choice([0, 1]),
"brake": random.choice([0, 1])
},
"soc_percent": round(random.uniform(0, 100), 1),
"voltage_v": round(random.uniform(300, 400), 1),
"current_a": round(random.uniform(-50, 200), 1),
"temperature_c": round(random.uniform(20, 80), 1)
}
def _redact_url(url: str) -> str:
try:
parsed = make_url(url)
if parsed.password:
parsed = parsed.set(password="***")
return str(parsed)
except Exception:
return url
async def run_test():
settings = load_settings()
db_url = os.getenv("SMARTEDT_TEST_DATABASE_URL", settings.database.url).strip()
print(f"Connecting to DB: {_redact_url(db_url)}")
engine = create_async_engine(db_url, echo=False)
# 0. 初始化表结构 (如果不存在)
print("Initializing schema...")
await init_schema(engine)
try:
await init_timescaledb(engine)
print("Schema and TimescaleDB initialized.")
except Exception as e:
print(f"TimescaleDB init warning (might already exist): {e}")
# 1. 准备测试数据
total_records = int(os.getenv("SMARTEDT_TEST_RECORDS", "300000"))
batch_size = int(os.getenv("SMARTEDT_TEST_BATCH_SIZE", "1000"))
simulation_id = f"TEST_SIM_{int(time.time())}"
device_id = "test_device_01"
print(f"Generating {total_records} records for simulation {simulation_id}...")
print("Starting insertion test...")
# 2. 插入性能测试
insert_start_time = time.time()
async with engine.begin() as conn:
# 分批插入
for base_seq in range(0, total_records, batch_size):
batch = []
end_seq = min(base_seq + batch_size, total_records)
for seq in range(base_seq, end_seq):
batch.append(
{
"ts": datetime.now(timezone.utc),
"simulation_id": simulation_id,
"device_id": device_id,
"seq": seq,
"signals": generate_payload(),
}
)
await conn.execute(insert(vehicle_signals), batch)
if end_seq % 50000 == 0:
print(f"Inserted {end_seq} records...")
insert_end_time = time.time()
insert_duration = insert_end_time - insert_start_time
print(f"\n✅ Insertion Test Complete:")
print(f"Total Records: {total_records}")
print(f"Time Taken: {insert_duration:.4f} seconds")
print(f"Throughput: {total_records / insert_duration:.2f} records/sec")
# 3. 查询性能测试
print("\nStarting query performance test...")
# 3.1 简单计数查询
query_start = time.time()
async with engine.connect() as conn:
result = await conn.execute(
select(text("count(*)")).select_from(vehicle_signals).where(vehicle_signals.c.simulation_id == simulation_id)
)
count = result.scalar()
query_end = time.time()
print(f"Query 1 (Count): Found {count} records in {query_end - query_start:.4f} seconds")
# 3.2 复杂 JSONB 查询 (查询车速 > 100 的记录数)
# 注意JSONB 查询语法取决于数据库和 SQLAlchemy 版本,这里使用 text() 以确保兼容性
query_start = time.time()
async with engine.connect() as conn:
# 查询 signals->>'vehicle_speed_kmh' > 100
stmt = text(
"SELECT count(*) FROM vehicle_signals "
"WHERE simulation_id = :sim_id "
"AND (signals->>'vehicle_speed_kmh')::float > 100"
)
result = await conn.execute(stmt, {"sim_id": simulation_id})
high_speed_count = result.scalar()
query_end = time.time()
print(f"Query 2 (JSONB Filter): Found {high_speed_count} records with speed > 100 in {query_end - query_start:.4f} seconds")
# 3.3 时间范围查询 (查询最近 1000 条)
query_start = time.time()
async with engine.connect() as conn:
stmt = (
select(vehicle_signals)
.where(vehicle_signals.c.simulation_id == simulation_id)
.order_by(vehicle_signals.c.ts.desc())
.limit(1000)
)
result = await conn.execute(stmt)
rows = result.fetchall()
query_end = time.time()
print(f"Query 3 (Time Range Limit): Retrieved {len(rows)} records in {query_end - query_start:.4f} seconds")
await engine.dispose()
if __name__ == "__main__":
# 确保在 Windows 上正确运行 asyncio
if hasattr(asyncio, 'WindowsSelectorEventLoopPolicy'):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(run_test())