BodyBalanceEvaluation/backend/detection_engine.py
2025-07-28 11:59:56 +08:00

615 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
检测引擎模块
负责实时数据处理、姿态分析和平衡评估
"""
import numpy as np
import cv2
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
import logging
import threading
from collections import deque
import json
logger = logging.getLogger(__name__)
class DetectionEngine:
"""检测引擎"""
def __init__(self):
self.session_data = {} # 存储各会话的数据
self.data_lock = threading.Lock()
self.analysis_algorithms = {
'balance_analysis': BalanceAnalyzer(),
'posture_analysis': PostureAnalyzer(),
'movement_analysis': MovementAnalyzer()
}
logger.info('检测引擎初始化完成')
def start_session(self, session_id: str, settings: Dict[str, Any]):
"""开始检测会话"""
with self.data_lock:
self.session_data[session_id] = {
'settings': settings,
'start_time': datetime.now(),
'data_buffer': deque(maxlen=1000), # 保留最近1000个数据点
'analysis_results': {},
'real_time_metrics': {}
}
logger.info(f'检测会话开始: {session_id}')
def process_data(self, session_id: str, raw_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理实时数据"""
if session_id not in self.session_data:
logger.warning(f'会话不存在: {session_id}')
return {}
try:
# 数据预处理
processed_data = self._preprocess_data(raw_data)
# 存储数据
with self.data_lock:
self.session_data[session_id]['data_buffer'].append(processed_data)
# 实时分析
real_time_results = self._real_time_analysis(session_id, processed_data)
# 更新实时指标
with self.data_lock:
self.session_data[session_id]['real_time_metrics'].update(real_time_results)
return real_time_results
except Exception as e:
logger.error(f'数据处理失败: {e}')
return {}
def _preprocess_data(self, raw_data: Dict[str, Any]) -> Dict[str, Any]:
"""数据预处理"""
processed = {
'timestamp': raw_data.get('timestamp', datetime.now().isoformat())
}
# 处理摄像头数据
if 'camera' in raw_data:
camera_data = raw_data['camera']
if 'pose_data' in camera_data:
processed['pose'] = self._process_pose_data(camera_data['pose_data'])
# 处理IMU数据
if 'imu' in raw_data:
processed['imu'] = self._process_imu_data(raw_data['imu'])
# 处理压力数据
if 'pressure' in raw_data:
processed['pressure'] = self._process_pressure_data(raw_data['pressure'])
return processed
def _process_pose_data(self, pose_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理姿态数据"""
return {
'center_of_gravity': pose_data.get('center_of_gravity', {'x': 0, 'y': 0}),
'body_angle': pose_data.get('body_angle', {'pitch': 0, 'roll': 0, 'yaw': 0}),
'confidence': pose_data.get('confidence', 0.0)
}
def _process_imu_data(self, imu_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理IMU数据"""
# 计算合成加速度
accel = imu_data.get('accel', {'x': 0, 'y': 0, 'z': 0})
total_accel = np.sqrt(accel['x']**2 + accel['y']**2 + accel['z']**2)
# 计算倾斜角度
pitch = np.arctan2(accel['y'], np.sqrt(accel['x']**2 + accel['z']**2)) * 180 / np.pi
roll = np.arctan2(-accel['x'], accel['z']) * 180 / np.pi
return {
'accel': accel,
'gyro': imu_data.get('gyro', {'x': 0, 'y': 0, 'z': 0}),
'total_accel': total_accel,
'pitch': pitch,
'roll': roll,
'temperature': imu_data.get('temperature', 0)
}
def _process_pressure_data(self, pressure_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理压力数据"""
left_foot = pressure_data.get('left_foot', 0)
right_foot = pressure_data.get('right_foot', 0)
total_pressure = left_foot + right_foot
# 计算压力分布比例
if total_pressure > 0:
left_ratio = left_foot / total_pressure
right_ratio = right_foot / total_pressure
balance_index = abs(left_ratio - right_ratio) # 平衡指数,越小越平衡
else:
left_ratio = right_ratio = 0.5
balance_index = 0
return {
'left_foot': left_foot,
'right_foot': right_foot,
'total_pressure': total_pressure,
'left_ratio': left_ratio,
'right_ratio': right_ratio,
'balance_index': balance_index,
'center_of_pressure': pressure_data.get('center_of_pressure', {'x': 0, 'y': 0})
}
def _real_time_analysis(self, session_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""实时分析"""
results = {}
try:
# 获取历史数据用于趋势分析
with self.data_lock:
data_buffer = list(self.session_data[session_id]['data_buffer'])
if len(data_buffer) < 2:
return results
# 平衡分析
if 'pressure' in data:
balance_result = self.analysis_algorithms['balance_analysis'].analyze_real_time(
data['pressure'], data_buffer
)
results['balance'] = balance_result
# 姿态分析
if 'pose' in data or 'imu' in data:
posture_result = self.analysis_algorithms['posture_analysis'].analyze_real_time(
data, data_buffer
)
results['posture'] = posture_result
# 运动分析
movement_result = self.analysis_algorithms['movement_analysis'].analyze_real_time(
data, data_buffer
)
results['movement'] = movement_result
except Exception as e:
logger.error(f'实时分析失败: {e}')
return results
def get_latest_data(self, session_id: str) -> Dict[str, Any]:
"""获取最新数据"""
if session_id not in self.session_data:
return {}
with self.data_lock:
session = self.session_data[session_id]
if not session['data_buffer']:
return {}
latest_data = session['data_buffer'][-1]
real_time_metrics = session['real_time_metrics'].copy()
return {
'latest_data': latest_data,
'real_time_metrics': real_time_metrics,
'data_count': len(session['data_buffer'])
}
def analyze_session(self, session_id: str) -> Dict[str, Any]:
"""分析整个会话数据"""
if session_id not in self.session_data:
logger.warning(f'会话不存在: {session_id}')
return {}
try:
with self.data_lock:
session = self.session_data[session_id]
data_buffer = list(session['data_buffer'])
settings = session['settings']
if not data_buffer:
return {'error': '没有数据可分析'}
analysis_results = {}
# 全面平衡分析
balance_analysis = self.analysis_algorithms['balance_analysis'].analyze_full_session(
data_buffer, settings
)
analysis_results['balance'] = balance_analysis
# 全面姿态分析
posture_analysis = self.analysis_algorithms['posture_analysis'].analyze_full_session(
data_buffer, settings
)
analysis_results['posture'] = posture_analysis
# 全面运动分析
movement_analysis = self.analysis_algorithms['movement_analysis'].analyze_full_session(
data_buffer, settings
)
analysis_results['movement'] = movement_analysis
# 综合评估
overall_assessment = self._generate_overall_assessment(analysis_results)
analysis_results['overall'] = overall_assessment
# 保存分析结果
with self.data_lock:
self.session_data[session_id]['analysis_results'] = analysis_results
logger.info(f'会话分析完成: {session_id}')
return analysis_results
except Exception as e:
logger.error(f'会话分析失败: {e}')
return {'error': str(e)}
def _generate_overall_assessment(self, analysis_results: Dict[str, Any]) -> Dict[str, Any]:
"""生成综合评估"""
try:
# 提取各项评分
balance_score = analysis_results.get('balance', {}).get('score', 0)
posture_score = analysis_results.get('posture', {}).get('score', 0)
movement_score = analysis_results.get('movement', {}).get('score', 0)
# 计算综合评分(加权平均)
weights = {'balance': 0.4, 'posture': 0.3, 'movement': 0.3}
overall_score = (
balance_score * weights['balance'] +
posture_score * weights['posture'] +
movement_score * weights['movement']
)
# 评估等级
if overall_score >= 90:
grade = 'A'
description = '优秀'
elif overall_score >= 80:
grade = 'B'
description = '良好'
elif overall_score >= 70:
grade = 'C'
description = '一般'
elif overall_score >= 60:
grade = 'D'
description = '较差'
else:
grade = 'E'
description = ''
# 生成建议
recommendations = self._generate_recommendations(analysis_results)
return {
'score': round(overall_score, 1),
'grade': grade,
'description': description,
'recommendations': recommendations,
'component_scores': {
'balance': balance_score,
'posture': posture_score,
'movement': movement_score
}
}
except Exception as e:
logger.error(f'综合评估生成失败: {e}')
return {'score': 0, 'grade': 'E', 'description': '评估失败'}
def _generate_recommendations(self, analysis_results: Dict[str, Any]) -> List[str]:
"""生成改善建议"""
recommendations = []
try:
# 平衡相关建议
balance_data = analysis_results.get('balance', {})
if balance_data.get('score', 0) < 80:
if balance_data.get('left_right_imbalance', 0) > 0.2:
recommendations.append('注意左右脚压力分布,建议进行单脚站立练习')
if balance_data.get('stability_index', 0) > 0.5:
recommendations.append('重心摆动较大,建议进行静态平衡训练')
# 姿态相关建议
posture_data = analysis_results.get('posture', {})
if posture_data.get('score', 0) < 80:
if abs(posture_data.get('avg_pitch', 0)) > 5:
recommendations.append('身体前后倾斜较明显,注意保持直立姿态')
if abs(posture_data.get('avg_roll', 0)) > 5:
recommendations.append('身体左右倾斜较明显,注意身体对称性')
# 运动相关建议
movement_data = analysis_results.get('movement', {})
if movement_data.get('score', 0) < 80:
if movement_data.get('movement_variability', 0) > 0.8:
recommendations.append('身体摆动过大,建议进行核心稳定性训练')
if movement_data.get('movement_frequency', 0) > 2:
recommendations.append('身体摆动频率较高,建议放松并专注于静态平衡')
# 通用建议
if not recommendations:
recommendations.append('整体表现良好,继续保持规律的平衡训练')
except Exception as e:
logger.error(f'建议生成失败: {e}')
recommendations = ['建议咨询专业医师进行详细评估']
return recommendations
def end_session(self, session_id: str):
"""结束检测会话"""
if session_id in self.session_data:
with self.data_lock:
del self.session_data[session_id]
logger.info(f'检测会话结束: {session_id}')
class BalanceAnalyzer:
"""平衡分析器"""
def analyze_real_time(self, pressure_data: Dict[str, Any], data_buffer: List[Dict]) -> Dict[str, Any]:
"""实时平衡分析"""
try:
balance_index = pressure_data.get('balance_index', 0)
cop = pressure_data.get('center_of_pressure', {'x': 0, 'y': 0})
# 计算最近10个数据点的平衡稳定性
recent_data = data_buffer[-10:] if len(data_buffer) >= 10 else data_buffer
if recent_data and all('pressure' in d for d in recent_data):
balance_indices = [d['pressure'].get('balance_index', 0) for d in recent_data]
stability = 1.0 - np.std(balance_indices) # 标准差越小,稳定性越高
else:
stability = 0.5
return {
'balance_index': balance_index,
'stability': max(0, min(1, stability)),
'center_of_pressure': cop,
'status': 'stable' if balance_index < 0.2 else 'unstable'
}
except Exception as e:
logger.error(f'实时平衡分析失败: {e}')
return {'balance_index': 0, 'stability': 0, 'status': 'unknown'}
def analyze_full_session(self, data_buffer: List[Dict], settings: Dict) -> Dict[str, Any]:
"""全会话平衡分析"""
try:
pressure_data = [d['pressure'] for d in data_buffer if 'pressure' in d]
if not pressure_data:
return {'score': 0, 'error': '没有压力数据'}
# 提取关键指标
balance_indices = [d.get('balance_index', 0) for d in pressure_data]
left_ratios = [d.get('left_ratio', 0.5) for d in pressure_data]
right_ratios = [d.get('right_ratio', 0.5) for d in pressure_data]
# 计算统计指标
avg_balance_index = np.mean(balance_indices)
std_balance_index = np.std(balance_indices)
max_balance_index = np.max(balance_indices)
# 左右脚不平衡程度
left_right_imbalance = abs(np.mean(left_ratios) - np.mean(right_ratios))
# 稳定性指数(基于标准差)
stability_index = std_balance_index
# 计算评分0-100分
balance_score = max(0, 100 - avg_balance_index * 200) # 平衡指数越小分数越高
stability_score = max(0, 100 - stability_index * 500) # 稳定性越高分数越高
symmetry_score = max(0, 100 - left_right_imbalance * 200) # 对称性越好分数越高
overall_score = (balance_score + stability_score + symmetry_score) / 3
return {
'score': round(overall_score, 1),
'avg_balance_index': round(avg_balance_index, 3),
'stability_index': round(stability_index, 3),
'left_right_imbalance': round(left_right_imbalance, 3),
'max_imbalance': round(max_balance_index, 3),
'data_points': len(pressure_data),
'component_scores': {
'balance': round(balance_score, 1),
'stability': round(stability_score, 1),
'symmetry': round(symmetry_score, 1)
}
}
except Exception as e:
logger.error(f'全会话平衡分析失败: {e}')
return {'score': 0, 'error': str(e)}
class PostureAnalyzer:
"""姿态分析器"""
def analyze_real_time(self, data: Dict[str, Any], data_buffer: List[Dict]) -> Dict[str, Any]:
"""实时姿态分析"""
try:
result = {}
# 分析IMU数据
if 'imu' in data:
imu_data = data['imu']
result['pitch'] = imu_data.get('pitch', 0)
result['roll'] = imu_data.get('roll', 0)
result['total_accel'] = imu_data.get('total_accel', 0)
# 分析姿态数据
if 'pose' in data:
pose_data = data['pose']
result['body_angle'] = pose_data.get('body_angle', {})
result['confidence'] = pose_data.get('confidence', 0)
# 计算姿态稳定性
recent_data = data_buffer[-5:] if len(data_buffer) >= 5 else data_buffer
if recent_data:
if 'imu' in data:
pitches = [d.get('imu', {}).get('pitch', 0) for d in recent_data if 'imu' in d]
rolls = [d.get('imu', {}).get('roll', 0) for d in recent_data if 'imu' in d]
if pitches and rolls:
pitch_stability = 1.0 - min(1.0, np.std(pitches) / 10)
roll_stability = 1.0 - min(1.0, np.std(rolls) / 10)
result['stability'] = (pitch_stability + roll_stability) / 2
return result
except Exception as e:
logger.error(f'实时姿态分析失败: {e}')
return {}
def analyze_full_session(self, data_buffer: List[Dict], settings: Dict) -> Dict[str, Any]:
"""全会话姿态分析"""
try:
imu_data = [d['imu'] for d in data_buffer if 'imu' in d]
if not imu_data:
return {'score': 0, 'error': '没有IMU数据'}
# 提取角度数据
pitches = [d.get('pitch', 0) for d in imu_data]
rolls = [d.get('roll', 0) for d in imu_data]
# 计算统计指标
avg_pitch = np.mean(pitches)
avg_roll = np.mean(rolls)
std_pitch = np.std(pitches)
std_roll = np.std(rolls)
max_pitch = np.max(np.abs(pitches))
max_roll = np.max(np.abs(rolls))
# 计算评分
pitch_score = max(0, 100 - abs(avg_pitch) * 5) # 平均倾斜角度越小分数越高
roll_score = max(0, 100 - abs(avg_roll) * 5)
stability_score = max(0, 100 - (std_pitch + std_roll) * 10) # 稳定性越高分数越高
overall_score = (pitch_score + roll_score + stability_score) / 3
return {
'score': round(overall_score, 1),
'avg_pitch': round(avg_pitch, 2),
'avg_roll': round(avg_roll, 2),
'std_pitch': round(std_pitch, 2),
'std_roll': round(std_roll, 2),
'max_pitch': round(max_pitch, 2),
'max_roll': round(max_roll, 2),
'data_points': len(imu_data),
'component_scores': {
'pitch': round(pitch_score, 1),
'roll': round(roll_score, 1),
'stability': round(stability_score, 1)
}
}
except Exception as e:
logger.error(f'全会话姿态分析失败: {e}')
return {'score': 0, 'error': str(e)}
class MovementAnalyzer:
"""运动分析器"""
def analyze_real_time(self, data: Dict[str, Any], data_buffer: List[Dict]) -> Dict[str, Any]:
"""实时运动分析"""
try:
if len(data_buffer) < 5:
return {'movement_detected': False}
# 分析最近的运动模式
recent_data = data_buffer[-10:]
# 计算重心位置变化
if 'pressure' in data:
cop_positions = []
for d in recent_data:
if 'pressure' in d:
cop = d['pressure'].get('center_of_pressure', {'x': 0, 'y': 0})
cop_positions.append((cop['x'], cop['y']))
if len(cop_positions) >= 2:
# 计算运动幅度
x_positions = [pos[0] for pos in cop_positions]
y_positions = [pos[1] for pos in cop_positions]
movement_range_x = np.max(x_positions) - np.min(x_positions)
movement_range_y = np.max(y_positions) - np.min(y_positions)
return {
'movement_detected': movement_range_x > 5 or movement_range_y > 5,
'movement_range_x': movement_range_x,
'movement_range_y': movement_range_y,
'total_movement': np.sqrt(movement_range_x**2 + movement_range_y**2)
}
return {'movement_detected': False}
except Exception as e:
logger.error(f'实时运动分析失败: {e}')
return {'movement_detected': False}
def analyze_full_session(self, data_buffer: List[Dict], settings: Dict) -> Dict[str, Any]:
"""全会话运动分析"""
try:
# 提取压力中心数据
cop_data = []
for d in data_buffer:
if 'pressure' in d:
cop = d['pressure'].get('center_of_pressure', {'x': 0, 'y': 0})
cop_data.append((cop['x'], cop['y']))
if len(cop_data) < 10:
return {'score': 0, 'error': '数据不足'}
# 计算运动指标
x_positions = [pos[0] for pos in cop_data]
y_positions = [pos[1] for pos in cop_data]
# 运动范围
movement_range_x = np.max(x_positions) - np.min(x_positions)
movement_range_y = np.max(y_positions) - np.min(y_positions)
total_range = np.sqrt(movement_range_x**2 + movement_range_y**2)
# 运动变异性
movement_variability = np.std(x_positions) + np.std(y_positions)
# 运动路径长度
path_length = 0
for i in range(1, len(cop_data)):
dx = cop_data[i][0] - cop_data[i-1][0]
dy = cop_data[i][1] - cop_data[i-1][1]
path_length += np.sqrt(dx**2 + dy**2)
# 运动频率分析(简化)
movement_frequency = path_length / len(cop_data) if len(cop_data) > 0 else 0
# 计算评分(运动幅度适中得分高)
range_score = max(0, 100 - total_range * 2) # 运动范围适中
variability_score = max(0, 100 - movement_variability * 10) # 变异性小
frequency_score = max(0, 100 - movement_frequency * 20) # 频率适中
overall_score = (range_score + variability_score + frequency_score) / 3
return {
'score': round(overall_score, 1),
'movement_range_x': round(movement_range_x, 2),
'movement_range_y': round(movement_range_y, 2),
'total_range': round(total_range, 2),
'movement_variability': round(movement_variability, 2),
'path_length': round(path_length, 2),
'movement_frequency': round(movement_frequency, 2),
'data_points': len(cop_data),
'component_scores': {
'range': round(range_score, 1),
'variability': round(variability_score, 1),
'frequency': round(frequency_score, 1)
}
}
except Exception as e:
logger.error(f'全会话运动分析失败: {e}')
return {'score': 0, 'error': str(e)}