SmartEDT/backend/services/rbac_service.py

187 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RBAC角色/权限)服务。
该模块围绕 sys_role / sys_permission / sys_role_permission 三张表提供基本的增删改查与绑定关系维护。
"""
from __future__ import annotations
import secrets
from typing import Any
from sqlalchemy import delete, insert, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from backend.database.schema import sys_permission, sys_role, sys_role_permission
from backend.utils import utc_now
class RbacService:
"""角色与权限点管理服务SQLAlchemy Core"""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def list_roles(self) -> list[dict[str, Any]]:
"""查询角色列表。"""
async with self._session_factory() as session:
q = (
select(
sys_role.c.role_id,
sys_role.c.role_name,
sys_role.c.role_desc,
sys_role.c.is_active,
sys_role.c.created_at,
sys_role.c.updated_at,
sys_role.c.extra,
)
.order_by(sys_role.c.role_name.asc())
)
return [dict(r) for r in (await session.execute(q)).mappings().all()]
async def get_role(self, role_id: str) -> dict[str, Any] | None:
"""按 role_id 查询角色。"""
async with self._session_factory() as session:
q = select(sys_role).where(sys_role.c.role_id == role_id).limit(1)
row = (await session.execute(q)).mappings().first()
return dict(row) if row else None
async def create_role(
self,
*,
role_id: str | None,
role_name: str,
role_desc: str | None = None,
is_active: bool = True,
extra: dict | None = None,
) -> dict[str, Any]:
"""创建角色。
说明role_id 不传时会自动生成。
"""
rid = role_id or ("role_" + secrets.token_hex(8))
values: dict[str, Any] = {
"role_id": rid,
"role_name": role_name,
"role_desc": role_desc,
"is_active": is_active,
"updated_at": utc_now(),
"extra": extra,
}
async with self._session_factory() as session:
try:
await session.execute(insert(sys_role).values(**values))
await session.commit()
except IntegrityError:
await session.rollback()
raise
created = await self.get_role(rid)
if not created:
raise RuntimeError("failed to create role")
return created
async def update_role(
self,
role_id: str,
*,
role_name: str | None = None,
role_desc: str | None = None,
is_active: bool | None = None,
extra: dict | None = None,
) -> dict[str, Any] | None:
"""更新角色字段(仅更新传入的字段)。"""
patch: dict[str, Any] = {"updated_at": utc_now()}
if role_name is not None:
patch["role_name"] = role_name
if role_desc is not None:
patch["role_desc"] = role_desc
if is_active is not None:
patch["is_active"] = is_active
if extra is not None:
patch["extra"] = extra
async with self._session_factory() as session:
try:
res = await session.execute(update(sys_role).where(sys_role.c.role_id == role_id).values(**patch))
await session.commit()
except IntegrityError:
await session.rollback()
raise
if res.rowcount == 0:
return None
return await self.get_role(role_id)
async def disable_role(self, role_id: str) -> bool:
"""禁用角色(软删除)。"""
async with self._session_factory() as session:
res = await session.execute(
update(sys_role).where(sys_role.c.role_id == role_id).values(is_active=False, updated_at=utc_now())
)
await session.commit()
return bool(res.rowcount)
async def list_permissions(self) -> list[dict[str, Any]]:
"""查询权限点列表。"""
async with self._session_factory() as session:
q = select(sys_permission).order_by(sys_permission.c.perm_group.asc().nulls_last(), sys_permission.c.perm_code.asc())
return [dict(r) for r in (await session.execute(q)).mappings().all()]
async def create_permission(
self, *, perm_code: str, perm_name: str, perm_group: str | None = None, perm_desc: str | None = None
) -> dict[str, Any]:
"""创建权限点。"""
async with self._session_factory() as session:
try:
await session.execute(
insert(sys_permission).values(
perm_code=perm_code, perm_name=perm_name, perm_group=perm_group, perm_desc=perm_desc
)
)
await session.commit()
except IntegrityError:
await session.rollback()
raise
q = select(sys_permission).where(sys_permission.c.perm_code == perm_code).limit(1)
row = (await session.execute(q)).mappings().first()
if not row:
raise RuntimeError("failed to create permission")
return dict(row)
async def delete_permission(self, perm_code: str) -> bool:
"""删除权限点,并清理与角色的关联。"""
async with self._session_factory() as session:
await session.execute(delete(sys_role_permission).where(sys_role_permission.c.perm_code == perm_code))
res = await session.execute(delete(sys_permission).where(sys_permission.c.perm_code == perm_code))
await session.commit()
return bool(res.rowcount)
async def get_role_permissions(self, role_id: str) -> list[str]:
"""查询指定角色拥有的权限点编码列表。"""
async with self._session_factory() as session:
q = select(sys_role_permission.c.perm_code).where(sys_role_permission.c.role_id == role_id)
rows = (await session.execute(q)).scalars().all()
return list(rows)
async def set_role_permissions(self, *, role_id: str, perm_codes: list[str]) -> list[str]:
"""覆盖设置角色权限点集合(先删后插)。"""
unique = list(dict.fromkeys(perm_codes))
async with self._session_factory() as session:
if unique:
q = select(sys_permission.c.perm_code).where(sys_permission.c.perm_code.in_(unique))
existing = set((await session.execute(q)).scalars().all())
missing = [c for c in unique if c not in existing]
if missing:
raise ValueError(f"missing permissions: {', '.join(missing)}")
try:
await session.execute(delete(sys_role_permission).where(sys_role_permission.c.role_id == role_id))
if unique:
await session.execute(
insert(sys_role_permission),
[{"role_id": role_id, "perm_code": code} for code in unique],
)
await session.commit()
except IntegrityError:
await session.rollback()
raise
return await self.get_role_permissions(role_id)