"""RBAC(角色/权限)服务。 该模块围绕 sys_role / sys_permission / sys_role_permission 三张表提供基本的增删改查与绑定关系维护。 """ from __future__ import annotations import secrets from typing import Any from sqlalchemy import delete, insert, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from backend.database.schema import sys_permission, sys_role, sys_role_permission from backend.utils import utc_now class RbacService: """角色与权限点管理服务(SQLAlchemy Core)。""" def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: self._session_factory = session_factory async def list_roles(self) -> list[dict[str, Any]]: """查询角色列表。""" async with self._session_factory() as session: q = ( select( sys_role.c.role_id, sys_role.c.role_name, sys_role.c.role_desc, sys_role.c.is_active, sys_role.c.created_at, sys_role.c.updated_at, sys_role.c.extra, ) .order_by(sys_role.c.role_name.asc()) ) return [dict(r) for r in (await session.execute(q)).mappings().all()] async def get_role(self, role_id: str) -> dict[str, Any] | None: """按 role_id 查询角色。""" async with self._session_factory() as session: q = select(sys_role).where(sys_role.c.role_id == role_id).limit(1) row = (await session.execute(q)).mappings().first() return dict(row) if row else None async def create_role( self, *, role_id: str | None, role_name: str, role_desc: str | None = None, is_active: bool = True, extra: dict | None = None, ) -> dict[str, Any]: """创建角色。 说明:role_id 不传时会自动生成。 """ rid = role_id or ("role_" + secrets.token_hex(8)) values: dict[str, Any] = { "role_id": rid, "role_name": role_name, "role_desc": role_desc, "is_active": is_active, "updated_at": utc_now(), "extra": extra, } async with self._session_factory() as session: try: await session.execute(insert(sys_role).values(**values)) await session.commit() except IntegrityError: await session.rollback() raise created = await self.get_role(rid) if not created: raise RuntimeError("failed to create role") return created async def update_role( self, role_id: str, *, role_name: str | None = None, role_desc: str | None = None, is_active: bool | None = None, extra: dict | None = None, ) -> dict[str, Any] | None: """更新角色字段(仅更新传入的字段)。""" patch: dict[str, Any] = {"updated_at": utc_now()} if role_name is not None: patch["role_name"] = role_name if role_desc is not None: patch["role_desc"] = role_desc if is_active is not None: patch["is_active"] = is_active if extra is not None: patch["extra"] = extra async with self._session_factory() as session: try: res = await session.execute(update(sys_role).where(sys_role.c.role_id == role_id).values(**patch)) await session.commit() except IntegrityError: await session.rollback() raise if res.rowcount == 0: return None return await self.get_role(role_id) async def disable_role(self, role_id: str) -> bool: """禁用角色(软删除)。""" async with self._session_factory() as session: res = await session.execute( update(sys_role).where(sys_role.c.role_id == role_id).values(is_active=False, updated_at=utc_now()) ) await session.commit() return bool(res.rowcount) async def list_permissions(self) -> list[dict[str, Any]]: """查询权限点列表。""" async with self._session_factory() as session: q = select(sys_permission).order_by(sys_permission.c.perm_group.asc().nulls_last(), sys_permission.c.perm_code.asc()) return [dict(r) for r in (await session.execute(q)).mappings().all()] async def create_permission( self, *, perm_code: str, perm_name: str, perm_group: str | None = None, perm_desc: str | None = None ) -> dict[str, Any]: """创建权限点。""" async with self._session_factory() as session: try: await session.execute( insert(sys_permission).values( perm_code=perm_code, perm_name=perm_name, perm_group=perm_group, perm_desc=perm_desc ) ) await session.commit() except IntegrityError: await session.rollback() raise q = select(sys_permission).where(sys_permission.c.perm_code == perm_code).limit(1) row = (await session.execute(q)).mappings().first() if not row: raise RuntimeError("failed to create permission") return dict(row) async def delete_permission(self, perm_code: str) -> bool: """删除权限点,并清理与角色的关联。""" async with self._session_factory() as session: await session.execute(delete(sys_role_permission).where(sys_role_permission.c.perm_code == perm_code)) res = await session.execute(delete(sys_permission).where(sys_permission.c.perm_code == perm_code)) await session.commit() return bool(res.rowcount) async def get_role_permissions(self, role_id: str) -> list[str]: """查询指定角色拥有的权限点编码列表。""" async with self._session_factory() as session: q = select(sys_role_permission.c.perm_code).where(sys_role_permission.c.role_id == role_id) rows = (await session.execute(q)).scalars().all() return list(rows) async def set_role_permissions(self, *, role_id: str, perm_codes: list[str]) -> list[str]: """覆盖设置角色权限点集合(先删后插)。""" unique = list(dict.fromkeys(perm_codes)) async with self._session_factory() as session: if unique: q = select(sys_permission.c.perm_code).where(sys_permission.c.perm_code.in_(unique)) existing = set((await session.execute(q)).scalars().all()) missing = [c for c in unique if c not in existing] if missing: raise ValueError(f"missing permissions: {', '.join(missing)}") try: await session.execute(delete(sys_role_permission).where(sys_role_permission.c.role_id == role_id)) if unique: await session.execute( insert(sys_role_permission), [{"role_id": role_id, "perm_code": code} for code in unique], ) await session.commit() except IntegrityError: await session.rollback() raise return await self.get_role_permissions(role_id)