187 lines
7.5 KiB
Python
187 lines
7.5 KiB
Python
"""RBAC(角色/权限)服务。
|
||
|
||
该模块围绕 sys_role / sys_permission / sys_role_permission 三张表提供基本的增删改查与绑定关系维护。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import secrets
|
||
from typing import Any
|
||
|
||
from sqlalchemy import delete, insert, select, update
|
||
from sqlalchemy.exc import IntegrityError
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||
|
||
from backend.database.schema import sys_permission, sys_role, sys_role_permission
|
||
from backend.utils import utc_now
|
||
|
||
|
||
class RbacService:
|
||
"""角色与权限点管理服务(SQLAlchemy Core)。"""
|
||
|
||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||
self._session_factory = session_factory
|
||
|
||
async def list_roles(self) -> list[dict[str, Any]]:
|
||
"""查询角色列表。"""
|
||
async with self._session_factory() as session:
|
||
q = (
|
||
select(
|
||
sys_role.c.role_id,
|
||
sys_role.c.role_name,
|
||
sys_role.c.role_desc,
|
||
sys_role.c.is_active,
|
||
sys_role.c.created_at,
|
||
sys_role.c.updated_at,
|
||
sys_role.c.extra,
|
||
)
|
||
.order_by(sys_role.c.role_name.asc())
|
||
)
|
||
return [dict(r) for r in (await session.execute(q)).mappings().all()]
|
||
|
||
async def get_role(self, role_id: str) -> dict[str, Any] | None:
|
||
"""按 role_id 查询角色。"""
|
||
async with self._session_factory() as session:
|
||
q = select(sys_role).where(sys_role.c.role_id == role_id).limit(1)
|
||
row = (await session.execute(q)).mappings().first()
|
||
return dict(row) if row else None
|
||
|
||
async def create_role(
|
||
self,
|
||
*,
|
||
role_id: str | None,
|
||
role_name: str,
|
||
role_desc: str | None = None,
|
||
is_active: bool = True,
|
||
extra: dict | None = None,
|
||
) -> dict[str, Any]:
|
||
"""创建角色。
|
||
|
||
说明:role_id 不传时会自动生成。
|
||
"""
|
||
rid = role_id or ("role_" + secrets.token_hex(8))
|
||
values: dict[str, Any] = {
|
||
"role_id": rid,
|
||
"role_name": role_name,
|
||
"role_desc": role_desc,
|
||
"is_active": is_active,
|
||
"updated_at": utc_now(),
|
||
"extra": extra,
|
||
}
|
||
async with self._session_factory() as session:
|
||
try:
|
||
await session.execute(insert(sys_role).values(**values))
|
||
await session.commit()
|
||
except IntegrityError:
|
||
await session.rollback()
|
||
raise
|
||
created = await self.get_role(rid)
|
||
if not created:
|
||
raise RuntimeError("failed to create role")
|
||
return created
|
||
|
||
async def update_role(
|
||
self,
|
||
role_id: str,
|
||
*,
|
||
role_name: str | None = None,
|
||
role_desc: str | None = None,
|
||
is_active: bool | None = None,
|
||
extra: dict | None = None,
|
||
) -> dict[str, Any] | None:
|
||
"""更新角色字段(仅更新传入的字段)。"""
|
||
patch: dict[str, Any] = {"updated_at": utc_now()}
|
||
if role_name is not None:
|
||
patch["role_name"] = role_name
|
||
if role_desc is not None:
|
||
patch["role_desc"] = role_desc
|
||
if is_active is not None:
|
||
patch["is_active"] = is_active
|
||
if extra is not None:
|
||
patch["extra"] = extra
|
||
async with self._session_factory() as session:
|
||
try:
|
||
res = await session.execute(update(sys_role).where(sys_role.c.role_id == role_id).values(**patch))
|
||
await session.commit()
|
||
except IntegrityError:
|
||
await session.rollback()
|
||
raise
|
||
if res.rowcount == 0:
|
||
return None
|
||
return await self.get_role(role_id)
|
||
|
||
async def disable_role(self, role_id: str) -> bool:
|
||
"""禁用角色(软删除)。"""
|
||
async with self._session_factory() as session:
|
||
res = await session.execute(
|
||
update(sys_role).where(sys_role.c.role_id == role_id).values(is_active=False, updated_at=utc_now())
|
||
)
|
||
await session.commit()
|
||
return bool(res.rowcount)
|
||
|
||
async def list_permissions(self) -> list[dict[str, Any]]:
|
||
"""查询权限点列表。"""
|
||
async with self._session_factory() as session:
|
||
q = select(sys_permission).order_by(sys_permission.c.perm_group.asc().nulls_last(), sys_permission.c.perm_code.asc())
|
||
return [dict(r) for r in (await session.execute(q)).mappings().all()]
|
||
|
||
async def create_permission(
|
||
self, *, perm_code: str, perm_name: str, perm_group: str | None = None, perm_desc: str | None = None
|
||
) -> dict[str, Any]:
|
||
"""创建权限点。"""
|
||
async with self._session_factory() as session:
|
||
try:
|
||
await session.execute(
|
||
insert(sys_permission).values(
|
||
perm_code=perm_code, perm_name=perm_name, perm_group=perm_group, perm_desc=perm_desc
|
||
)
|
||
)
|
||
await session.commit()
|
||
except IntegrityError:
|
||
await session.rollback()
|
||
raise
|
||
q = select(sys_permission).where(sys_permission.c.perm_code == perm_code).limit(1)
|
||
row = (await session.execute(q)).mappings().first()
|
||
if not row:
|
||
raise RuntimeError("failed to create permission")
|
||
return dict(row)
|
||
|
||
async def delete_permission(self, perm_code: str) -> bool:
|
||
"""删除权限点,并清理与角色的关联。"""
|
||
async with self._session_factory() as session:
|
||
await session.execute(delete(sys_role_permission).where(sys_role_permission.c.perm_code == perm_code))
|
||
res = await session.execute(delete(sys_permission).where(sys_permission.c.perm_code == perm_code))
|
||
await session.commit()
|
||
return bool(res.rowcount)
|
||
|
||
async def get_role_permissions(self, role_id: str) -> list[str]:
|
||
"""查询指定角色拥有的权限点编码列表。"""
|
||
async with self._session_factory() as session:
|
||
q = select(sys_role_permission.c.perm_code).where(sys_role_permission.c.role_id == role_id)
|
||
rows = (await session.execute(q)).scalars().all()
|
||
return list(rows)
|
||
|
||
async def set_role_permissions(self, *, role_id: str, perm_codes: list[str]) -> list[str]:
|
||
"""覆盖设置角色权限点集合(先删后插)。"""
|
||
unique = list(dict.fromkeys(perm_codes))
|
||
async with self._session_factory() as session:
|
||
if unique:
|
||
q = select(sys_permission.c.perm_code).where(sys_permission.c.perm_code.in_(unique))
|
||
existing = set((await session.execute(q)).scalars().all())
|
||
missing = [c for c in unique if c not in existing]
|
||
if missing:
|
||
raise ValueError(f"missing permissions: {', '.join(missing)}")
|
||
|
||
try:
|
||
await session.execute(delete(sys_role_permission).where(sys_role_permission.c.role_id == role_id))
|
||
if unique:
|
||
await session.execute(
|
||
insert(sys_role_permission),
|
||
[{"role_id": role_id, "perm_code": code} for code in unique],
|
||
)
|
||
await session.commit()
|
||
except IntegrityError:
|
||
await session.rollback()
|
||
raise
|
||
return await self.get_role_permissions(role_id)
|