521 lines
17 KiB
Python
521 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
工具模块
|
||
包含常用的工具函数、配置管理和辅助功能
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
import hashlib
|
||
import secrets
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional, Any, Union
|
||
from pathlib import Path
|
||
import configparser
|
||
from functools import wraps
|
||
import time
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class Config:
|
||
"""配置管理器"""
|
||
|
||
def __init__(self, config_file: str = 'config.ini'):
|
||
self.config_file = Path(config_file)
|
||
self.config = configparser.ConfigParser()
|
||
self._load_config()
|
||
|
||
def _load_config(self):
|
||
"""加载配置文件"""
|
||
try:
|
||
if self.config_file.exists():
|
||
self.config.read(self.config_file, encoding='utf-8')
|
||
logger.info(f'配置文件已加载: {self.config_file}')
|
||
else:
|
||
self._create_default_config()
|
||
logger.info('已创建默认配置文件')
|
||
except Exception as e:
|
||
logger.error(f'配置文件加载失败: {e}')
|
||
self._create_default_config()
|
||
|
||
def _create_default_config(self):
|
||
"""创建默认配置"""
|
||
try:
|
||
# 应用配置
|
||
self.config['APP'] = {
|
||
'name': 'Body Balance Evaluation System',
|
||
'version': '1.0.0',
|
||
'debug': 'false',
|
||
'log_level': 'INFO'
|
||
}
|
||
|
||
# 服务器配置
|
||
self.config['SERVER'] = {
|
||
'host': '127.0.0.1',
|
||
'port': '5000',
|
||
'cors_origins': '*'
|
||
}
|
||
|
||
# 数据库配置
|
||
self.config['DATABASE'] = {
|
||
'path': 'data/balance_system.db',
|
||
'backup_interval': '24', # 小时
|
||
'max_backups': '7'
|
||
}
|
||
|
||
# 设备配置
|
||
self.config['DEVICES'] = {
|
||
'camera_index': '0',
|
||
'camera_width': '640',
|
||
'camera_height': '480',
|
||
'camera_fps': '30',
|
||
'imu_port': 'COM3',
|
||
'pressure_port': 'COM4'
|
||
}
|
||
|
||
# 检测配置
|
||
self.config['DETECTION'] = {
|
||
'default_duration': '60', # 秒
|
||
'sampling_rate': '30', # Hz
|
||
'balance_threshold': '0.2',
|
||
'posture_threshold': '5.0' # 度
|
||
}
|
||
|
||
# 数据处理配置
|
||
self.config['DATA_PROCESSING'] = {
|
||
'filter_window': '5',
|
||
'outlier_threshold': '2.0',
|
||
'chart_dpi': '300',
|
||
'export_format': 'csv'
|
||
}
|
||
|
||
# 安全配置
|
||
self.config['SECURITY'] = {
|
||
'secret_key': secrets.token_hex(32),
|
||
'session_timeout': '3600', # 秒
|
||
'max_login_attempts': '5'
|
||
}
|
||
|
||
# 保存配置文件
|
||
self.save_config()
|
||
|
||
except Exception as e:
|
||
logger.error(f'默认配置创建失败: {e}')
|
||
|
||
def get(self, section: str, key: str, fallback: Any = None) -> str:
|
||
"""获取配置值"""
|
||
try:
|
||
return self.config.get(section, key, fallback=fallback)
|
||
except Exception as e:
|
||
logger.warning(f'配置获取失败 [{section}][{key}]: {e}')
|
||
return fallback
|
||
|
||
def getint(self, section: str, key: str, fallback: int = 0) -> int:
|
||
"""获取整数配置值"""
|
||
try:
|
||
return self.config.getint(section, key, fallback=fallback)
|
||
except Exception as e:
|
||
logger.warning(f'整数配置获取失败 [{section}][{key}]: {e}')
|
||
return fallback
|
||
|
||
def getfloat(self, section: str, key: str, fallback: float = 0.0) -> float:
|
||
"""获取浮点数配置值"""
|
||
try:
|
||
return self.config.getfloat(section, key, fallback=fallback)
|
||
except Exception as e:
|
||
logger.warning(f'浮点数配置获取失败 [{section}][{key}]: {e}')
|
||
return fallback
|
||
|
||
def getboolean(self, section: str, key: str, fallback: bool = False) -> bool:
|
||
"""获取布尔配置值"""
|
||
try:
|
||
return self.config.getboolean(section, key, fallback=fallback)
|
||
except Exception as e:
|
||
logger.warning(f'布尔配置获取失败 [{section}][{key}]: {e}')
|
||
return fallback
|
||
|
||
def set(self, section: str, key: str, value: str):
|
||
"""设置配置值"""
|
||
try:
|
||
if not self.config.has_section(section):
|
||
self.config.add_section(section)
|
||
self.config.set(section, key, str(value))
|
||
except Exception as e:
|
||
logger.error(f'配置设置失败 [{section}][{key}]: {e}')
|
||
|
||
def save_config(self):
|
||
"""保存配置文件"""
|
||
try:
|
||
# 确保目录存在
|
||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||
self.config.write(f)
|
||
logger.info(f'配置文件已保存: {self.config_file}')
|
||
except Exception as e:
|
||
logger.error(f'配置文件保存失败: {e}')
|
||
|
||
class Logger:
|
||
"""日志管理器"""
|
||
|
||
@staticmethod
|
||
def setup_logging(log_level: str = 'INFO', log_file: str = None):
|
||
"""设置日志配置"""
|
||
try:
|
||
# 创建日志目录
|
||
if log_file:
|
||
log_path = Path(log_file)
|
||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 配置日志格式
|
||
formatter = logging.Formatter(
|
||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
|
||
# 设置根日志器
|
||
root_logger = logging.getLogger()
|
||
root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||
|
||
# 清除现有处理器
|
||
for handler in root_logger.handlers[:]:
|
||
root_logger.removeHandler(handler)
|
||
|
||
# 控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setFormatter(formatter)
|
||
root_logger.addHandler(console_handler)
|
||
|
||
# 文件处理器
|
||
if log_file:
|
||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||
file_handler.setFormatter(formatter)
|
||
root_logger.addHandler(file_handler)
|
||
|
||
logger.info(f'日志系统已初始化,级别: {log_level}')
|
||
|
||
except Exception as e:
|
||
print(f'日志系统初始化失败: {e}')
|
||
|
||
class DataValidator:
|
||
"""数据验证器"""
|
||
|
||
@staticmethod
|
||
def validate_patient_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""验证患者数据"""
|
||
errors = []
|
||
|
||
# 必填字段检查
|
||
required_fields = ['name', 'gender', 'birth_date']
|
||
for field in required_fields:
|
||
if not data.get(field):
|
||
errors.append(f'缺少必填字段: {field}')
|
||
|
||
# 姓名验证
|
||
if data.get('name'):
|
||
name = data['name'].strip()
|
||
if len(name) < 2 or len(name) > 50:
|
||
errors.append('姓名长度应在2-50个字符之间')
|
||
data['name'] = name
|
||
|
||
# 性别验证
|
||
if data.get('gender'):
|
||
if data['gender'] not in ['male', 'female', 'other']:
|
||
errors.append('性别值无效')
|
||
|
||
# 出生日期验证
|
||
if data.get('birth_date'):
|
||
try:
|
||
birth_date = datetime.fromisoformat(data['birth_date'].replace('Z', '+00:00'))
|
||
if birth_date > datetime.now():
|
||
errors.append('出生日期不能是未来时间')
|
||
if birth_date < datetime(1900, 1, 1):
|
||
errors.append('出生日期过早')
|
||
except ValueError:
|
||
errors.append('出生日期格式无效')
|
||
|
||
# 身高验证
|
||
if data.get('height'):
|
||
try:
|
||
height = float(data['height'])
|
||
if height < 50 or height > 250:
|
||
errors.append('身高应在50-250cm之间')
|
||
data['height'] = height
|
||
except (ValueError, TypeError):
|
||
errors.append('身高格式无效')
|
||
|
||
# 体重验证
|
||
if data.get('weight'):
|
||
try:
|
||
weight = float(data['weight'])
|
||
if weight < 10 or weight > 300:
|
||
errors.append('体重应在10-300kg之间')
|
||
data['weight'] = weight
|
||
except (ValueError, TypeError):
|
||
errors.append('体重格式无效')
|
||
|
||
# 电话验证
|
||
if data.get('phone'):
|
||
phone = data['phone'].strip()
|
||
if phone and not phone.replace('-', '').replace(' ', '').isdigit():
|
||
errors.append('电话号码格式无效')
|
||
data['phone'] = phone
|
||
|
||
return {
|
||
'valid': len(errors) == 0,
|
||
'errors': errors,
|
||
'data': data
|
||
}
|
||
|
||
@staticmethod
|
||
def validate_detection_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""验证检测配置"""
|
||
errors = []
|
||
|
||
# 检测时长验证
|
||
if 'duration' in config:
|
||
try:
|
||
duration = int(config['duration'])
|
||
if duration < 10 or duration > 600:
|
||
errors.append('检测时长应在10-600秒之间')
|
||
config['duration'] = duration
|
||
except (ValueError, TypeError):
|
||
errors.append('检测时长格式无效')
|
||
|
||
# 采样频率验证
|
||
if 'sampling_rate' in config:
|
||
try:
|
||
rate = int(config['sampling_rate'])
|
||
if rate < 1 or rate > 100:
|
||
errors.append('采样频率应在1-100Hz之间')
|
||
config['sampling_rate'] = rate
|
||
except (ValueError, TypeError):
|
||
errors.append('采样频率格式无效')
|
||
|
||
# 录像设置验证
|
||
if 'record_video' in config:
|
||
if not isinstance(config['record_video'], bool):
|
||
errors.append('录像设置应为布尔值')
|
||
|
||
return {
|
||
'valid': len(errors) == 0,
|
||
'errors': errors,
|
||
'config': config
|
||
}
|
||
|
||
class SecurityUtils:
|
||
"""安全工具"""
|
||
|
||
@staticmethod
|
||
def generate_session_id() -> str:
|
||
"""生成会话ID"""
|
||
return secrets.token_urlsafe(32)
|
||
|
||
@staticmethod
|
||
def hash_password(password: str) -> str:
|
||
"""密码哈希"""
|
||
salt = secrets.token_hex(16)
|
||
password_hash = hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000)
|
||
return f"{salt}:{password_hash.hex()}"
|
||
|
||
@staticmethod
|
||
def verify_password(password: str, hashed: str) -> bool:
|
||
"""验证密码"""
|
||
try:
|
||
salt, password_hash = hashed.split(':')
|
||
return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex() == password_hash
|
||
except Exception:
|
||
return False
|
||
|
||
@staticmethod
|
||
def sanitize_filename(filename: str) -> str:
|
||
"""清理文件名"""
|
||
import re
|
||
# 移除或替换不安全的字符
|
||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||
# 限制长度
|
||
if len(filename) > 255:
|
||
name, ext = os.path.splitext(filename)
|
||
filename = name[:255-len(ext)] + ext
|
||
return filename
|
||
|
||
class FileUtils:
|
||
"""文件工具"""
|
||
|
||
@staticmethod
|
||
def ensure_directory(path: Union[str, Path]):
|
||
"""确保目录存在"""
|
||
Path(path).mkdir(parents=True, exist_ok=True)
|
||
|
||
@staticmethod
|
||
def get_file_size(path: Union[str, Path]) -> int:
|
||
"""获取文件大小"""
|
||
try:
|
||
return Path(path).stat().st_size
|
||
except Exception:
|
||
return 0
|
||
|
||
@staticmethod
|
||
def clean_old_files(directory: Union[str, Path], max_age_days: int = 30):
|
||
"""清理旧文件"""
|
||
try:
|
||
directory = Path(directory)
|
||
if not directory.exists():
|
||
return
|
||
|
||
cutoff_time = datetime.now() - timedelta(days=max_age_days)
|
||
|
||
for file_path in directory.iterdir():
|
||
if file_path.is_file():
|
||
file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||
if file_time < cutoff_time:
|
||
file_path.unlink()
|
||
logger.info(f'已删除旧文件: {file_path}')
|
||
|
||
except Exception as e:
|
||
logger.error(f'清理旧文件失败: {e}')
|
||
|
||
@staticmethod
|
||
def backup_file(source: Union[str, Path], backup_dir: Union[str, Path] = None) -> Optional[Path]:
|
||
"""备份文件"""
|
||
try:
|
||
source = Path(source)
|
||
if not source.exists():
|
||
return None
|
||
|
||
if backup_dir is None:
|
||
backup_dir = source.parent / 'backups'
|
||
else:
|
||
backup_dir = Path(backup_dir)
|
||
|
||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
backup_name = f"{source.stem}_{timestamp}{source.suffix}"
|
||
backup_path = backup_dir / backup_name
|
||
|
||
import shutil
|
||
shutil.copy2(source, backup_path)
|
||
|
||
logger.info(f'文件已备份: {source} -> {backup_path}')
|
||
return backup_path
|
||
|
||
except Exception as e:
|
||
logger.error(f'文件备份失败: {e}')
|
||
return None
|
||
|
||
class PerformanceMonitor:
|
||
"""性能监控器"""
|
||
|
||
def __init__(self):
|
||
self.metrics = {}
|
||
|
||
def start_timer(self, name: str):
|
||
"""开始计时"""
|
||
self.metrics[name] = {'start_time': time.time()}
|
||
|
||
def end_timer(self, name: str) -> float:
|
||
"""结束计时"""
|
||
if name in self.metrics and 'start_time' in self.metrics[name]:
|
||
duration = time.time() - self.metrics[name]['start_time']
|
||
self.metrics[name]['duration'] = duration
|
||
return duration
|
||
return 0.0
|
||
|
||
def get_metrics(self) -> Dict[str, Any]:
|
||
"""获取性能指标"""
|
||
return self.metrics.copy()
|
||
|
||
def reset_metrics(self):
|
||
"""重置指标"""
|
||
self.metrics.clear()
|
||
|
||
def timing_decorator(func):
|
||
"""计时装饰器"""
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
start_time = time.time()
|
||
try:
|
||
result = func(*args, **kwargs)
|
||
duration = time.time() - start_time
|
||
logger.debug(f'{func.__name__} 执行时间: {duration:.3f}秒')
|
||
return result
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
logger.error(f'{func.__name__} 执行失败 (耗时: {duration:.3f}秒): {e}')
|
||
raise
|
||
return wrapper
|
||
|
||
def retry_decorator(max_retries: int = 3, delay: float = 1.0):
|
||
"""重试装饰器"""
|
||
def decorator(func):
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
last_exception = None
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
return func(*args, **kwargs)
|
||
except Exception as e:
|
||
last_exception = e
|
||
if attempt < max_retries:
|
||
logger.warning(f'{func.__name__} 第{attempt + 1}次尝试失败: {e},{delay}秒后重试')
|
||
time.sleep(delay)
|
||
else:
|
||
logger.error(f'{func.__name__} 所有重试均失败')
|
||
|
||
raise last_exception
|
||
return wrapper
|
||
return decorator
|
||
|
||
class ResponseFormatter:
|
||
"""响应格式化器"""
|
||
|
||
@staticmethod
|
||
def success(data: Any = None, message: str = 'Success') -> Dict[str, Any]:
|
||
"""成功响应"""
|
||
response = {
|
||
'success': True,
|
||
'message': message,
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
if data is not None:
|
||
response['data'] = data
|
||
return response
|
||
|
||
@staticmethod
|
||
def error(message: str, error_code: str = None, details: Any = None) -> Dict[str, Any]:
|
||
"""错误响应"""
|
||
response = {
|
||
'success': False,
|
||
'message': message,
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
if error_code:
|
||
response['error_code'] = error_code
|
||
if details:
|
||
response['details'] = details
|
||
return response
|
||
|
||
@staticmethod
|
||
def paginated(data: List[Any], page: int, page_size: int, total: int) -> Dict[str, Any]:
|
||
"""分页响应"""
|
||
return {
|
||
'success': True,
|
||
'data': data,
|
||
'pagination': {
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total': total,
|
||
'total_pages': (total + page_size - 1) // page_size
|
||
},
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
# 全局配置实例
|
||
config = Config()
|
||
|
||
# 性能监控实例
|
||
performance_monitor = PerformanceMonitor() |