100 lines
3.0 KiB
Python
100 lines
3.0 KiB
Python
|
|
"""轻量 access token 签发与校验。
|
|||
|
|
|
|||
|
|
说明:
|
|||
|
|
- 当前实现不是标准 JWT(避免引入额外依赖),而是“base64url(payload) + HMAC 签名”的轻量令牌。
|
|||
|
|
- 适用于内部系统的最小化认证需求;如需与第三方兼容,可替换为 JWT。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import base64
|
|||
|
|
import hashlib
|
|||
|
|
import hmac
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class TokenPayload:
|
|||
|
|
"""解析后的 token 载荷。"""
|
|||
|
|
|
|||
|
|
user_id: str
|
|||
|
|
username: str
|
|||
|
|
role_id: str
|
|||
|
|
exp: int
|
|||
|
|
iat: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
def issue_access_token(*, user_id: str, username: str, role_id: str, expires_in_seconds: int = 3600) -> str:
|
|||
|
|
"""签发 access token。"""
|
|||
|
|
now = int(time.time())
|
|||
|
|
payload = {
|
|||
|
|
"sub": user_id,
|
|||
|
|
"username": username,
|
|||
|
|
"role_id": role_id,
|
|||
|
|
"iat": now,
|
|||
|
|
"exp": now + int(expires_in_seconds),
|
|||
|
|
"v": 1,
|
|||
|
|
}
|
|||
|
|
payload_bytes = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
|||
|
|
payload_b64 = _b64url_encode(payload_bytes)
|
|||
|
|
sig = _sign(payload_b64.encode("ascii"))
|
|||
|
|
sig_b64 = _b64url_encode(sig)
|
|||
|
|
return f"v1.{payload_b64}.{sig_b64}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def verify_access_token(token: str) -> TokenPayload:
|
|||
|
|
"""校验 access token 并返回载荷。
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ValueError: token 非法或已过期。
|
|||
|
|
"""
|
|||
|
|
if not token or not isinstance(token, str):
|
|||
|
|
raise ValueError("invalid token")
|
|||
|
|
parts = token.split(".")
|
|||
|
|
if len(parts) != 3 or parts[0] != "v1":
|
|||
|
|
raise ValueError("invalid token")
|
|||
|
|
payload_b64, sig_b64 = parts[1], parts[2]
|
|||
|
|
expected_sig = _sign(payload_b64.encode("ascii"))
|
|||
|
|
actual_sig = _b64url_decode(sig_b64)
|
|||
|
|
if not hmac.compare_digest(expected_sig, actual_sig):
|
|||
|
|
raise ValueError("invalid token")
|
|||
|
|
payload_raw = _b64url_decode(payload_b64)
|
|||
|
|
payload = json.loads(payload_raw.decode("utf-8"))
|
|||
|
|
exp = int(payload.get("exp"))
|
|||
|
|
if int(time.time()) >= exp:
|
|||
|
|
raise ValueError("token expired")
|
|||
|
|
return TokenPayload(
|
|||
|
|
user_id=str(payload.get("sub")),
|
|||
|
|
username=str(payload.get("username")),
|
|||
|
|
role_id=str(payload.get("role_id")),
|
|||
|
|
exp=exp,
|
|||
|
|
iat=int(payload.get("iat")),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def access_token_secret() -> bytes:
|
|||
|
|
"""获取 token 签名密钥(来自环境变量 SMARTEDT_AUTH_SECRET)。"""
|
|||
|
|
secret = os.getenv("SMARTEDT_AUTH_SECRET", "").strip()
|
|||
|
|
if not secret:
|
|||
|
|
secret = "smartedt-dev-secret-change-me"
|
|||
|
|
return secret.encode("utf-8")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sign(message: bytes) -> bytes:
|
|||
|
|
"""对 message 做 HMAC-SHA256 签名。"""
|
|||
|
|
return hmac.new(access_token_secret(), message, hashlib.sha256).digest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _b64url_encode(data: bytes) -> str:
|
|||
|
|
"""编码为不带 padding 的 base64url 字符串。"""
|
|||
|
|
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _b64url_decode(value: str) -> bytes:
|
|||
|
|
"""解码不带 padding 的 base64url 字符串。"""
|
|||
|
|
padded = value + "=" * (-len(value) % 4)
|
|||
|
|
return base64.urlsafe_b64decode(padded.encode("ascii"))
|