170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
"""系统用户服务。
|
||
|
||
围绕 sys_user 表提供用户的增删改查、密码设置(写入哈希)、登录时间维护等能力。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import secrets
|
||
from typing import Any
|
||
|
||
from sqlalchemy import select, update
|
||
from sqlalchemy.exc import IntegrityError
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||
|
||
from backend.auth.passwords import hash_password
|
||
from backend.database.schema import sys_role, sys_user
|
||
from backend.utils import utc_now
|
||
|
||
|
||
class UserService:
|
||
"""系统用户管理服务(SQLAlchemy Core)。"""
|
||
|
||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||
self._session_factory = session_factory
|
||
|
||
async def list_users(self) -> list[dict[str, Any]]:
|
||
"""查询用户列表(包含角色名称)。"""
|
||
async with self._session_factory() as session:
|
||
q = (
|
||
select(
|
||
sys_user.c.user_id,
|
||
sys_user.c.username,
|
||
sys_user.c.display_name,
|
||
sys_user.c.role_id,
|
||
sys_role.c.role_name,
|
||
sys_user.c.is_active,
|
||
sys_user.c.last_login_at,
|
||
sys_user.c.created_at,
|
||
sys_user.c.updated_at,
|
||
sys_user.c.extra,
|
||
)
|
||
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
|
||
.order_by(sys_user.c.created_at.desc())
|
||
)
|
||
return [dict(r) for r in (await session.execute(q)).mappings().all()]
|
||
|
||
async def get_user(self, user_id: str) -> dict[str, Any] | None:
|
||
"""按 user_id 查询用户(包含角色名称)。"""
|
||
async with self._session_factory() as session:
|
||
q = (
|
||
select(
|
||
sys_user.c.user_id,
|
||
sys_user.c.username,
|
||
sys_user.c.display_name,
|
||
sys_user.c.role_id,
|
||
sys_role.c.role_name,
|
||
sys_user.c.is_active,
|
||
sys_user.c.last_login_at,
|
||
sys_user.c.created_at,
|
||
sys_user.c.updated_at,
|
||
sys_user.c.extra,
|
||
)
|
||
.select_from(sys_user.join(sys_role, sys_user.c.role_id == sys_role.c.role_id))
|
||
.where(sys_user.c.user_id == user_id)
|
||
.limit(1)
|
||
)
|
||
row = (await session.execute(q)).mappings().first()
|
||
return dict(row) if row else None
|
||
|
||
async def get_user_by_username(self, username: str) -> dict[str, Any] | None:
|
||
"""按 username 查询用户(用于登录)。"""
|
||
async with self._session_factory() as session:
|
||
q = select(sys_user).where(sys_user.c.username == username).limit(1)
|
||
row = (await session.execute(q)).mappings().first()
|
||
return dict(row) if row else None
|
||
|
||
async def create_user(
|
||
self,
|
||
*,
|
||
user_id: str | None,
|
||
username: str,
|
||
password: str,
|
||
role_id: str,
|
||
display_name: str | None = None,
|
||
is_active: bool = True,
|
||
extra: dict | None = None,
|
||
) -> dict[str, Any]:
|
||
"""创建用户并写入密码哈希。"""
|
||
uid = user_id or ("user_" + secrets.token_hex(8))
|
||
password_hash = hash_password(password)
|
||
async with self._session_factory() as session:
|
||
try:
|
||
await session.execute(
|
||
sys_user.insert().values(
|
||
user_id=uid,
|
||
username=username,
|
||
display_name=display_name,
|
||
password_hash=password_hash,
|
||
role_id=role_id,
|
||
is_active=is_active,
|
||
updated_at=utc_now(),
|
||
extra=extra,
|
||
)
|
||
)
|
||
await session.commit()
|
||
except IntegrityError:
|
||
await session.rollback()
|
||
raise
|
||
created = await self.get_user(uid)
|
||
if not created:
|
||
raise RuntimeError("failed to create user")
|
||
return created
|
||
|
||
async def update_user(
|
||
self,
|
||
user_id: str,
|
||
*,
|
||
display_name: str | None = None,
|
||
role_id: str | None = None,
|
||
is_active: bool | None = None,
|
||
extra: dict | None = None,
|
||
) -> dict[str, Any] | None:
|
||
"""更新用户字段(仅更新传入的字段)。"""
|
||
patch: dict[str, Any] = {"updated_at": utc_now()}
|
||
if display_name is not None:
|
||
patch["display_name"] = display_name
|
||
if role_id is not None:
|
||
patch["role_id"] = role_id
|
||
if is_active is not None:
|
||
patch["is_active"] = is_active
|
||
if extra is not None:
|
||
patch["extra"] = extra
|
||
async with self._session_factory() as session:
|
||
try:
|
||
res = await session.execute(update(sys_user).where(sys_user.c.user_id == user_id).values(**patch))
|
||
await session.commit()
|
||
except IntegrityError:
|
||
await session.rollback()
|
||
raise
|
||
if res.rowcount == 0:
|
||
return None
|
||
return await self.get_user(user_id)
|
||
|
||
async def disable_user(self, user_id: str) -> bool:
|
||
"""禁用用户(软删除)。"""
|
||
async with self._session_factory() as session:
|
||
res = await session.execute(
|
||
update(sys_user).where(sys_user.c.user_id == user_id).values(is_active=False, updated_at=utc_now())
|
||
)
|
||
await session.commit()
|
||
return bool(res.rowcount)
|
||
|
||
async def set_password(self, user_id: str, new_password: str) -> bool:
|
||
"""设置用户密码(保存为哈希,不存明文)。"""
|
||
password_hash = hash_password(new_password)
|
||
async with self._session_factory() as session:
|
||
res = await session.execute(
|
||
update(sys_user).where(sys_user.c.user_id == user_id).values(password_hash=password_hash, updated_at=utc_now())
|
||
)
|
||
await session.commit()
|
||
return bool(res.rowcount)
|
||
|
||
async def touch_last_login(self, user_id: str) -> None:
|
||
"""更新用户最近登录时间。"""
|
||
async with self._session_factory() as session:
|
||
await session.execute(
|
||
update(sys_user).where(sys_user.c.user_id == user_id).values(last_login_at=utc_now(), updated_at=utc_now())
|
||
)
|
||
await session.commit()
|