Source code for structum_lab.plugins.auth.jwt

# JWT Auth Provider
# SPDX-License-Identifier: Apache-2.0

"""
JWT Authentication Provider.

Implements AuthInterface using JSON Web Tokens (access + refresh).
"""

from __future__ import annotations

from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Any, cast

import jwt  # type: ignore[import-not-found]
from structum_lab.auth.interfaces import (
    AuthInterface,
    TokenPair,
    UserInterface,
    UserRepositoryInterface,
)

from .base import BaseAuthProvider
from .password import Argon2PasswordHasher

if TYPE_CHECKING:
    from structum_lab.config import ConfigInterface


[docs] class JWTAuthProvider(BaseAuthProvider, AuthInterface): """JWT-based authentication provider. Manages: - User authentication (verifying password) - Token generation (Access + Refresh) - Token verification - Password hashing (via Argon2) Example: >>> auth = JWTAuthProvider.from_config() >>> tokens = auth.authenticate("user", "pass", user_repo) """
[docs] def __init__( self, secret_key: str, refresh_secret_key: str | None = None, algorithm: str = "HS256", access_token_expire_minutes: int = 15, refresh_token_expire_days: int = 7, ) -> None: """Initialize JWT provider. Args: secret_key: Secret for signing access tokens refresh_secret_key: Secret for refresh tokens (defaults to secret_key) algorithm: Signing algorithm (HS256, RS256, etc.) access_token_expire_minutes: Access token TTL refresh_token_expire_days: Refresh token TTL """ self.secret_key = secret_key self.refresh_secret_key = refresh_secret_key or secret_key self.algorithm = algorithm self.access_token_expire_minutes = access_token_expire_minutes self.refresh_token_expire_days = refresh_token_expire_days # Internal hasher self._hasher = Argon2PasswordHasher()
[docs] @classmethod def from_config( cls, config_key: str | None = None, *, config: ConfigInterface | None = None, ) -> JWTAuthProvider: """Factory method to create provider from config.""" if config is None: from structum_lab.config import get_config config = get_config() key = config_key or cls.DEFAULT_CONFIG_KEY # Secrets (should actally come from secrets provider usually merged in config) secret_key_val = config.get(f"{key}.secret_key") if not secret_key_val: # Try finding in secrets specifically if config is merged # Assuming 'secret_key' is available in the merged config raise ValueError(f"Missing config: {key}.secret_key") def _get_secret(v: Any) -> str: if hasattr(v, "get_secret_value"): return str(v.get_secret_value()) return str(v) secret_key = _get_secret(secret_key_val) refresh_key_val = config.get(f"{key}.refresh_secret_key") refresh_key = _get_secret(refresh_key_val) if refresh_key_val else None return cls( secret_key=secret_key, refresh_secret_key=refresh_key, algorithm=config.get(f"{key}.algorithm", "HS256"), access_token_expire_minutes=config.get(f"{key}.access_token_expire_minutes", 15), refresh_token_expire_days=config.get(f"{key}.refresh_token_expire_days", 7), )
[docs] def authenticate( self, username: str, password: str, user_repo: UserRepositoryInterface, ) -> TokenPair | None: """Authenticate user and return tokens.""" user = user_repo.find_by_username(username) if not user: return None if not self.verify_password(password, user.hashed_password): return None return self._create_tokens(user)
[docs] def refresh( self, refresh_token: str, user_repo: UserRepositoryInterface, ) -> TokenPair | None: """Refresh tokens using refresh token.""" try: payload = jwt.decode( refresh_token, self.refresh_secret_key, algorithms=[self.algorithm], ) if payload.get("type") != "refresh": return None user_id = payload.get("sub") if not user_id: return None user = user_repo.find_by_id(user_id) if not user: return None return self._create_tokens(user) except jwt.PyJWTError: return None
[docs] def verify_access_token(self, token: str) -> dict[str, Any] | None: """Verify access token and return payload.""" try: payload = jwt.decode( token, self.secret_key, algorithms=[self.algorithm], ) if payload.get("type") != "access": return None return dict(payload) except jwt.PyJWTError: return None
[docs] def hash_password(self, password: str) -> str: """Hash password using Argon2.""" return self._hasher.hash(password)
[docs] def verify_password(self, password: str, hashed: str) -> bool: """Verify password against hash.""" return self._hasher.verify(password, hashed)
def _create_tokens(self, user: UserInterface) -> TokenPair: """Generate access and refresh tokens.""" now = datetime.now(UTC) # Access Token access_payload = { "sub": user.id, "username": user.username, "roles": user.roles, "type": "access", "exp": now + timedelta(minutes=self.access_token_expire_minutes), "iat": now, } access_token = jwt.encode(access_payload, self.secret_key, algorithm=self.algorithm) # Refresh Token refresh_payload = { "sub": user.id, "type": "refresh", "exp": now + timedelta(days=self.refresh_token_expire_days), "iat": now, } refresh_token = jwt.encode( refresh_payload, self.refresh_secret_key, algorithm=self.algorithm ) return TokenPair( access_token=access_token, refresh_token=refresh_token, expires_at=cast(datetime, access_payload["exp"]), )