"""Authentication mutations for GraphQL."""

from datetime import UTC, datetime, timedelta

import jwt
import strawberry
from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError

from accounts.models import Workspace, WorkspaceMembership

from ..types import AuthPayload, AuthTokens, LoginInput, RefreshTokenInput, RegisterInput, UserType

User = get_user_model()


def generate_unique_slug(base_slug: str) -> str:
    """Generate a unique workspace slug, appending a number if needed."""
    slug = base_slug
    counter = 1
    while Workspace.objects.filter(slug=slug).exists():
        slug = f"{base_slug}-{counter}"
        counter += 1
    return slug


# Token expiration settings
ACCESS_TOKEN_EXPIRY_MINUTES = 60
REFRESH_TOKEN_EXPIRY_DAYS = 7


def create_access_token(user_id: int) -> str:
    """Create a JWT access token for a user."""
    now = datetime.now(UTC)
    payload = {
        "user_id": user_id,
        "exp": now + timedelta(minutes=ACCESS_TOKEN_EXPIRY_MINUTES),
        "iat": now,
        "type": "access",
    }
    return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")


def create_refresh_token(user_id: int) -> str:
    """Create a JWT refresh token for a user."""
    now = datetime.now(UTC)
    payload = {
        "user_id": user_id,
        "exp": now + timedelta(days=REFRESH_TOKEN_EXPIRY_DAYS),
        "iat": now,
        "type": "refresh",
    }
    return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")


def create_tokens(user: User) -> AuthTokens:
    """Create access and refresh tokens for a user."""
    return AuthTokens(
        access_token=create_access_token(user.pk),
        refresh_token=create_refresh_token(user.pk),
    )


@strawberry.type
class AuthMutations:
    """Authentication mutations."""

    @strawberry.mutation
    def login(self, input: LoginInput) -> AuthPayload:
        """Authenticate user and return tokens."""
        user = authenticate(username=input.username, password=input.password)

        if user is None:
            raise Exception("Invalid credentials") from None

        if not user.is_active:
            raise Exception("User account is disabled") from None

        return AuthPayload(
            user=UserType.from_model(user),
            tokens=create_tokens(user),
        )

    @strawberry.mutation
    def register(self, input: RegisterInput) -> AuthPayload:
        """Register a new user and return tokens."""
        # Validate password match
        if input.password != input.password_confirm:
            raise Exception("Passwords do not match") from None

        # Check if username already exists
        if User.objects.filter(username=input.username).exists():
            raise Exception("Username already exists") from None

        # Check if email already exists
        if User.objects.filter(email=input.email).exists():
            raise Exception("Email already exists") from None

        # Validate password strength
        try:
            validate_password(input.password)
        except ValidationError as e:
            raise Exception(", ".join(e.messages)) from None

        # Create the user
        user = User.objects.create_user(
            username=input.username,
            email=input.email,
            password=input.password,
            first_name=input.first_name or "",
            last_name=input.last_name or "",
        )

        # Create default workspace for the new user
        workspace_slug = generate_unique_slug(input.username)
        workspace = Workspace.objects.create(
            name=f"{input.username}'s Workspace",
            slug=workspace_slug,
        )
        WorkspaceMembership.objects.create(
            user=user,
            workspace=workspace,
            role="owner",
        )

        return AuthPayload(
            user=UserType.from_model(user),
            tokens=create_tokens(user),
        )

    @strawberry.mutation
    def refresh_token(self, input: RefreshTokenInput) -> AuthTokens:
        """Refresh access token using refresh token."""
        try:
            payload = jwt.decode(
                input.refresh_token,
                settings.SECRET_KEY,
                algorithms=["HS256"],
            )

            # Verify it's a refresh token
            if payload.get("type") != "refresh":
                raise Exception("Invalid token type") from None

            user_id = payload.get("user_id")
            if not user_id:
                raise Exception("Invalid token") from None

            # Verify user still exists and is active
            try:
                user = User.objects.get(pk=user_id)
                if not user.is_active:
                    raise Exception("User account is disabled") from None
            except User.DoesNotExist:
                raise Exception("User not found") from None

            # Create new token pair
            return create_tokens(user)

        except jwt.ExpiredSignatureError:
            raise Exception("Refresh token has expired") from None
        except jwt.InvalidTokenError:
            raise Exception("Invalid refresh token") from None
