# JWT Auth Provider
# SPDX-License-Identifier: Apache-2.0
"""
JWT Authentication Provider.
Implements AuthInterface using JSON Web Tokens (access + refresh).
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Any, cast
import jwt # type: ignore[import-not-found]
from structum_lab.auth.interfaces import (
AuthInterface,
TokenPair,
UserInterface,
UserRepositoryInterface,
)
from .base import BaseAuthProvider
from .password import Argon2PasswordHasher
if TYPE_CHECKING:
from structum_lab.config import ConfigInterface
[docs]
class JWTAuthProvider(BaseAuthProvider, AuthInterface):
"""JWT-based authentication provider.
Manages:
- User authentication (verifying password)
- Token generation (Access + Refresh)
- Token verification
- Password hashing (via Argon2)
Example:
>>> auth = JWTAuthProvider.from_config()
>>> tokens = auth.authenticate("user", "pass", user_repo)
"""
[docs]
def __init__(
self,
secret_key: str,
refresh_secret_key: str | None = None,
algorithm: str = "HS256",
access_token_expire_minutes: int = 15,
refresh_token_expire_days: int = 7,
) -> None:
"""Initialize JWT provider.
Args:
secret_key: Secret for signing access tokens
refresh_secret_key: Secret for refresh tokens (defaults to secret_key)
algorithm: Signing algorithm (HS256, RS256, etc.)
access_token_expire_minutes: Access token TTL
refresh_token_expire_days: Refresh token TTL
"""
self.secret_key = secret_key
self.refresh_secret_key = refresh_secret_key or secret_key
self.algorithm = algorithm
self.access_token_expire_minutes = access_token_expire_minutes
self.refresh_token_expire_days = refresh_token_expire_days
# Internal hasher
self._hasher = Argon2PasswordHasher()
[docs]
@classmethod
def from_config(
cls,
config_key: str | None = None,
*,
config: ConfigInterface | None = None,
) -> JWTAuthProvider:
"""Factory method to create provider from config."""
if config is None:
from structum_lab.config import get_config
config = get_config()
key = config_key or cls.DEFAULT_CONFIG_KEY
# Secrets (should actally come from secrets provider usually merged in config)
secret_key_val = config.get(f"{key}.secret_key")
if not secret_key_val:
# Try finding in secrets specifically if config is merged
# Assuming 'secret_key' is available in the merged config
raise ValueError(f"Missing config: {key}.secret_key")
def _get_secret(v: Any) -> str:
if hasattr(v, "get_secret_value"):
return str(v.get_secret_value())
return str(v)
secret_key = _get_secret(secret_key_val)
refresh_key_val = config.get(f"{key}.refresh_secret_key")
refresh_key = _get_secret(refresh_key_val) if refresh_key_val else None
return cls(
secret_key=secret_key,
refresh_secret_key=refresh_key,
algorithm=config.get(f"{key}.algorithm", "HS256"),
access_token_expire_minutes=config.get(f"{key}.access_token_expire_minutes", 15),
refresh_token_expire_days=config.get(f"{key}.refresh_token_expire_days", 7),
)
[docs]
def authenticate(
self,
username: str,
password: str,
user_repo: UserRepositoryInterface,
) -> TokenPair | None:
"""Authenticate user and return tokens."""
user = user_repo.find_by_username(username)
if not user:
return None
if not self.verify_password(password, user.hashed_password):
return None
return self._create_tokens(user)
[docs]
def refresh(
self,
refresh_token: str,
user_repo: UserRepositoryInterface,
) -> TokenPair | None:
"""Refresh tokens using refresh token."""
try:
payload = jwt.decode(
refresh_token,
self.refresh_secret_key,
algorithms=[self.algorithm],
)
if payload.get("type") != "refresh":
return None
user_id = payload.get("sub")
if not user_id:
return None
user = user_repo.find_by_id(user_id)
if not user:
return None
return self._create_tokens(user)
except jwt.PyJWTError:
return None
[docs]
def verify_access_token(self, token: str) -> dict[str, Any] | None:
"""Verify access token and return payload."""
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
)
if payload.get("type") != "access":
return None
return dict(payload)
except jwt.PyJWTError:
return None
[docs]
def hash_password(self, password: str) -> str:
"""Hash password using Argon2."""
return self._hasher.hash(password)
[docs]
def verify_password(self, password: str, hashed: str) -> bool:
"""Verify password against hash."""
return self._hasher.verify(password, hashed)
def _create_tokens(self, user: UserInterface) -> TokenPair:
"""Generate access and refresh tokens."""
now = datetime.now(UTC)
# Access Token
access_payload = {
"sub": user.id,
"username": user.username,
"roles": user.roles,
"type": "access",
"exp": now + timedelta(minutes=self.access_token_expire_minutes),
"iat": now,
}
access_token = jwt.encode(access_payload, self.secret_key, algorithm=self.algorithm)
# Refresh Token
refresh_payload = {
"sub": user.id,
"type": "refresh",
"exp": now + timedelta(days=self.refresh_token_expire_days),
"iat": now,
}
refresh_token = jwt.encode(
refresh_payload, self.refresh_secret_key, algorithm=self.algorithm
)
return TokenPair(
access_token=access_token,
refresh_token=refresh_token,
expires_at=cast(datetime, access_payload["exp"]),
)