"""Fernet encryption utilities for sensitive credential storage.

This module provides symmetric encryption for storing sensitive values like
OAuth tokens, API keys, and other credentials. It uses Fernet (AES-128-CBC
with HMAC) from the cryptography library with keys derived from Django's
SECRET_KEY.

Key rotation is supported via Django's SECRET_KEY_FALLBACKS setting:
- Encryption always uses the current SECRET_KEY
- Decryption tries current key first, then each fallback in order
- Re-encrypt values on next write to migrate to current key

Usage:
    from core.utils.encryption import encrypt_value, decrypt_value

    # Encrypt a value
    encrypted = encrypt_value("my-secret-token")

    # Decrypt a value (tries current key, then fallbacks)
    decrypted = decrypt_value(encrypted)
"""

from __future__ import annotations

import base64
import hashlib
import logging
from functools import lru_cache

from cryptography.fernet import Fernet, InvalidToken
from django.conf import settings

logger = logging.getLogger(__name__)

# Salt for key derivation - changing this invalidates all encrypted data
_KEY_DERIVATION_SALT = b"dexxy-credential-encryption-v1"


@lru_cache(maxsize=8)
def _derive_fernet_key(secret_key: str) -> bytes:
    """Derive a Fernet-compatible key from a Django SECRET_KEY.

    Uses PBKDF2-like derivation via hashlib to create a 32-byte key,
    then base64-encodes it for Fernet compatibility.

    Args:
        secret_key: Django SECRET_KEY string

    Returns:
        Base64-encoded 32-byte key suitable for Fernet
    """
    # Use SHA-256 to derive a consistent 32-byte key
    key_material = hashlib.pbkdf2_hmac(
        "sha256",
        secret_key.encode("utf-8"),
        _KEY_DERIVATION_SALT,
        iterations=100_000,
        dklen=32,
    )
    return base64.urlsafe_b64encode(key_material)


def _get_fernet_instances() -> list[Fernet]:
    """Get Fernet instances for current key and all fallback keys.

    Returns:
        List of Fernet instances, with current key first, then fallbacks
    """
    keys = [settings.SECRET_KEY]

    # Add fallback keys if configured (Django 4.1+)
    fallbacks = getattr(settings, "SECRET_KEY_FALLBACKS", [])
    keys.extend(fallbacks)

    return [Fernet(_derive_fernet_key(key)) for key in keys]


def encrypt_value(plaintext: str) -> str:
    """Encrypt a string value using Fernet symmetric encryption.

    Always encrypts using the current SECRET_KEY. The result is a
    URL-safe base64-encoded string that can be safely stored in
    the database.

    Args:
        plaintext: The value to encrypt

    Returns:
        Encrypted value as a string (URL-safe base64)

    Raises:
        ValueError: If plaintext is empty
    """
    if not plaintext:
        raise ValueError("Cannot encrypt empty value")

    fernet_key = _derive_fernet_key(settings.SECRET_KEY)
    fernet = Fernet(fernet_key)

    encrypted_bytes = fernet.encrypt(plaintext.encode("utf-8"))
    return encrypted_bytes.decode("utf-8")


def decrypt_value(ciphertext: str) -> str:
    """Decrypt a Fernet-encrypted string value.

    Tries decryption with the current SECRET_KEY first, then each
    SECRET_KEY_FALLBACKS in order. This enables seamless key rotation.

    Args:
        ciphertext: The encrypted value to decrypt

    Returns:
        Decrypted plaintext string

    Raises:
        ValueError: If ciphertext is empty or decryption fails with all keys
    """
    if not ciphertext:
        raise ValueError("Cannot decrypt empty value")

    fernet_instances = _get_fernet_instances()
    ciphertext_bytes = ciphertext.encode("utf-8")

    for i, fernet in enumerate(fernet_instances):
        try:
            decrypted_bytes = fernet.decrypt(ciphertext_bytes)
            if i > 0:
                logger.info(
                    "Decrypted value using fallback key %d. Consider re-encrypting.",
                    i,
                )
            return decrypted_bytes.decode("utf-8")
        except InvalidToken:
            continue

    raise ValueError(
        "Failed to decrypt value with current key or any fallback keys. "
        "The value may be corrupted or encrypted with an unknown key."
    )


def is_encrypted(value: str) -> bool:
    """Check if a value appears to be Fernet-encrypted.

    This is a heuristic check based on Fernet's format:
    - URL-safe base64 encoded
    - Starts with 'gAAAAA' (version byte + timestamp prefix)

    Args:
        value: The value to check

    Returns:
        True if the value appears to be Fernet-encrypted
    """
    if not value:
        return False

    # Fernet tokens start with version byte (0x80) which base64-encodes to 'gA'
    # followed by timestamp, resulting in 'gAAAAA' prefix for recent timestamps
    return value.startswith("gAAAAA") and len(value) >= 100
