170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
|
|
"""系统用户服务。
|
|||
|
|
|
|||
|
|
围绕 sys_user 表提供用户的增删改查、密码设置(写入哈希)、登录时间维护等能力。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import secrets
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from sqlalchemy import select, update
|
|||
|
|
from sqlalchemy.exc import IntegrityError
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||
|
|
|
|||
|
|
from backend.auth.passwords import hash_password
|
|||
|
|
from backend.database.schema import sys_role, sys_user
|
|||
|
|
from backend.utils import utc_now
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UserService:
|
|||
|
|
"""系统用户管理服务(SQLAlchemy Core)。"""
|
|||
|
|
|
|||
|
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
|||
|
|
self._session_factory = session_factory
|
|||
|
|
|
|||
|
|
async def list_users(self) -> list[dict[str, Any]]:
|
|||
|
|
"""查询用户列表(包含角色名称)。"""
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
q = (
|
|||
|
|
select(
|
|||
|
|
sys_user.c.user_id,
|
|||
|
|
sys_user.c.username,
|
|||
|
|
sys_user.c.display_name,
|
|||
|
|
sys_user.c.role_id,
|
|||
|
|
sys_role.c.role_name,
|
|||
|
|
sys_user.c.is_active,
|
|||
|
|
sys_user.c.last_login_at,
|
|||
|
|
sys_user.c.created_at,
|
|||
|
|
sys_user.c.updated_at,
|
|||
|
|
sys_user.c.extra,
|
|||
|
|
)
|
|||
|
|
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
|
|||
|
|
.order_by(sys_user.c.created_at.desc())
|
|||
|
|
)
|
|||
|
|
return [dict(r) for r in (await session.execute(q)).mappings().all()]
|
|||
|
|
|
|||
|
|
async def get_user(self, user_id: str) -> dict[str, Any] | None:
|
|||
|
|
"""按 user_id 查询用户(包含角色名称)。"""
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
q = (
|
|||
|
|
select(
|
|||
|
|
sys_user.c.user_id,
|
|||
|
|
sys_user.c.username,
|
|||
|
|
sys_user.c.display_name,
|
|||
|
|
sys_user.c.role_id,
|
|||
|
|
sys_role.c.role_name,
|
|||
|
|
sys_user.c.is_active,
|
|||
|
|
sys_user.c.last_login_at,
|
|||
|
|
sys_user.c.created_at,
|
|||
|
|
sys_user.c.updated_at,
|
|||
|
|
sys_user.c.extra,
|
|||
|
|
)
|
|||
|
|
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
|
|||
|
|
.where(sys_user.c.user_id == user_id)
|
|||
|
|
.limit(1)
|
|||
|
|
)
|
|||
|
|
row = (await session.execute(q)).mappings().first()
|
|||
|
|
return dict(row) if row else None
|
|||
|
|
|
|||
|
|
async def get_user_by_username(self, username: str) -> dict[str, Any] | None:
|
|||
|
|
"""按 username 查询用户(用于登录)。"""
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
q = select(sys_user).where(sys_user.c.username == username).limit(1)
|
|||
|
|
row = (await session.execute(q)).mappings().first()
|
|||
|
|
return dict(row) if row else None
|
|||
|
|
|
|||
|
|
async def create_user(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
user_id: str | None,
|
|||
|
|
username: str,
|
|||
|
|
password: str,
|
|||
|
|
role_id: str,
|
|||
|
|
display_name: str | None = None,
|
|||
|
|
is_active: bool = True,
|
|||
|
|
extra: dict | None = None,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
"""创建用户并写入密码哈希。"""
|
|||
|
|
uid = user_id or ("user_" + secrets.token_hex(8))
|
|||
|
|
password_hash = hash_password(password)
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
try:
|
|||
|
|
await session.execute(
|
|||
|
|
sys_user.insert().values(
|
|||
|
|
user_id=uid,
|
|||
|
|
username=username,
|
|||
|
|
display_name=display_name,
|
|||
|
|
password_hash=password_hash,
|
|||
|
|
role_id=role_id,
|
|||
|
|
is_active=is_active,
|
|||
|
|
updated_at=utc_now(),
|
|||
|
|
extra=extra,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
await session.commit()
|
|||
|
|
except IntegrityError:
|
|||
|
|
await session.rollback()
|
|||
|
|
raise
|
|||
|
|
created = await self.get_user(uid)
|
|||
|
|
if not created:
|
|||
|
|
raise RuntimeError("failed to create user")
|
|||
|
|
return created
|
|||
|
|
|
|||
|
|
async def update_user(
|
|||
|
|
self,
|
|||
|
|
user_id: str,
|
|||
|
|
*,
|
|||
|
|
display_name: str | None = None,
|
|||
|
|
role_id: str | None = None,
|
|||
|
|
is_active: bool | None = None,
|
|||
|
|
extra: dict | None = None,
|
|||
|
|
) -> dict[str, Any] | None:
|
|||
|
|
"""更新用户字段(仅更新传入的字段)。"""
|
|||
|
|
patch: dict[str, Any] = {"updated_at": utc_now()}
|
|||
|
|
if display_name is not None:
|
|||
|
|
patch["display_name"] = display_name
|
|||
|
|
if role_id is not None:
|
|||
|
|
patch["role_id"] = role_id
|
|||
|
|
if is_active is not None:
|
|||
|
|
patch["is_active"] = is_active
|
|||
|
|
if extra is not None:
|
|||
|
|
patch["extra"] = extra
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
try:
|
|||
|
|
res = await session.execute(update(sys_user).where(sys_user.c.user_id == user_id).values(**patch))
|
|||
|
|
await session.commit()
|
|||
|
|
except IntegrityError:
|
|||
|
|
await session.rollback()
|
|||
|
|
raise
|
|||
|
|
if res.rowcount == 0:
|
|||
|
|
return None
|
|||
|
|
return await self.get_user(user_id)
|
|||
|
|
|
|||
|
|
async def disable_user(self, user_id: str) -> bool:
|
|||
|
|
"""禁用用户(软删除)。"""
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
res = await session.execute(
|
|||
|
|
update(sys_user).where(sys_user.c.user_id == user_id).values(is_active=False, updated_at=utc_now())
|
|||
|
|
)
|
|||
|
|
await session.commit()
|
|||
|
|
return bool(res.rowcount)
|
|||
|
|
|
|||
|
|
async def set_password(self, user_id: str, new_password: str) -> bool:
|
|||
|
|
"""设置用户密码(保存为哈希,不存明文)。"""
|
|||
|
|
password_hash = hash_password(new_password)
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
res = await session.execute(
|
|||
|
|
update(sys_user).where(sys_user.c.user_id == user_id).values(password_hash=password_hash, updated_at=utc_now())
|
|||
|
|
)
|
|||
|
|
await session.commit()
|
|||
|
|
return bool(res.rowcount)
|
|||
|
|
|
|||
|
|
async def touch_last_login(self, user_id: str) -> None:
|
|||
|
|
"""更新用户最近登录时间。"""
|
|||
|
|
async with self._session_factory() as session:
|
|||
|
|
await session.execute(
|
|||
|
|
update(sys_user).where(sys_user.c.user_id == user_id).values(last_login_at=utc_now(), updated_at=utc_now())
|
|||
|
|
)
|
|||
|
|
await session.commit()
|