"""Tests for the unified activity feed query service.

These tests verify that:
1. Unified feed correctly merges Thread and Signal items
2. Items are sorted in chronological order (most recent first)
3. Cursor-based pagination works across both tables
4. Performance meets targets for large datasets
"""

from datetime import timedelta

import pytest
from django.utils import timezone

from accounts.tests.factories import WorkspaceFactory
from messages.services.feed_query import FeedConnection, FeedCursor, get_unified_feed
from messages.tests.factories import SignalFactory, ThreadFactory


@pytest.mark.django_db
class TestFeedCursor:
    """Tests for FeedCursor encoding/decoding."""

    def test_encode_decode_roundtrip(self):
        """Test that encoding and decoding produces the same cursor."""
        now = timezone.now()
        original = FeedCursor(
            thread_time=now,
            signal_time=now - timedelta(hours=1),
        )
        encoded = original.encode()
        decoded = FeedCursor.decode(encoded)

        assert decoded.thread_time == original.thread_time
        assert decoded.signal_time == original.signal_time

    def test_encode_decode_with_none_values(self):
        """Test encoding/decoding cursor with None values."""
        original = FeedCursor(thread_time=None, signal_time=None)
        encoded = original.encode()
        decoded = FeedCursor.decode(encoded)

        assert decoded.thread_time is None
        assert decoded.signal_time is None

    def test_encode_decode_with_partial_values(self):
        """Test encoding/decoding cursor with one None value."""
        now = timezone.now()
        original = FeedCursor(thread_time=now, signal_time=None)
        encoded = original.encode()
        decoded = FeedCursor.decode(encoded)

        assert decoded.thread_time == original.thread_time
        assert decoded.signal_time is None

    def test_decode_invalid_cursor_raises_error(self):
        """Test that decoding invalid cursor raises ValueError."""
        with pytest.raises(ValueError, match="Invalid cursor format"):
            FeedCursor.decode("not-a-valid-cursor")


@pytest.mark.django_db
class TestGetUnifiedFeed:
    """Tests for get_unified_feed function."""

    def test_empty_feed(self):
        """Test feed returns empty list for workspace with no items."""
        workspace = WorkspaceFactory()
        result = get_unified_feed(workspace_id=str(workspace.id))

        assert result.items == []
        assert result.has_next_page is False

    def test_threads_only(self):
        """Test feed with only threads."""
        workspace = WorkspaceFactory()
        thread1 = ThreadFactory(workspace=workspace, activity_at=timezone.now())
        thread2 = ThreadFactory(workspace=workspace, activity_at=timezone.now() - timedelta(hours=1))

        result = get_unified_feed(workspace_id=str(workspace.id))

        assert len(result.items) == 2
        assert result.items[0] == thread1  # Most recent first
        assert result.items[1] == thread2

    def test_signals_only(self):
        """Test feed with only signals."""
        workspace = WorkspaceFactory()
        signal1 = SignalFactory(workspace=workspace, occurred_at=timezone.now())
        signal2 = SignalFactory(workspace=workspace, occurred_at=timezone.now() - timedelta(hours=1))

        result = get_unified_feed(workspace_id=str(workspace.id))

        assert len(result.items) == 2
        assert result.items[0] == signal1  # Most recent first
        assert result.items[1] == signal2

    def test_mixed_content_chronological_order(self):
        """Test feed merges threads and signals in chronological order."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        # Create interleaved items
        thread1 = ThreadFactory(workspace=workspace, activity_at=now)
        signal1 = SignalFactory(workspace=workspace, occurred_at=now - timedelta(hours=1))
        thread2 = ThreadFactory(workspace=workspace, activity_at=now - timedelta(hours=2))
        signal2 = SignalFactory(workspace=workspace, occurred_at=now - timedelta(hours=3))

        result = get_unified_feed(workspace_id=str(workspace.id))

        assert len(result.items) == 4
        assert result.items[0] == thread1
        assert result.items[1] == signal1
        assert result.items[2] == thread2
        assert result.items[3] == signal2

    def test_pagination_first_page(self):
        """Test pagination returns correct number of items."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        # Create 5 items
        for i in range(5):
            ThreadFactory(workspace=workspace, activity_at=now - timedelta(hours=i))

        result = get_unified_feed(workspace_id=str(workspace.id), first=3)

        assert len(result.items) == 3
        assert result.has_next_page is True
        assert result.end_cursor is not None

    def test_pagination_with_cursor(self):
        """Test pagination with cursor returns next page."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        # Create 5 threads
        threads = [ThreadFactory(workspace=workspace, activity_at=now - timedelta(hours=i)) for i in range(5)]

        # Get first page
        first_page = get_unified_feed(workspace_id=str(workspace.id), first=2)
        assert len(first_page.items) == 2
        assert first_page.items[0] == threads[0]
        assert first_page.items[1] == threads[1]
        assert first_page.has_next_page is True

        # Get second page
        second_page = get_unified_feed(
            workspace_id=str(workspace.id),
            first=2,
            after=first_page.end_cursor,
        )
        assert len(second_page.items) == 2
        assert second_page.items[0] == threads[2]
        assert second_page.items[1] == threads[3]
        assert second_page.has_next_page is True

        # Get last page
        last_page = get_unified_feed(
            workspace_id=str(workspace.id),
            first=2,
            after=second_page.end_cursor,
        )
        assert len(last_page.items) == 1
        assert last_page.items[0] == threads[4]
        assert last_page.has_next_page is False

    def test_pagination_with_mixed_content(self):
        """Test pagination works correctly with mixed threads and signals."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        # Create interleaved items
        items = []
        for i in range(6):
            if i % 2 == 0:
                items.append(ThreadFactory(workspace=workspace, activity_at=now - timedelta(hours=i)))
            else:
                items.append(SignalFactory(workspace=workspace, occurred_at=now - timedelta(hours=i)))

        # Get first page
        first_page = get_unified_feed(workspace_id=str(workspace.id), first=3)
        assert len(first_page.items) == 3
        assert first_page.has_next_page is True

        # Get second page
        second_page = get_unified_feed(
            workspace_id=str(workspace.id),
            first=3,
            after=first_page.end_cursor,
        )
        assert len(second_page.items) == 3
        assert second_page.has_next_page is False

        # Verify all items were returned (no duplicates, no missing)
        all_items = list(first_page.items) + list(second_page.items)
        assert len(all_items) == 6

    def test_workspace_isolation(self):
        """Test feed only returns items from specified workspace."""
        workspace1 = WorkspaceFactory()
        workspace2 = WorkspaceFactory()

        ThreadFactory(workspace=workspace1)
        SignalFactory(workspace=workspace1)
        ThreadFactory(workspace=workspace2)
        SignalFactory(workspace=workspace2)

        result1 = get_unified_feed(workspace_id=str(workspace1.id))
        result2 = get_unified_feed(workspace_id=str(workspace2.id))

        assert len(result1.items) == 2
        assert len(result2.items) == 2
        assert all(item.workspace_id == workspace1.id for item in result1.items)
        assert all(item.workspace_id == workspace2.id for item in result2.items)

    def test_feed_connection_structure(self):
        """Test FeedConnection has correct structure."""
        workspace = WorkspaceFactory()
        ThreadFactory(workspace=workspace)

        result = get_unified_feed(workspace_id=str(workspace.id))

        assert isinstance(result, FeedConnection)
        assert hasattr(result, "items")
        assert hasattr(result, "has_next_page")
        assert hasattr(result, "end_cursor")


@pytest.mark.django_db
class TestFeedFilters:
    """Tests for feed filtering capabilities."""

    def test_filter_threads_only(self):
        """Test filtering to return only threads."""
        workspace = WorkspaceFactory()
        ThreadFactory(workspace=workspace)
        ThreadFactory(workspace=workspace)
        SignalFactory(workspace=workspace)

        result = get_unified_feed(workspace_id=str(workspace.id), item_types=["thread"])

        assert len(result.items) == 2
        from messages.models import Thread

        assert all(isinstance(item, Thread) for item in result.items)

    def test_filter_signals_only(self):
        """Test filtering to return only signals."""
        workspace = WorkspaceFactory()
        ThreadFactory(workspace=workspace)
        SignalFactory(workspace=workspace)
        SignalFactory(workspace=workspace)

        result = get_unified_feed(workspace_id=str(workspace.id), item_types=["signal"])

        assert len(result.items) == 2
        from messages.models import Signal

        assert all(isinstance(item, Signal) for item in result.items)

    def test_filter_by_thread_type(self):
        """Test filtering threads by thread_type."""
        workspace = WorkspaceFactory()
        ThreadFactory(workspace=workspace, thread_type="conversation")
        ThreadFactory(workspace=workspace, thread_type="post")
        ThreadFactory(workspace=workspace, thread_type="conversation")

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["thread"],
            thread_types=["conversation"],
        )

        assert len(result.items) == 2
        assert all(item.thread_type == "conversation" for item in result.items)

    def test_filter_by_signal_type(self):
        """Test filtering signals by signal_type."""
        workspace = WorkspaceFactory()
        SignalFactory(workspace=workspace, signal_type="mention")
        SignalFactory(workspace=workspace, signal_type="star")
        SignalFactory(workspace=workspace, signal_type="mention")

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["signal"],
            signal_types=["mention"],
        )

        assert len(result.items) == 2
        assert all(item.signal_type == "mention" for item in result.items)

    def test_filter_by_multiple_signal_types(self):
        """Test filtering signals by multiple signal_types."""
        workspace = WorkspaceFactory()
        SignalFactory(workspace=workspace, signal_type="mention")
        SignalFactory(workspace=workspace, signal_type="star")
        SignalFactory(workspace=workspace, signal_type="fork")

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["signal"],
            signal_types=["mention", "star"],
        )

        assert len(result.items) == 2

    def test_filter_by_source_ids(self):
        """Test filtering by source IDs."""
        workspace = WorkspaceFactory()
        from sources.tests.factories import SourceFactory

        source1 = SourceFactory(workspace=workspace)
        source2 = SourceFactory(workspace=workspace)

        ThreadFactory(workspace=workspace, source=source1)
        ThreadFactory(workspace=workspace, source=source2)
        SignalFactory(workspace=workspace, source=source1)
        SignalFactory(workspace=workspace, source=source2)

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            source_ids=[str(source1.id)],
        )

        assert len(result.items) == 2
        assert all(item.source_id == source1.id for item in result.items)

    def test_filter_by_date_range_since(self):
        """Test filtering by since date."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=5))
        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=2))
        SignalFactory(workspace=workspace, occurred_at=now - timedelta(days=1))

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            since=now - timedelta(days=3),
        )

        assert len(result.items) == 2  # Only items within last 3 days

    def test_filter_by_date_range_until(self):
        """Test filtering by until date."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=5))
        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=2))
        SignalFactory(workspace=workspace, occurred_at=now - timedelta(days=1))

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            until=now - timedelta(days=3),
        )

        assert len(result.items) == 1  # Only items older than 3 days

    def test_filter_by_date_range_since_and_until(self):
        """Test filtering by date range with both since and until."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=10))
        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=5))
        SignalFactory(workspace=workspace, occurred_at=now - timedelta(days=3))
        ThreadFactory(workspace=workspace, activity_at=now - timedelta(days=1))

        result = get_unified_feed(
            workspace_id=str(workspace.id),
            since=now - timedelta(days=7),
            until=now - timedelta(days=2),
        )

        assert len(result.items) == 2  # Only items between 2-7 days ago

    def test_combined_filters(self):
        """Test combining multiple filters."""
        workspace = WorkspaceFactory()
        now = timezone.now()
        from sources.tests.factories import SourceFactory

        source = SourceFactory(workspace=workspace)

        # Create various items
        ThreadFactory(workspace=workspace, source=source, thread_type="conversation", activity_at=now)
        ThreadFactory(workspace=workspace, source=source, thread_type="post", activity_at=now)
        SignalFactory(workspace=workspace, source=source, signal_type="mention", occurred_at=now)

        # Filter to only conversation threads from this source
        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["thread"],
            thread_types=["conversation"],
            source_ids=[str(source.id)],
        )

        assert len(result.items) == 1
        assert result.items[0].thread_type == "conversation"


@pytest.mark.django_db
class TestFeedQueryPerformance:
    """Performance tests for feed queries.

    These tests verify query performance meets the <100ms target.
    Note: In CI, we test with smaller datasets; production targets 10k+ items.
    """

    def test_feed_query_with_moderate_dataset(self):
        """Test feed query completes quickly with moderate dataset."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        # Create 100 items (50 threads, 50 signals)
        for i in range(50):
            ThreadFactory(workspace=workspace, activity_at=now - timedelta(minutes=i))
            SignalFactory(workspace=workspace, occurred_at=now - timedelta(minutes=i + 30))

        import time

        start = time.perf_counter()
        result = get_unified_feed(workspace_id=str(workspace.id), first=20)
        elapsed_ms = (time.perf_counter() - start) * 1000

        assert len(result.items) == 20
        assert elapsed_ms < 1000  # Should be well under 1 second
        # Note: Real performance testing should be done with larger datasets

    def test_filtered_query_uses_indexes(self):
        """Test that filtered queries complete quickly (indicates index usage)."""
        workspace = WorkspaceFactory()
        now = timezone.now()

        # Create mixed items
        for i in range(30):
            ThreadFactory(
                workspace=workspace,
                thread_type="conversation" if i % 2 == 0 else "post",
                activity_at=now - timedelta(minutes=i),
            )
            SignalFactory(
                workspace=workspace,
                signal_type="mention" if i % 2 == 0 else "star",
                occurred_at=now - timedelta(minutes=i + 30),
            )

        import time

        # Test filtered query performance
        start = time.perf_counter()
        result = get_unified_feed(
            workspace_id=str(workspace.id),
            item_types=["signal"],
            signal_types=["mention"],
            first=20,
        )
        elapsed_ms = (time.perf_counter() - start) * 1000

        assert len(result.items) == 15  # Half of 30 signals are mentions
        assert elapsed_ms < 500  # Should be fast with index usage
