SmartEDT/backend/services/user_service.py

170 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""系统用户服务。
围绕 sys_user 表提供用户的增删改查、密码设置(写入哈希)、登录时间维护等能力。
"""
from __future__ import annotations
import secrets
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from backend.auth.passwords import hash_password
from backend.database.schema import sys_role, sys_user
from backend.utils import utc_now
class UserService:
"""系统用户管理服务SQLAlchemy Core"""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def list_users(self) -> list[dict[str, Any]]:
"""查询用户列表(包含角色名称)。"""
async with self._session_factory() as session:
q = (
select(
sys_user.c.user_id,
sys_user.c.username,
sys_user.c.display_name,
sys_user.c.role_id,
sys_role.c.role_name,
sys_user.c.is_active,
sys_user.c.last_login_at,
sys_user.c.created_at,
sys_user.c.updated_at,
sys_user.c.extra,
)
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
.order_by(sys_user.c.created_at.desc())
)
return [dict(r) for r in (await session.execute(q)).mappings().all()]
async def get_user(self, user_id: str) -> dict[str, Any] | None:
"""按 user_id 查询用户(包含角色名称)。"""
async with self._session_factory() as session:
q = (
select(
sys_user.c.user_id,
sys_user.c.username,
sys_user.c.display_name,
sys_user.c.role_id,
sys_role.c.role_name,
sys_user.c.is_active,
sys_user.c.last_login_at,
sys_user.c.created_at,
sys_user.c.updated_at,
sys_user.c.extra,
)
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
.where(sys_user.c.user_id == user_id)
.limit(1)
)
row = (await session.execute(q)).mappings().first()
return dict(row) if row else None
async def get_user_by_username(self, username: str) -> dict[str, Any] | None:
"""按 username 查询用户(用于登录)。"""
async with self._session_factory() as session:
q = select(sys_user).where(sys_user.c.username == username).limit(1)
row = (await session.execute(q)).mappings().first()
return dict(row) if row else None
async def create_user(
self,
*,
user_id: str | None,
username: str,
password: str,
role_id: str,
display_name: str | None = None,
is_active: bool = True,
extra: dict | None = None,
) -> dict[str, Any]:
"""创建用户并写入密码哈希。"""
uid = user_id or ("user_" + secrets.token_hex(8))
password_hash = hash_password(password)
async with self._session_factory() as session:
try:
await session.execute(
sys_user.insert().values(
user_id=uid,
username=username,
display_name=display_name,
password_hash=password_hash,
role_id=role_id,
is_active=is_active,
updated_at=utc_now(),
extra=extra,
)
)
await session.commit()
except IntegrityError:
await session.rollback()
raise
created = await self.get_user(uid)
if not created:
raise RuntimeError("failed to create user")
return created
async def update_user(
self,
user_id: str,
*,
display_name: str | None = None,
role_id: str | None = None,
is_active: bool | None = None,
extra: dict | None = None,
) -> dict[str, Any] | None:
"""更新用户字段(仅更新传入的字段)。"""
patch: dict[str, Any] = {"updated_at": utc_now()}
if display_name is not None:
patch["display_name"] = display_name
if role_id is not None:
patch["role_id"] = role_id
if is_active is not None:
patch["is_active"] = is_active
if extra is not None:
patch["extra"] = extra
async with self._session_factory() as session:
try:
res = await session.execute(update(sys_user).where(sys_user.c.user_id == user_id).values(**patch))
await session.commit()
except IntegrityError:
await session.rollback()
raise
if res.rowcount == 0:
return None
return await self.get_user(user_id)
async def disable_user(self, user_id: str) -> bool:
"""禁用用户(软删除)。"""
async with self._session_factory() as session:
res = await session.execute(
update(sys_user).where(sys_user.c.user_id == user_id).values(is_active=False, updated_at=utc_now())
)
await session.commit()
return bool(res.rowcount)
async def set_password(self, user_id: str, new_password: str) -> bool:
"""设置用户密码(保存为哈希,不存明文)。"""
password_hash = hash_password(new_password)
async with self._session_factory() as session:
res = await session.execute(
update(sys_user).where(sys_user.c.user_id == user_id).values(password_hash=password_hash, updated_at=utc_now())
)
await session.commit()
return bool(res.rowcount)
async def touch_last_login(self, user_id: str) -> None:
"""更新用户最近登录时间。"""
async with self._session_factory() as session:
await session.execute(
update(sys_user).where(sys_user.c.user_id == user_id).values(last_login_at=utc_now(), updated_at=utc_now())
)
await session.commit()