BodyBalanceEvaluation/backend/detection_engine.py
2025-07-31 17:23:05 +08:00

912 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
检测引擎模块
负责实时数据处理、姿态分析和平衡评估
"""
import numpy as np
import cv2
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
import logging
import threading
from collections import deque
import json
import base64
import os
from pathlib import Path
from flask import Blueprint, request, jsonify
logger = logging.getLogger(__name__)
class DetectionEngine:
"""检测引擎"""
def __init__(self):
self.session_data = {} # 存储各会话的数据
self.data_lock = threading.Lock()
self.analysis_algorithms = {
'balance_analysis': BalanceAnalyzer(),
'posture_analysis': PostureAnalyzer(),
'movement_analysis': MovementAnalyzer()
}
# 创建必要的目录
self._ensure_directories()
logger.info('检测引擎初始化完成')
def _ensure_directories(self):
"""确保必要的目录存在"""
try:
# 使用根目录的data目录
root_data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
patients_dir = os.path.join(root_data_dir, 'patients')
os.makedirs(patients_dir, exist_ok=True)
logger.info('数据目录创建完成')
except Exception as e:
logger.error(f'创建数据目录失败: {e}')
def start_session(self, session_id: str, settings: Dict[str, Any]):
"""开始检测会话"""
with self.data_lock:
self.session_data[session_id] = {
'settings': settings,
'start_time': datetime.now(),
'data_buffer': deque(maxlen=1000), # 保留最近1000个数据点
'analysis_results': {},
'real_time_metrics': {}
}
logger.info(f'检测会话开始: {session_id}')
def process_data(self, session_id: str, raw_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理实时数据"""
if session_id not in self.session_data:
logger.warning(f'会话不存在: {session_id}')
return {}
try:
# 数据预处理
processed_data = self._preprocess_data(raw_data)
# 存储数据
with self.data_lock:
self.session_data[session_id]['data_buffer'].append(processed_data)
# 实时分析
real_time_results = self._real_time_analysis(session_id, processed_data)
# 更新实时指标
with self.data_lock:
self.session_data[session_id]['real_time_metrics'].update(real_time_results)
return real_time_results
except Exception as e:
logger.error(f'数据处理失败: {e}')
return {}
def _preprocess_data(self, raw_data: Dict[str, Any]) -> Dict[str, Any]:
"""数据预处理"""
processed = {
'timestamp': raw_data.get('timestamp', datetime.now().isoformat())
}
# 处理摄像头数据
if 'camera' in raw_data:
camera_data = raw_data['camera']
if 'pose_data' in camera_data:
processed['pose'] = self._process_pose_data(camera_data['pose_data'])
# 处理IMU数据
if 'imu' in raw_data:
processed['imu'] = self._process_imu_data(raw_data['imu'])
# 处理压力数据
if 'pressure' in raw_data:
processed['pressure'] = self._process_pressure_data(raw_data['pressure'])
return processed
def _process_pose_data(self, pose_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理姿态数据"""
return {
'center_of_gravity': pose_data.get('center_of_gravity', {'x': 0, 'y': 0}),
'body_angle': pose_data.get('body_angle', {'pitch': 0, 'roll': 0, 'yaw': 0}),
'confidence': pose_data.get('confidence', 0.0)
}
def _process_imu_data(self, imu_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理IMU数据"""
# 计算合成加速度
accel = imu_data.get('accel', {'x': 0, 'y': 0, 'z': 0})
total_accel = np.sqrt(accel['x']**2 + accel['y']**2 + accel['z']**2)
# 计算倾斜角度
pitch = np.arctan2(accel['y'], np.sqrt(accel['x']**2 + accel['z']**2)) * 180 / np.pi
roll = np.arctan2(-accel['x'], accel['z']) * 180 / np.pi
return {
'accel': accel,
'gyro': imu_data.get('gyro', {'x': 0, 'y': 0, 'z': 0}),
'total_accel': total_accel,
'pitch': pitch,
'roll': roll,
'temperature': imu_data.get('temperature', 0)
}
def _process_pressure_data(self, pressure_data: Dict[str, Any]) -> Dict[str, Any]:
"""处理压力数据"""
left_foot = pressure_data.get('left_foot', 0)
right_foot = pressure_data.get('right_foot', 0)
total_pressure = left_foot + right_foot
# 计算压力分布比例
if total_pressure > 0:
left_ratio = left_foot / total_pressure
right_ratio = right_foot / total_pressure
balance_index = abs(left_ratio - right_ratio) # 平衡指数,越小越平衡
else:
left_ratio = right_ratio = 0.5
balance_index = 0
return {
'left_foot': left_foot,
'right_foot': right_foot,
'total_pressure': total_pressure,
'left_ratio': left_ratio,
'right_ratio': right_ratio,
'balance_index': balance_index,
'center_of_pressure': pressure_data.get('center_of_pressure', {'x': 0, 'y': 0})
}
def _real_time_analysis(self, session_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""实时分析"""
results = {}
try:
# 获取历史数据用于趋势分析
with self.data_lock:
data_buffer = list(self.session_data[session_id]['data_buffer'])
if len(data_buffer) < 2:
return results
# 平衡分析
if 'pressure' in data:
balance_result = self.analysis_algorithms['balance_analysis'].analyze_real_time(
data['pressure'], data_buffer
)
results['balance'] = balance_result
# 姿态分析
if 'pose' in data or 'imu' in data:
posture_result = self.analysis_algorithms['posture_analysis'].analyze_real_time(
data, data_buffer
)
results['posture'] = posture_result
# 运动分析
movement_result = self.analysis_algorithms['movement_analysis'].analyze_real_time(
data, data_buffer
)
results['movement'] = movement_result
except Exception as e:
logger.error(f'实时分析失败: {e}')
return results
def get_latest_data(self, session_id: str) -> Dict[str, Any]:
"""获取最新数据"""
if session_id not in self.session_data:
return {}
with self.data_lock:
session = self.session_data[session_id]
if not session['data_buffer']:
return {}
latest_data = session['data_buffer'][-1]
real_time_metrics = session['real_time_metrics'].copy()
return {
'latest_data': latest_data,
'real_time_metrics': real_time_metrics,
'data_count': len(session['data_buffer'])
}
def analyze_session(self, session_id: str) -> Dict[str, Any]:
"""分析整个会话数据"""
if session_id not in self.session_data:
logger.warning(f'会话不存在: {session_id}')
return {}
try:
with self.data_lock:
session = self.session_data[session_id]
data_buffer = list(session['data_buffer'])
settings = session['settings']
if not data_buffer:
return {'error': '没有数据可分析'}
analysis_results = {}
# 全面平衡分析
balance_analysis = self.analysis_algorithms['balance_analysis'].analyze_full_session(
data_buffer, settings
)
analysis_results['balance'] = balance_analysis
# 全面姿态分析
posture_analysis = self.analysis_algorithms['posture_analysis'].analyze_full_session(
data_buffer, settings
)
analysis_results['posture'] = posture_analysis
# 全面运动分析
movement_analysis = self.analysis_algorithms['movement_analysis'].analyze_full_session(
data_buffer, settings
)
analysis_results['movement'] = movement_analysis
# 综合评估
overall_assessment = self._generate_overall_assessment(analysis_results)
analysis_results['overall'] = overall_assessment
# 保存分析结果
with self.data_lock:
self.session_data[session_id]['analysis_results'] = analysis_results
logger.info(f'会话分析完成: {session_id}')
return analysis_results
except Exception as e:
logger.error(f'会话分析失败: {e}')
return {'error': str(e)}
def _generate_overall_assessment(self, analysis_results: Dict[str, Any]) -> Dict[str, Any]:
"""生成综合评估"""
try:
# 提取各项评分
balance_score = analysis_results.get('balance', {}).get('score', 0)
posture_score = analysis_results.get('posture', {}).get('score', 0)
movement_score = analysis_results.get('movement', {}).get('score', 0)
# 计算综合评分(加权平均)
weights = {'balance': 0.4, 'posture': 0.3, 'movement': 0.3}
overall_score = (
balance_score * weights['balance'] +
posture_score * weights['posture'] +
movement_score * weights['movement']
)
# 评估等级
if overall_score >= 90:
grade = 'A'
description = '优秀'
elif overall_score >= 80:
grade = 'B'
description = '良好'
elif overall_score >= 70:
grade = 'C'
description = '一般'
elif overall_score >= 60:
grade = 'D'
description = '较差'
else:
grade = 'E'
description = ''
# 生成建议
recommendations = self._generate_recommendations(analysis_results)
return {
'score': round(overall_score, 1),
'grade': grade,
'description': description,
'recommendations': recommendations,
'component_scores': {
'balance': balance_score,
'posture': posture_score,
'movement': movement_score
}
}
except Exception as e:
logger.error(f'综合评估生成失败: {e}')
return {'score': 0, 'grade': 'E', 'description': '评估失败'}
def _generate_recommendations(self, analysis_results: Dict[str, Any]) -> List[str]:
"""生成改善建议"""
recommendations = []
try:
# 平衡相关建议
balance_data = analysis_results.get('balance', {})
if balance_data.get('score', 0) < 80:
if balance_data.get('left_right_imbalance', 0) > 0.2:
recommendations.append('注意左右脚压力分布,建议进行单脚站立练习')
if balance_data.get('stability_index', 0) > 0.5:
recommendations.append('重心摆动较大,建议进行静态平衡训练')
# 姿态相关建议
posture_data = analysis_results.get('posture', {})
if posture_data.get('score', 0) < 80:
if abs(posture_data.get('avg_pitch', 0)) > 5:
recommendations.append('身体前后倾斜较明显,注意保持直立姿态')
if abs(posture_data.get('avg_roll', 0)) > 5:
recommendations.append('身体左右倾斜较明显,注意身体对称性')
# 运动相关建议
movement_data = analysis_results.get('movement', {})
if movement_data.get('score', 0) < 80:
if movement_data.get('movement_variability', 0) > 0.8:
recommendations.append('身体摆动过大,建议进行核心稳定性训练')
if movement_data.get('movement_frequency', 0) > 2:
recommendations.append('身体摆动频率较高,建议放松并专注于静态平衡')
# 通用建议
if not recommendations:
recommendations.append('整体表现良好,继续保持规律的平衡训练')
except Exception as e:
logger.error(f'建议生成失败: {e}')
recommendations = ['建议咨询专业医师进行详细评估']
return recommendations
def save_screenshot(self, patient_id: str, session_id: str, image_data: str, filename: str = None) -> Dict[str, Any]:
"""保存截图"""
try:
# 参数验证
if not patient_id or not image_data:
return {'success': False, 'error': '缺少必要参数'}
# 解码Base64图片数据
try:
# 移除data:image/jpeg;base64,前缀(如果存在)
if ',' in image_data:
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
except Exception as e:
logger.error(f'Base64解码失败: {e}')
return {'success': False, 'error': 'Base64解码失败'}
# 生成文件名
if not filename:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'screenshot_{timestamp}.jpg'
# 确保文件名以.jpg结尾
if not filename.lower().endswith('.jpg'):
filename += '.jpg'
# 创建患者会话目录
root_data_dir = Path(os.path.dirname(os.path.dirname(__file__))) / 'data'
session_dir = root_data_dir / 'patients' / patient_id / session_id
session_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = session_dir / filename
with open(file_path, 'wb') as f:
f.write(image_bytes)
logger.info(f'截图保存成功: {file_path}')
return {
'success': True,
'file_path': str(file_path),
'filename': filename
}
except Exception as e:
logger.error(f'保存截图失败: {e}')
return {'success': False, 'error': str(e)}
def save_recording(self, patient_id: str, session_id: str, video_data: str, filename: str = None) -> Dict[str, Any]:
"""保存录像"""
try:
# 参数验证
if not patient_id or not video_data:
return {'success': False, 'error': '缺少必要参数'}
# 解码Base64视频数据
try:
# 移除data:video/webm;base64,前缀(如果存在)
if ',' in video_data:
video_data = video_data.split(',')[1]
video_bytes = base64.b64decode(video_data)
except Exception as e:
logger.error(f'Base64解码失败: {e}')
return {'success': False, 'error': 'Base64解码失败'}
# 生成文件名
if not filename:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'recording_{timestamp}.mp4'
# 确保文件名以.mp4结尾
if not filename.lower().endswith('.mp4'):
filename += '.mp4'
# 创建患者会话目录
root_data_dir = Path(os.path.dirname(os.path.dirname(__file__))) / 'data'
session_dir = root_data_dir / 'patients' / patient_id / session_id
session_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = session_dir / filename
with open(file_path, 'wb') as f:
f.write(video_bytes)
logger.info(f'录像保存成功: {file_path}')
return {
'success': True,
'file_path': str(file_path),
'filename': filename
}
except Exception as e:
logger.error(f'保存录像失败: {e}')
return {'success': False, 'error': str(e)}
def get_patient_files(self, patient_id: str) -> Dict[str, Any]:
"""获取患者的所有文件列表按会话ID组织"""
try:
root_data_dir = Path(os.path.dirname(os.path.dirname(__file__))) / 'data'
patient_dir = root_data_dir / 'patients' / patient_id
if not patient_dir.exists():
return {'sessions': {}}
sessions = {}
# 遍历患者目录下的所有会话目录
for session_dir in patient_dir.iterdir():
if session_dir.is_dir():
session_id = session_dir.name
screenshots = []
recordings = []
# 获取该会话下的截图列表
for file_path in session_dir.glob('screenshot_*.jpg'):
screenshots.append({
'filename': file_path.name,
'path': str(file_path),
'created_time': datetime.fromtimestamp(file_path.stat().st_mtime).isoformat()
})
# 获取该会话下的录像列表
for file_path in session_dir.glob('recording_*.mp4'):
recordings.append({
'filename': file_path.name,
'path': str(file_path),
'created_time': datetime.fromtimestamp(file_path.stat().st_mtime).isoformat()
})
sessions[session_id] = {
'screenshots': sorted(screenshots, key=lambda x: x['created_time'], reverse=True),
'recordings': sorted(recordings, key=lambda x: x['created_time'], reverse=True)
}
return {'sessions': sessions}
except Exception as e:
logger.error(f'获取患者文件列表失败: {e}')
return {'sessions': {}}
def end_session(self, session_id: str):
"""结束检测会话"""
if session_id in self.session_data:
with self.data_lock:
del self.session_data[session_id]
logger.info(f'检测会话结束: {session_id}')
class BalanceAnalyzer:
"""平衡分析器"""
def analyze_real_time(self, pressure_data: Dict[str, Any], data_buffer: List[Dict]) -> Dict[str, Any]:
"""实时平衡分析"""
try:
balance_index = pressure_data.get('balance_index', 0)
cop = pressure_data.get('center_of_pressure', {'x': 0, 'y': 0})
# 计算最近10个数据点的平衡稳定性
recent_data = data_buffer[-10:] if len(data_buffer) >= 10 else data_buffer
if recent_data and all('pressure' in d for d in recent_data):
balance_indices = [d['pressure'].get('balance_index', 0) for d in recent_data]
stability = 1.0 - np.std(balance_indices) # 标准差越小,稳定性越高
else:
stability = 0.5
return {
'balance_index': balance_index,
'stability': max(0, min(1, stability)),
'center_of_pressure': cop,
'status': 'stable' if balance_index < 0.2 else 'unstable'
}
except Exception as e:
logger.error(f'实时平衡分析失败: {e}')
return {'balance_index': 0, 'stability': 0, 'status': 'unknown'}
def analyze_full_session(self, data_buffer: List[Dict], settings: Dict) -> Dict[str, Any]:
"""全会话平衡分析"""
try:
pressure_data = [d['pressure'] for d in data_buffer if 'pressure' in d]
if not pressure_data:
return {'score': 0, 'error': '没有压力数据'}
# 提取关键指标
balance_indices = [d.get('balance_index', 0) for d in pressure_data]
left_ratios = [d.get('left_ratio', 0.5) for d in pressure_data]
right_ratios = [d.get('right_ratio', 0.5) for d in pressure_data]
# 计算统计指标
avg_balance_index = np.mean(balance_indices)
std_balance_index = np.std(balance_indices)
max_balance_index = np.max(balance_indices)
# 左右脚不平衡程度
left_right_imbalance = abs(np.mean(left_ratios) - np.mean(right_ratios))
# 稳定性指数(基于标准差)
stability_index = std_balance_index
# 计算评分0-100分
balance_score = max(0, 100 - avg_balance_index * 200) # 平衡指数越小分数越高
stability_score = max(0, 100 - stability_index * 500) # 稳定性越高分数越高
symmetry_score = max(0, 100 - left_right_imbalance * 200) # 对称性越好分数越高
overall_score = (balance_score + stability_score + symmetry_score) / 3
return {
'score': round(overall_score, 1),
'avg_balance_index': round(avg_balance_index, 3),
'stability_index': round(stability_index, 3),
'left_right_imbalance': round(left_right_imbalance, 3),
'max_imbalance': round(max_balance_index, 3),
'data_points': len(pressure_data),
'component_scores': {
'balance': round(balance_score, 1),
'stability': round(stability_score, 1),
'symmetry': round(symmetry_score, 1)
}
}
except Exception as e:
logger.error(f'全会话平衡分析失败: {e}')
return {'score': 0, 'error': str(e)}
class PostureAnalyzer:
"""姿态分析器"""
def analyze_real_time(self, data: Dict[str, Any], data_buffer: List[Dict]) -> Dict[str, Any]:
"""实时姿态分析"""
try:
result = {}
# 分析IMU数据
if 'imu' in data:
imu_data = data['imu']
result['pitch'] = imu_data.get('pitch', 0)
result['roll'] = imu_data.get('roll', 0)
result['total_accel'] = imu_data.get('total_accel', 0)
# 分析姿态数据
if 'pose' in data:
pose_data = data['pose']
result['body_angle'] = pose_data.get('body_angle', {})
result['confidence'] = pose_data.get('confidence', 0)
# 计算姿态稳定性
recent_data = data_buffer[-5:] if len(data_buffer) >= 5 else data_buffer
if recent_data:
if 'imu' in data:
pitches = [d.get('imu', {}).get('pitch', 0) for d in recent_data if 'imu' in d]
rolls = [d.get('imu', {}).get('roll', 0) for d in recent_data if 'imu' in d]
if pitches and rolls:
pitch_stability = 1.0 - min(1.0, np.std(pitches) / 10)
roll_stability = 1.0 - min(1.0, np.std(rolls) / 10)
result['stability'] = (pitch_stability + roll_stability) / 2
return result
except Exception as e:
logger.error(f'实时姿态分析失败: {e}')
return {}
def analyze_full_session(self, data_buffer: List[Dict], settings: Dict) -> Dict[str, Any]:
"""全会话姿态分析"""
try:
imu_data = [d['imu'] for d in data_buffer if 'imu' in d]
if not imu_data:
return {'score': 0, 'error': '没有IMU数据'}
# 提取角度数据
pitches = [d.get('pitch', 0) for d in imu_data]
rolls = [d.get('roll', 0) for d in imu_data]
# 计算统计指标
avg_pitch = np.mean(pitches)
avg_roll = np.mean(rolls)
std_pitch = np.std(pitches)
std_roll = np.std(rolls)
max_pitch = np.max(np.abs(pitches))
max_roll = np.max(np.abs(rolls))
# 计算评分
pitch_score = max(0, 100 - abs(avg_pitch) * 5) # 平均倾斜角度越小分数越高
roll_score = max(0, 100 - abs(avg_roll) * 5)
stability_score = max(0, 100 - (std_pitch + std_roll) * 10) # 稳定性越高分数越高
overall_score = (pitch_score + roll_score + stability_score) / 3
return {
'score': round(overall_score, 1),
'avg_pitch': round(avg_pitch, 2),
'avg_roll': round(avg_roll, 2),
'std_pitch': round(std_pitch, 2),
'std_roll': round(std_roll, 2),
'max_pitch': round(max_pitch, 2),
'max_roll': round(max_roll, 2),
'data_points': len(imu_data),
'component_scores': {
'pitch': round(pitch_score, 1),
'roll': round(roll_score, 1),
'stability': round(stability_score, 1)
}
}
except Exception as e:
logger.error(f'全会话姿态分析失败: {e}')
return {'score': 0, 'error': str(e)}
class MovementAnalyzer:
"""运动分析器"""
def analyze_real_time(self, data: Dict[str, Any], data_buffer: List[Dict]) -> Dict[str, Any]:
"""实时运动分析"""
try:
if len(data_buffer) < 5:
return {'movement_detected': False}
# 分析最近的运动模式
recent_data = data_buffer[-10:]
# 计算重心位置变化
if 'pressure' in data:
cop_positions = []
for d in recent_data:
if 'pressure' in d:
cop = d['pressure'].get('center_of_pressure', {'x': 0, 'y': 0})
cop_positions.append((cop['x'], cop['y']))
if len(cop_positions) >= 2:
# 计算运动幅度
x_positions = [pos[0] for pos in cop_positions]
y_positions = [pos[1] for pos in cop_positions]
movement_range_x = np.max(x_positions) - np.min(x_positions)
movement_range_y = np.max(y_positions) - np.min(y_positions)
return {
'movement_detected': movement_range_x > 5 or movement_range_y > 5,
'movement_range_x': movement_range_x,
'movement_range_y': movement_range_y,
'total_movement': np.sqrt(movement_range_x**2 + movement_range_y**2)
}
return {'movement_detected': False}
except Exception as e:
logger.error(f'实时运动分析失败: {e}')
return {'movement_detected': False}
def analyze_full_session(self, data_buffer: List[Dict], settings: Dict) -> Dict[str, Any]:
"""全会话运动分析"""
try:
# 提取压力中心数据
cop_data = []
for d in data_buffer:
if 'pressure' in d:
cop = d['pressure'].get('center_of_pressure', {'x': 0, 'y': 0})
cop_data.append((cop['x'], cop['y']))
if len(cop_data) < 10:
return {'score': 0, 'error': '数据不足'}
# 计算运动指标
x_positions = [pos[0] for pos in cop_data]
y_positions = [pos[1] for pos in cop_data]
# 运动范围
movement_range_x = np.max(x_positions) - np.min(x_positions)
movement_range_y = np.max(y_positions) - np.min(y_positions)
total_range = np.sqrt(movement_range_x**2 + movement_range_y**2)
# 运动变异性
movement_variability = np.std(x_positions) + np.std(y_positions)
# 运动路径长度
path_length = 0
for i in range(1, len(cop_data)):
dx = cop_data[i][0] - cop_data[i-1][0]
dy = cop_data[i][1] - cop_data[i-1][1]
path_length += np.sqrt(dx**2 + dy**2)
# 运动频率分析(简化)
movement_frequency = path_length / len(cop_data) if len(cop_data) > 0 else 0
# 计算评分(运动幅度适中得分高)
range_score = max(0, 100 - total_range * 2) # 运动范围适中
variability_score = max(0, 100 - movement_variability * 10) # 变异性小
frequency_score = max(0, 100 - movement_frequency * 20) # 频率适中
overall_score = (range_score + variability_score + frequency_score) / 3
return {
'score': round(overall_score, 1),
'movement_range_x': round(movement_range_x, 2),
'movement_range_y': round(movement_range_y, 2),
'total_range': round(total_range, 2),
'movement_variability': round(movement_variability, 2),
'path_length': round(path_length, 2),
'movement_frequency': round(movement_frequency, 2),
'data_points': len(cop_data),
'component_scores': {
'range': round(range_score, 1),
'variability': round(variability_score, 1),
'frequency': round(frequency_score, 1)
}
}
except Exception as e:
logger.error(f'全会话运动分析失败: {e}')
return {'score': 0, 'status': 'error', 'error': str(e)}
# ==================== Flask 路由定义 ====================
# 创建Blueprint
detection_bp = Blueprint('detection', __name__)
# 全局detection_engine实例
detection_engine_instance = None
def init_detection_engine():
"""初始化检测引擎实例"""
global detection_engine_instance
if detection_engine_instance is None:
detection_engine_instance = DetectionEngine()
return detection_engine_instance
@detection_bp.route('/api/screenshots/save', methods=['POST'])
def save_screenshot():
"""保存截图"""
try:
engine = init_detection_engine()
data = request.get_json()
# 验证必需参数
required_fields = ['patientId', 'imageData', 'sessionId']
for field in required_fields:
if not data.get(field):
return jsonify({
'success': False,
'message': f'缺少必需参数: {field}'
}), 400
patient_id = data['patientId']
image_data = data['imageData']
session_id = data['sessionId']
filename = data.get('filename') # 可选参数
# 验证base64图片数据格式
if not image_data.startswith('data:image/'):
return jsonify({
'success': False,
'message': '无效的图片数据格式'
}), 400
# 调用detection_engine的保存截图方法
result = engine.save_screenshot(patient_id, session_id, image_data, filename)
if result['success']:
return jsonify({
'success': True,
'message': '截图保存成功',
'filepath': result['file_path'],
'filename': result['filename']
})
else:
return jsonify({
'success': False,
'message': result['error']
}), 400
except Exception as e:
logger.error(f'保存截图失败: {e}')
return jsonify({
'success': False,
'message': f'保存截图失败: {str(e)}'
}), 500
@detection_bp.route('/api/recordings/save', methods=['POST'])
def save_recording():
"""保存录像"""
try:
engine = init_detection_engine()
data = request.get_json()
# 验证必需参数
required_fields = ['patientId', 'videoData', 'sessionId']
for field in required_fields:
if not data.get(field):
return jsonify({
'success': False,
'message': f'缺少必需参数: {field}'
}), 400
patient_id = data['patientId']
video_data = data['videoData']
session_id = data['sessionId']
filename = data.get('filename') # 可选参数
# 验证base64视频数据格式
if not (video_data.startswith('data:video/mp4') or video_data.startswith('data:video/webm')):
return jsonify({
'success': False,
'message': '无效的视频数据格式仅支持MP4和WebM格式'
}), 400
# 调用detection_engine的保存录像方法
result = engine.save_recording(patient_id, session_id, video_data, filename)
if result['success']:
return jsonify({
'success': True,
'message': '录像保存成功',
'filepath': result['file_path'],
'filename': result['filename']
})
else:
return jsonify({
'success': False,
'message': result['error']
}), 400
except Exception as e:
logger.error(f'保存录像失败: {e}')
return jsonify({
'success': False,
'message': f'保存录像失败: {str(e)}'
}), 500
@detection_bp.route('/api/patients/<patient_id>/files', methods=['GET'])
def get_patient_files(patient_id):
"""获取患者的所有文件列表"""
try:
engine = init_detection_engine()
# 调用detection_engine的获取患者文件方法
result = engine.get_patient_files(patient_id)
return jsonify({
'success': True,
'data': result
})
except Exception as e:
logger.error(f'获取患者文件列表失败: {e}')
return jsonify({
'success': False,
'message': f'获取患者文件列表失败: {str(e)}'
}), 500