BodyBalanceEvaluation/backend/utils.py
2025-08-02 16:52:17 +08:00

521 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
工具模块
包含常用的工具函数、配置管理和辅助功能
"""
import os
import json
import logging
import hashlib
import secrets
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from pathlib import Path
import configparser
from functools import wraps
import time
logger = logging.getLogger(__name__)
class Config:
"""配置管理器"""
def __init__(self, config_file: str = '../config.ini'):
self.config_file = Path(config_file)
self.config = configparser.ConfigParser()
self._load_config()
def _load_config(self):
"""加载配置文件"""
try:
if self.config_file.exists():
self.config.read(self.config_file, encoding='utf-8')
logger.info(f'配置文件已加载: {self.config_file}')
else:
self._create_default_config()
logger.info('已创建默认配置文件')
except Exception as e:
logger.error(f'配置文件加载失败: {e}')
self._create_default_config()
def _create_default_config(self):
"""创建默认配置"""
try:
# 应用配置
self.config['APP'] = {
'name': 'Body Balance Evaluation System',
'version': '1.0.0',
'debug': 'false',
'log_level': 'INFO'
}
# 服务器配置
self.config['SERVER'] = {
'host': '0.0.0.0',
'port': '5000',
'cors_origins': '*'
}
# 数据库配置
self.config['DATABASE'] = {
'path': 'backend/data/body_balance.db',
'backup_interval': '24', # 小时
'max_backups': '7'
}
# 设备配置
self.config['DEVICES'] = {
'camera_index': '0',
'camera_width': '640',
'camera_height': '480',
'camera_fps': '30',
'imu_port': 'COM3',
'pressure_port': 'COM4'
}
# 检测配置
self.config['DETECTION'] = {
'default_duration': '60', # 秒
'sampling_rate': '30', # Hz
'balance_threshold': '0.2',
'posture_threshold': '5.0' # 度
}
# 数据处理配置
self.config['DATA_PROCESSING'] = {
'filter_window': '5',
'outlier_threshold': '2.0',
'chart_dpi': '300',
'export_format': 'csv'
}
# 安全配置
self.config['SECURITY'] = {
'secret_key': secrets.token_hex(32),
'session_timeout': '3600', # 秒
'max_login_attempts': '5'
}
# 保存配置文件
self.save_config()
except Exception as e:
logger.error(f'默认配置创建失败: {e}')
def get(self, section: str, key: str, fallback: Any = None) -> str:
"""获取配置值"""
try:
return self.config.get(section, key, fallback=fallback)
except Exception as e:
logger.warning(f'配置获取失败 [{section}][{key}]: {e}')
return fallback
def getint(self, section: str, key: str, fallback: int = 0) -> int:
"""获取整数配置值"""
try:
return self.config.getint(section, key, fallback=fallback)
except Exception as e:
logger.warning(f'整数配置获取失败 [{section}][{key}]: {e}')
return fallback
def getfloat(self, section: str, key: str, fallback: float = 0.0) -> float:
"""获取浮点数配置值"""
try:
return self.config.getfloat(section, key, fallback=fallback)
except Exception as e:
logger.warning(f'浮点数配置获取失败 [{section}][{key}]: {e}')
return fallback
def getboolean(self, section: str, key: str, fallback: bool = False) -> bool:
"""获取布尔配置值"""
try:
return self.config.getboolean(section, key, fallback=fallback)
except Exception as e:
logger.warning(f'布尔配置获取失败 [{section}][{key}]: {e}')
return fallback
def set(self, section: str, key: str, value: str):
"""设置配置值"""
try:
if not self.config.has_section(section):
self.config.add_section(section)
self.config.set(section, key, str(value))
except Exception as e:
logger.error(f'配置设置失败 [{section}][{key}]: {e}')
def save_config(self):
"""保存配置文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_file, 'w', encoding='utf-8') as f:
self.config.write(f)
logger.info(f'配置文件已保存: {self.config_file}')
except Exception as e:
logger.error(f'配置文件保存失败: {e}')
class Logger:
"""日志管理器"""
@staticmethod
def setup_logging(log_level: str = 'INFO', log_file: str = None):
"""设置日志配置"""
try:
# 创建日志目录
if log_file:
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# 配置日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 设置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
# 清除现有处理器
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# 文件处理器
if log_file:
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
logger.info(f'日志系统已初始化,级别: {log_level}')
except Exception as e:
print(f'日志系统初始化失败: {e}')
class DataValidator:
"""数据验证器"""
@staticmethod
def validate_patient_data(data: Dict[str, Any]) -> Dict[str, Any]:
"""验证患者数据"""
errors = []
# 必填字段检查
required_fields = ['name', 'gender', 'birth_date']
for field in required_fields:
if not data.get(field):
errors.append(f'缺少必填字段: {field}')
# 姓名验证
if data.get('name'):
name = data['name'].strip()
if len(name) < 2 or len(name) > 50:
errors.append('姓名长度应在2-50个字符之间')
data['name'] = name
# 性别验证
if data.get('gender'):
if data['gender'] not in ['male', 'female', 'other']:
errors.append('性别值无效')
# 出生日期验证
if data.get('birth_date'):
try:
birth_date = datetime.fromisoformat(data['birth_date'].replace('Z', '+00:00'))
if birth_date > datetime.now():
errors.append('出生日期不能是未来时间')
if birth_date < datetime(1900, 1, 1):
errors.append('出生日期过早')
except ValueError:
errors.append('出生日期格式无效')
# 身高验证
if data.get('height'):
try:
height = float(data['height'])
if height < 50 or height > 250:
errors.append('身高应在50-250cm之间')
data['height'] = height
except (ValueError, TypeError):
errors.append('身高格式无效')
# 体重验证
if data.get('weight'):
try:
weight = float(data['weight'])
if weight < 10 or weight > 300:
errors.append('体重应在10-300kg之间')
data['weight'] = weight
except (ValueError, TypeError):
errors.append('体重格式无效')
# 电话验证
if data.get('phone'):
phone = data['phone'].strip()
if phone and not phone.replace('-', '').replace(' ', '').isdigit():
errors.append('电话号码格式无效')
data['phone'] = phone
return {
'valid': len(errors) == 0,
'errors': errors,
'data': data
}
@staticmethod
def validate_detection_config(config: Dict[str, Any]) -> Dict[str, Any]:
"""验证检测配置"""
errors = []
# 检测时长验证
if 'duration' in config:
try:
duration = int(config['duration'])
if duration < 10 or duration > 600:
errors.append('检测时长应在10-600秒之间')
config['duration'] = duration
except (ValueError, TypeError):
errors.append('检测时长格式无效')
# 采样频率验证
if 'sampling_rate' in config:
try:
rate = int(config['sampling_rate'])
if rate < 1 or rate > 100:
errors.append('采样频率应在1-100Hz之间')
config['sampling_rate'] = rate
except (ValueError, TypeError):
errors.append('采样频率格式无效')
# 录像设置验证
if 'record_video' in config:
if not isinstance(config['record_video'], bool):
errors.append('录像设置应为布尔值')
return {
'valid': len(errors) == 0,
'errors': errors,
'config': config
}
class SecurityUtils:
"""安全工具"""
@staticmethod
def generate_session_id() -> str:
"""生成会话ID"""
return secrets.token_urlsafe(32)
@staticmethod
def hash_password(password: str) -> str:
"""密码哈希"""
salt = secrets.token_hex(16)
password_hash = hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000)
return f"{salt}:{password_hash.hex()}"
@staticmethod
def verify_password(password: str, hashed: str) -> bool:
"""验证密码"""
try:
salt, password_hash = hashed.split(':')
return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex() == password_hash
except Exception:
return False
@staticmethod
def sanitize_filename(filename: str) -> str:
"""清理文件名"""
import re
# 移除或替换不安全的字符
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
# 限制长度
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:255-len(ext)] + ext
return filename
class FileUtils:
"""文件工具"""
@staticmethod
def ensure_directory(path: Union[str, Path]):
"""确保目录存在"""
Path(path).mkdir(parents=True, exist_ok=True)
@staticmethod
def get_file_size(path: Union[str, Path]) -> int:
"""获取文件大小"""
try:
return Path(path).stat().st_size
except Exception:
return 0
@staticmethod
def clean_old_files(directory: Union[str, Path], max_age_days: int = 30):
"""清理旧文件"""
try:
directory = Path(directory)
if not directory.exists():
return
cutoff_time = datetime.now() - timedelta(days=max_age_days)
for file_path in directory.iterdir():
if file_path.is_file():
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
if file_time < cutoff_time:
file_path.unlink()
logger.info(f'已删除旧文件: {file_path}')
except Exception as e:
logger.error(f'清理旧文件失败: {e}')
@staticmethod
def backup_file(source: Union[str, Path], backup_dir: Union[str, Path] = None) -> Optional[Path]:
"""备份文件"""
try:
source = Path(source)
if not source.exists():
return None
if backup_dir is None:
backup_dir = source.parent / 'backups'
else:
backup_dir = Path(backup_dir)
backup_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_name = f"{source.stem}_{timestamp}{source.suffix}"
backup_path = backup_dir / backup_name
import shutil
shutil.copy2(source, backup_path)
logger.info(f'文件已备份: {source} -> {backup_path}')
return backup_path
except Exception as e:
logger.error(f'文件备份失败: {e}')
return None
class PerformanceMonitor:
"""性能监控器"""
def __init__(self):
self.metrics = {}
def start_timer(self, name: str):
"""开始计时"""
self.metrics[name] = {'start_time': time.time()}
def end_timer(self, name: str) -> float:
"""结束计时"""
if name in self.metrics and 'start_time' in self.metrics[name]:
duration = time.time() - self.metrics[name]['start_time']
self.metrics[name]['duration'] = duration
return duration
return 0.0
def get_metrics(self) -> Dict[str, Any]:
"""获取性能指标"""
return self.metrics.copy()
def reset_metrics(self):
"""重置指标"""
self.metrics.clear()
def timing_decorator(func):
"""计时装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
duration = time.time() - start_time
logger.debug(f'{func.__name__} 执行时间: {duration:.3f}')
return result
except Exception as e:
duration = time.time() - start_time
logger.error(f'{func.__name__} 执行失败 (耗时: {duration:.3f}秒): {e}')
raise
return wrapper
def retry_decorator(max_retries: int = 3, delay: float = 1.0):
"""重试装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries:
logger.warning(f'{func.__name__}{attempt + 1}次尝试失败: {e}{delay}秒后重试')
time.sleep(delay)
else:
logger.error(f'{func.__name__} 所有重试均失败')
raise last_exception
return wrapper
return decorator
class ResponseFormatter:
"""响应格式化器"""
@staticmethod
def success(data: Any = None, message: str = 'Success') -> Dict[str, Any]:
"""成功响应"""
response = {
'success': True,
'message': message,
'timestamp': datetime.now().isoformat()
}
if data is not None:
response['data'] = data
return response
@staticmethod
def error(message: str, error_code: str = None, details: Any = None) -> Dict[str, Any]:
"""错误响应"""
response = {
'success': False,
'message': message,
'timestamp': datetime.now().isoformat()
}
if error_code:
response['error_code'] = error_code
if details:
response['details'] = details
return response
@staticmethod
def paginated(data: List[Any], page: int, page_size: int, total: int) -> Dict[str, Any]:
"""分页响应"""
return {
'success': True,
'data': data,
'pagination': {
'page': page,
'page_size': page_size,
'total': total,
'total_pages': (total + page_size - 1) // page_size
},
'timestamp': datetime.now().isoformat()
}
# 全局配置实例
config = Config()
# 性能监控实例
performance_monitor = PerformanceMonitor()