SmartEDT/backend/services/user_service.py

170 lines
6.5 KiB
Python
Raw Normal View History

"""系统用户服务。
围绕 sys_user 表提供用户的增删改查密码设置写入哈希登录时间维护等能力
"""
from __future__ import annotations
import secrets
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from backend.auth.passwords import hash_password
from backend.database.schema import sys_role, sys_user
from backend.utils import utc_now
class UserService:
"""系统用户管理服务SQLAlchemy Core"""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def list_users(self) -> list[dict[str, Any]]:
"""查询用户列表(包含角色名称)。"""
async with self._session_factory() as session:
q = (
select(
sys_user.c.user_id,
sys_user.c.username,
sys_user.c.display_name,
sys_user.c.role_id,
sys_role.c.role_name,
sys_user.c.is_active,
sys_user.c.last_login_at,
sys_user.c.created_at,
sys_user.c.updated_at,
sys_user.c.extra,
)
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
.order_by(sys_user.c.created_at.desc())
)
return [dict(r) for r in (await session.execute(q)).mappings().all()]
async def get_user(self, user_id: str) -> dict[str, Any] | None:
"""按 user_id 查询用户(包含角色名称)。"""
async with self._session_factory() as session:
q = (
select(
sys_user.c.user_id,
sys_user.c.username,
sys_user.c.display_name,
sys_user.c.role_id,
sys_role.c.role_name,
sys_user.c.is_active,
sys_user.c.last_login_at,
sys_user.c.created_at,
sys_user.c.updated_at,
sys_user.c.extra,
)
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
.where(sys_user.c.user_id == user_id)
.limit(1)
)
row = (await session.execute(q)).mappings().first()
return dict(row) if row else None
async def get_user_by_username(self, username: str) -> dict[str, Any] | None:
"""按 username 查询用户(用于登录)。"""
async with self._session_factory() as session:
q = select(sys_user).where(sys_user.c.username == username).limit(1)
row = (await session.execute(q)).mappings().first()
return dict(row) if row else None
async def create_user(
self,
*,
user_id: str | None,
username: str,
password: str,
role_id: str,
display_name: str | None = None,
is_active: bool = True,
extra: dict | None = None,
) -> dict[str, Any]:
"""创建用户并写入密码哈希。"""
uid = user_id or ("user_" + secrets.token_hex(8))
password_hash = hash_password(password)
async with self._session_factory() as session:
try:
await session.execute(
sys_user.insert().values(
user_id=uid,
username=username,
display_name=display_name,
password_hash=password_hash,
role_id=role_id,
is_active=is_active,
updated_at=utc_now(),
extra=extra,
)
)
await session.commit()
except IntegrityError:
await session.rollback()
raise
created = await self.get_user(uid)
if not created:
raise RuntimeError("failed to create user")
return created
async def update_user(
self,
user_id: str,
*,
display_name: str | None = None,
role_id: str | None = None,
is_active: bool | None = None,
extra: dict | None = None,
) -> dict[str, Any] | None:
"""更新用户字段(仅更新传入的字段)。"""
patch: dict[str, Any] = {"updated_at": utc_now()}
if display_name is not None:
patch["display_name"] = display_name
if role_id is not None:
patch["role_id"] = role_id
if is_active is not None:
patch["is_active"] = is_active
if extra is not None:
patch["extra"] = extra
async with self._session_factory() as session:
try:
res = await session.execute(update(sys_user).where(sys_user.c.user_id == user_id).values(**patch))
await session.commit()
except IntegrityError:
await session.rollback()
raise
if res.rowcount == 0:
return None
return await self.get_user(user_id)
async def disable_user(self, user_id: str) -> bool:
"""禁用用户(软删除)。"""
async with self._session_factory() as session:
res = await session.execute(
update(sys_user).where(sys_user.c.user_id == user_id).values(is_active=False, updated_at=utc_now())
)
await session.commit()
return bool(res.rowcount)
async def set_password(self, user_id: str, new_password: str) -> bool:
"""设置用户密码(保存为哈希,不存明文)。"""
password_hash = hash_password(new_password)
async with self._session_factory() as session:
res = await session.execute(
update(sys_user).where(sys_user.c.user_id == user_id).values(password_hash=password_hash, updated_at=utc_now())
)
await session.commit()
return bool(res.rowcount)
async def touch_last_login(self, user_id: str) -> None:
"""更新用户最近登录时间。"""
async with self._session_factory() as session:
await session.execute(
update(sys_user).where(sys_user.c.user_id == user_id).values(last_login_at=utc_now(), updated_at=utc_now())
)
await session.commit()