"""Unified activity feed query service.

This module provides the FeedQueryService for querying the unified activity feed
that combines Thread and Signal items in chronological order.

Usage:
    from messages.services.feed_query import get_unified_feed, FeedCursor

    result = get_unified_feed(
        workspace_id=str(workspace.id),
        first=20,
        after=cursor,
    )

See also:
    - research.md for pagination design decisions
    - data-model.md for model definitions
"""

from __future__ import annotations

import base64
import heapq
import itertools
import json
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, cast
from uuid import UUID

from messages.models import Signal, Thread

if TYPE_CHECKING:
    from collections.abc import Sequence


@dataclass
class FeedCursor:
    """Cursor for paginating across Thread and Signal tables.

    Encodes the last seen timestamp for both item types to enable
    efficient cursor-based pagination across two tables.

    Attributes:
        thread_time: Last Thread.activity_at seen, or None if no threads yet
        signal_time: Last Signal.occurred_at seen, or None if no signals yet
    """

    thread_time: datetime | None = None
    signal_time: datetime | None = None

    def encode(self) -> str:
        """Encode cursor to base64 string for API transport."""
        data = {
            "t": self.thread_time.isoformat() if self.thread_time else None,
            "s": self.signal_time.isoformat() if self.signal_time else None,
        }
        return base64.b64encode(json.dumps(data).encode()).decode()

    @classmethod
    def decode(cls, cursor: str) -> FeedCursor:
        """Decode cursor from base64 string."""
        try:
            data = json.loads(base64.b64decode(cursor))
            return cls(
                thread_time=datetime.fromisoformat(data["t"]) if data.get("t") else None,
                signal_time=datetime.fromisoformat(data["s"]) if data.get("s") else None,
            )
        except (json.JSONDecodeError, KeyError, ValueError) as e:
            raise ValueError(f"Invalid cursor format: {e}") from e


@dataclass
class FeedConnection:
    """Connection type for paginated feed results.

    Attributes:
        items: List of Thread and Signal items in chronological order
        has_next_page: Whether more items exist after this page
        end_cursor: Cursor to use for fetching the next page, or None if no more pages
    """

    items: Sequence[Thread | Signal]
    has_next_page: bool
    end_cursor: FeedCursor | None


def _get_item_timestamp(item: Thread | Signal) -> datetime:
    """Get the timestamp used for sorting from a feed item."""
    if isinstance(item, Thread):
        return item.activity_at or item.created_at
    return item.occurred_at


def get_unified_feed(
    *,
    workspace_id: str | UUID,
    first: int = 20,
    after: FeedCursor | None = None,
    item_types: list[str] | None = None,
    thread_types: list[str] | None = None,
    signal_types: list[str] | None = None,
    source_ids: list[str | UUID] | None = None,
    since: datetime | None = None,
    until: datetime | None = None,
) -> FeedConnection:
    """Query the unified activity feed combining Threads and Signals.

    Performs application-level merge of Thread and Signal querysets,
    sorted by timestamp in descending order (most recent first).

    Args:
        workspace_id: The workspace to query (required, keyword-only)
        first: Maximum number of items to return (default 20)
        after: Cursor for pagination (optional)
        item_types: Filter by item type - ["thread"], ["signal"], or both (optional)
        thread_types: Filter threads by thread_type - ["conversation", "post"] (optional)
        signal_types: Filter signals by signal_type - ["mention", "star", etc.] (optional)
        source_ids: Filter by source IDs (optional)
        since: Only include items with timestamp >= since (optional)
        until: Only include items with timestamp <= until (optional)

    Returns:
        FeedConnection with items, pagination info, and next cursor

    Example:
        # Get all feed items
        result = get_unified_feed(workspace_id=str(workspace.id), first=20)

        # Get only threads
        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["thread"],
        )

        # Get only mentions
        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["signal"],
            signal_types=["mention"],
        )
    """
    items: list[Thread | Signal] = []

    # Determine which item types to query
    include_threads = item_types is None or "thread" in item_types
    include_signals = item_types is None or "signal" in item_types

    thread_list: list[Thread] = []
    signal_list: list[Signal] = []

    # Query threads if not filtered out
    if include_threads:
        threads = Thread.objects.filter(workspace_id=str(workspace_id))

        # Apply cursor pagination
        if after and after.thread_time:
            threads = threads.filter(activity_at__lt=after.thread_time)

        # Apply thread_type filter
        if thread_types:
            threads = threads.filter(thread_type__in=thread_types)

        # Apply source filter
        if source_ids:
            threads = threads.filter(source_id__in=source_ids)

        # Apply date range filters
        if since:
            threads = threads.filter(activity_at__gte=since)
        if until:
            threads = threads.filter(activity_at__lte=until)

        thread_list = list(threads.order_by("-activity_at")[: first + 1])
        items.extend(thread_list)

    # Query signals if not filtered out
    if include_signals:
        signals = Signal.objects.filter(workspace_id=str(workspace_id))

        # Apply cursor pagination
        if after and after.signal_time:
            signals = signals.filter(occurred_at__lt=after.signal_time)

        # Apply signal_type filter
        if signal_types:
            signals = signals.filter(signal_type__in=signal_types)

        # Apply source filter
        if source_ids:
            signals = signals.filter(source_id__in=source_ids)

        # Apply date range filters
        if since:
            signals = signals.filter(occurred_at__gte=since)
        if until:
            signals = signals.filter(occurred_at__lte=until)

        signal_list = list(signals.order_by("-occurred_at")[: first + 1])
        items.extend(signal_list)

    # Merge and sort by timestamp (descending)
    # heapq.merge expects ascending order, so we negate for descending
    merged = heapq.merge(
        cast(list[Thread | Signal], thread_list),
        cast(list[Thread | Signal], signal_list),
        key=lambda x: _get_item_timestamp(x),
        reverse=True,
    )
    items = list(itertools.islice(merged, first + 1))

    # Check if there are more items
    has_next_page = len(items) > first
    items = items[:first]

    # Build end cursor from last items of each type in this page
    # The cursor tracks where we left off for each item type
    end_cursor = None
    if items:
        last_thread_time: datetime | None = None
        last_signal_time: datetime | None = None

        # Find the last item of each type in the returned items
        for item in reversed(items):
            if isinstance(item, Thread) and last_thread_time is None:
                last_thread_time = _get_item_timestamp(item)
            elif isinstance(item, Signal) and last_signal_time is None:
                last_signal_time = _get_item_timestamp(item)
            if last_thread_time is not None and last_signal_time is not None:
                break

        # If a type wasn't in this page, use the previous cursor value
        # (we haven't moved past any items of that type)
        if after:
            if last_thread_time is None:
                last_thread_time = after.thread_time
            if last_signal_time is None:
                last_signal_time = after.signal_time

        end_cursor = FeedCursor(
            thread_time=last_thread_time,
            signal_time=last_signal_time,
        )

    return FeedConnection(
        items=items,
        has_next_page=has_next_page,
        end_cursor=end_cursor,
    )
