SmartEDT/backend/services/rbac_service.py

187 lines
7.5 KiB
Python
Raw Normal View History

"""RBAC角色/权限)服务。
该模块围绕 sys_role / sys_permission / sys_role_permission 三张表提供基本的增删改查与绑定关系维护
"""
from __future__ import annotations
import secrets
from typing import Any
from sqlalchemy import delete, insert, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from backend.database.schema import sys_permission, sys_role, sys_role_permission
from backend.utils import utc_now
class RbacService:
"""角色与权限点管理服务SQLAlchemy Core"""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def list_roles(self) -> list[dict[str, Any]]:
"""查询角色列表。"""
async with self._session_factory() as session:
q = (
select(
sys_role.c.role_id,
sys_role.c.role_name,
sys_role.c.role_desc,
sys_role.c.is_active,
sys_role.c.created_at,
sys_role.c.updated_at,
sys_role.c.extra,
)
.order_by(sys_role.c.role_name.asc())
)
return [dict(r) for r in (await session.execute(q)).mappings().all()]
async def get_role(self, role_id: str) -> dict[str, Any] | None:
"""按 role_id 查询角色。"""
async with self._session_factory() as session:
q = select(sys_role).where(sys_role.c.role_id == role_id).limit(1)
row = (await session.execute(q)).mappings().first()
return dict(row) if row else None
async def create_role(
self,
*,
role_id: str | None,
role_name: str,
role_desc: str | None = None,
is_active: bool = True,
extra: dict | None = None,
) -> dict[str, Any]:
"""创建角色。
说明role_id 不传时会自动生成
"""
rid = role_id or ("role_" + secrets.token_hex(8))
values: dict[str, Any] = {
"role_id": rid,
"role_name": role_name,
"role_desc": role_desc,
"is_active": is_active,
"updated_at": utc_now(),
"extra": extra,
}
async with self._session_factory() as session:
try:
await session.execute(insert(sys_role).values(**values))
await session.commit()
except IntegrityError:
await session.rollback()
raise
created = await self.get_role(rid)
if not created:
raise RuntimeError("failed to create role")
return created
async def update_role(
self,
role_id: str,
*,
role_name: str | None = None,
role_desc: str | None = None,
is_active: bool | None = None,
extra: dict | None = None,
) -> dict[str, Any] | None:
"""更新角色字段(仅更新传入的字段)。"""
patch: dict[str, Any] = {"updated_at": utc_now()}
if role_name is not None:
patch["role_name"] = role_name
if role_desc is not None:
patch["role_desc"] = role_desc
if is_active is not None:
patch["is_active"] = is_active
if extra is not None:
patch["extra"] = extra
async with self._session_factory() as session:
try:
res = await session.execute(update(sys_role).where(sys_role.c.role_id == role_id).values(**patch))
await session.commit()
except IntegrityError:
await session.rollback()
raise
if res.rowcount == 0:
return None
return await self.get_role(role_id)
async def disable_role(self, role_id: str) -> bool:
"""禁用角色(软删除)。"""
async with self._session_factory() as session:
res = await session.execute(
update(sys_role).where(sys_role.c.role_id == role_id).values(is_active=False, updated_at=utc_now())
)
await session.commit()
return bool(res.rowcount)
async def list_permissions(self) -> list[dict[str, Any]]:
"""查询权限点列表。"""
async with self._session_factory() as session:
q = select(sys_permission).order_by(sys_permission.c.perm_group.asc().nulls_last(), sys_permission.c.perm_code.asc())
return [dict(r) for r in (await session.execute(q)).mappings().all()]
async def create_permission(
self, *, perm_code: str, perm_name: str, perm_group: str | None = None, perm_desc: str | None = None
) -> dict[str, Any]:
"""创建权限点。"""
async with self._session_factory() as session:
try:
await session.execute(
insert(sys_permission).values(
perm_code=perm_code, perm_name=perm_name, perm_group=perm_group, perm_desc=perm_desc
)
)
await session.commit()
except IntegrityError:
await session.rollback()
raise
q = select(sys_permission).where(sys_permission.c.perm_code == perm_code).limit(1)
row = (await session.execute(q)).mappings().first()
if not row:
raise RuntimeError("failed to create permission")
return dict(row)
async def delete_permission(self, perm_code: str) -> bool:
"""删除权限点,并清理与角色的关联。"""
async with self._session_factory() as session:
await session.execute(delete(sys_role_permission).where(sys_role_permission.c.perm_code == perm_code))
res = await session.execute(delete(sys_permission).where(sys_permission.c.perm_code == perm_code))
await session.commit()
return bool(res.rowcount)
async def get_role_permissions(self, role_id: str) -> list[str]:
"""查询指定角色拥有的权限点编码列表。"""
async with self._session_factory() as session:
q = select(sys_role_permission.c.perm_code).where(sys_role_permission.c.role_id == role_id)
rows = (await session.execute(q)).scalars().all()
return list(rows)
async def set_role_permissions(self, *, role_id: str, perm_codes: list[str]) -> list[str]:
"""覆盖设置角色权限点集合(先删后插)。"""
unique = list(dict.fromkeys(perm_codes))
async with self._session_factory() as session:
if unique:
q = select(sys_permission.c.perm_code).where(sys_permission.c.perm_code.in_(unique))
existing = set((await session.execute(q)).scalars().all())
missing = [c for c in unique if c not in existing]
if missing:
raise ValueError(f"missing permissions: {', '.join(missing)}")
try:
await session.execute(delete(sys_role_permission).where(sys_role_permission.c.role_id == role_id))
if unique:
await session.execute(
insert(sys_role_permission),
[{"role_id": role_id, "perm_code": code} for code in unique],
)
await session.commit()
except IntegrityError:
await session.rollback()
raise
return await self.get_role_permissions(role_id)