diff --git a/src/dstack/_internal/cli/commands/event.py b/src/dstack/_internal/cli/commands/event.py index 4fd122d449..709a14d826 100644 --- a/src/dstack/_internal/cli/commands/event.py +++ b/src/dstack/_internal/cli/commands/event.py @@ -2,7 +2,12 @@ from dataclasses import asdict from dstack._internal.cli.commands import APIBaseCommand -from dstack._internal.cli.services.events import EventListFilters, EventPaginator, print_event +from dstack._internal.cli.services.events import ( + EventListFilters, + EventPaginator, + EventTracker, + print_event, +) from dstack._internal.cli.utils.common import ( get_start_time, ) @@ -29,6 +34,12 @@ def _register(self): list_parser.set_defaults(subfunc=self._list) for parser in [self._parser, list_parser]: + parser.add_argument( + "-w", + "--watch", + help="Watch events in realtime", + action="store_true", + ) parser.add_argument( "--since", help=( @@ -106,7 +117,11 @@ def _list(self, args: argparse.Namespace): since = get_start_time(args.since) filters = _build_filters(args, self.api) - if since is not None: + if args.watch: + events = EventTracker( + client=self.api.client.events, filters=filters, since=since + ).stream_forever() + elif since is not None: events = EventPaginator(self.api.client.events).list( filters=filters, since=since, ascending=True ) diff --git a/src/dstack/_internal/cli/services/events.py b/src/dstack/_internal/cli/services/events.py index 0f0eb0f4b4..37ce9e1eb0 100644 --- a/src/dstack/_internal/cli/services/events.py +++ b/src/dstack/_internal/cli/services/events.py @@ -1,12 +1,13 @@ +import time import uuid from collections.abc import Iterator from dataclasses import asdict, dataclass -from datetime import datetime +from datetime import datetime, timedelta from typing import Optional from rich.text import Text -from dstack._internal.cli.utils.common import console +from dstack._internal.cli.utils.common import LIVE_TABLE_PROVISION_INTERVAL_SECS, console from dstack._internal.core.models.events import Event, EventTargetType from dstack._internal.server.schemas.events import LIST_EVENTS_DEFAULT_LIMIT from dstack.api.server._events import EventsAPIClient @@ -50,6 +51,82 @@ def list( prev_recorded_at = events[-1].recorded_at +class EventTracker: + """ + Tracks new events from the server. Implements a sliding window mechanism to avoid + missing events that are commited with a delay. + """ + + def __init__( + self, + client: EventsAPIClient, + filters: EventListFilters, + since: Optional[datetime], + event_delay_tolerance: timedelta = timedelta(seconds=20), + ) -> None: + self._client = client + self._filters = filters + self._since = since + self._event_delay_tolerance = event_delay_tolerance + self._seen_events: dict[uuid.UUID, _SeenEvent] = {} + self._latest_event: Optional[Event] = None + + def poll(self) -> Iterator[Event]: + """ + Fetches the next batch of events from the server. + """ + + if self._since is None and self._latest_event is None: + # First batch without `since` - fetch some recent events + event_stream = reversed(self._client.list(ascending=False, **asdict(self._filters))) + else: + configured_since = self._since or datetime.fromtimestamp(0) + latest_event_recorded_at = ( + self._latest_event.recorded_at + if self._latest_event is not None + else datetime.fromtimestamp(0) + ) + since = max( + configured_since.astimezone(), + latest_event_recorded_at.astimezone() - self._event_delay_tolerance, + ) + self._cleanup_seen_events(before=since) + event_stream = EventPaginator(self._client).list(self._filters, since, ascending=True) + + for event in event_stream: + if event.id not in self._seen_events: + self._seen_events[event.id] = _SeenEvent(recorded_at=event.recorded_at) + yield event + self._latest_event = event + + def stream_forever( + self, + update_interval: timedelta = timedelta(seconds=LIVE_TABLE_PROVISION_INTERVAL_SECS), + ) -> Iterator[Event]: + """ + Yields events as they are received from the server. + """ + + while True: + for event in self.poll(): + yield event + time.sleep(update_interval.total_seconds()) + + def _cleanup_seen_events(self, before: datetime) -> None: + ids_to_delete = { + event_id + for event_id, seen_event in self._seen_events.items() + if seen_event.recorded_at.astimezone() < before.astimezone() + } + for event_id in ids_to_delete: + del self._seen_events[event_id] + + +@dataclass +class _SeenEvent: + recorded_at: datetime + + def print_event(event: Event) -> None: recorded_at = event.recorded_at.astimezone().strftime("%Y-%m-%d %H:%M:%S") targets = ", ".join(f"{target.type} {target.name}" for target in event.targets) diff --git a/src/dstack/_internal/server/routers/events.py b/src/dstack/_internal/server/routers/events.py index 036a8b2be8..48782fb1a3 100644 --- a/src/dstack/_internal/server/routers/events.py +++ b/src/dstack/_internal/server/routers/events.py @@ -33,6 +33,10 @@ async def list_events( The results are paginated. To get the next page, pass `recorded_at` and `id` of the last event from the previous page as `prev_recorded_at` and `prev_id`. + + NOTE: Some events may become available in the API with a delay after their `recorded_at`. + This should be taken into account when using the API to monitor recent events, + so that delayed events are not missed during pagination. """ return CustomORJSONResponse( await events_services.list_events( diff --git a/src/tests/_internal/cli/services/test_events.py b/src/tests/_internal/cli/services/test_events.py new file mode 100644 index 0000000000..95aa4aaec5 --- /dev/null +++ b/src/tests/_internal/cli/services/test_events.py @@ -0,0 +1,349 @@ +import uuid +from dataclasses import asdict +from datetime import datetime, timedelta, timezone +from typing import Optional +from unittest.mock import MagicMock + +from dstack._internal.cli.services.events import EventListFilters, EventTracker +from dstack._internal.core.models.events import Event, EventTarget, EventTargetType +from dstack._internal.server.schemas.events import LIST_EVENTS_DEFAULT_LIMIT + + +class TestEventTracker: + def create_test_event( + self, + event_id: Optional[uuid.UUID] = None, + recorded_at: Optional[datetime] = None, + message: str = "Test event", + ) -> Event: + if event_id is None: + event_id = uuid.uuid4() + if recorded_at is None: + recorded_at = datetime.now(timezone.utc) + + return Event( + id=event_id, + message=message, + recorded_at=recorded_at, + actor_user_id=uuid.uuid4(), + actor_user="test_user", + targets=[ + EventTarget( + type=EventTargetType.RUN, + project_id=uuid.uuid4(), + project_name="test_project", + id=uuid.uuid4(), + name="test_run", + ) + ], + ) + + def test_poll_no_since(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=None, + event_delay_tolerance=timedelta(seconds=20), + ) + + # First poll - requests latest existing events + + event1 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 9, 0, tzinfo=timezone.utc) + ) + event2 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 10, 0, tzinfo=timezone.utc) + ) + mock_client.list.return_value = [event2, event1] # reversed due to ascending=False + + events = list(tracker.poll()) + + assert events == [event1, event2] + mock_client.list.assert_called_once_with( + ascending=False, + **asdict(filters), + ) + + # Second poll - requests events after the latest existing event + + mock_client.list.reset_mock() + mock_client.list.return_value = [] + + events = list(tracker.poll()) + + assert events == [] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=event2.recorded_at - timedelta(seconds=20), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + def test_poll_with_since(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + event_delay_tolerance=timedelta(seconds=20), + ) + + # First poll - requests events after `since` + + event1 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 9, 0, tzinfo=timezone.utc) + ) + event2 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 10, 0, tzinfo=timezone.utc) + ) + mock_client.list.return_value = [event1, event2] + + events = list(tracker.poll()) + + assert events == [event1, event2] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + # Second poll - requests events after the latest event + + mock_client.list.reset_mock() + mock_client.list.return_value = [] + + events = list(tracker.poll()) + + assert events == [] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=event2.recorded_at - timedelta(seconds=20), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + def test_poll_with_since_never_uses_prev_recorded_at_earlier_than_since(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + since = datetime(2023, 1, 1, 10, 0, tzinfo=timezone.utc) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=datetime(2023, 1, 1, 10, 0, tzinfo=timezone.utc), + event_delay_tolerance=timedelta(seconds=20), + ) + + # First poll - returns an event that is 5 seconds newer than `since` + + event1 = self.create_test_event(recorded_at=since + timedelta(seconds=5)) + mock_client.list.return_value = [event1] + + events = list(tracker.poll()) + + assert events == [event1] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=since, + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + # Second poll - prev_recorded_at should still be `since` (not event1 - 20s) + + mock_client.list.reset_mock() + mock_client.list.return_value = [] + + events = list(tracker.poll()) + + assert events == [] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=since, + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + def test_poll_no_since_always_empty_response(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=None, + event_delay_tolerance=timedelta(seconds=20), + ) + + for _ in range(2): + mock_client.list.reset_mock() + mock_client.list.return_value = [] + events = list(tracker.poll()) + assert events == [] + mock_client.list.assert_called_once_with( + ascending=False, + **asdict(filters), + ) + + def test_poll_with_since_always_empty_response(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + event_delay_tolerance=timedelta(seconds=20), + ) + + for _ in range(2): + mock_client.list.reset_mock() + mock_client.list.return_value = [] + events = list(tracker.poll()) + assert events == [] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + def test_poll_event_deduplication(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + event_delay_tolerance=timedelta(seconds=20), + ) + + # First poll - returns event1 and event2 + + event1 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 9, 0, tzinfo=timezone.utc) + ) + event2 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 10, 0, tzinfo=timezone.utc) + ) + mock_client.list.return_value = [event1, event2] + + events = list(tracker.poll()) + + assert events == [event1, event2] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + # Second poll - returns event2 (duplicate) and event3 (new) + + mock_client.list.reset_mock() + event3 = self.create_test_event( + recorded_at=datetime(2023, 1, 1, 11, 0, tzinfo=timezone.utc) + ) + mock_client.list.return_value = [event2, event3] + + events = list(tracker.poll()) + + assert events == [event3] # does not return duplicate event2 + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=event2.recorded_at - timedelta(seconds=20), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + ) + + def test_poll_respects_pagination(self): + mock_client = MagicMock() + filters = EventListFilters(target_runs=[uuid.uuid4()]) + + tracker = EventTracker( + client=mock_client, + filters=filters, + since=datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc), + event_delay_tolerance=timedelta(seconds=20), + ) + + ### + # First poll - create (1.5 * default limit) events + ### + + num_events = int(LIST_EVENTS_DEFAULT_LIMIT * 1.5) + events = [ + self.create_test_event( + recorded_at=datetime(2023, 1, 1, 9, 0, tzinfo=timezone.utc) + timedelta(seconds=i) + ) + for i in range(num_events) + ] + + # Mock pagination: first call returns first batch, second call returns remaining events + call_count = 0 + + def mock_list(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return events[:LIST_EVENTS_DEFAULT_LIMIT] # First batch + elif call_count == 2: + return events[LIST_EVENTS_DEFAULT_LIMIT:] # Remaining events + else: + return [] + + mock_client.list.side_effect = mock_list + + result_events = list(tracker.poll()) + + assert result_events == events + assert mock_client.list.call_count == 2 + + # Verify first call + first_call = mock_client.list.call_args_list[0] + assert first_call[1]["ascending"] == True + assert first_call[1]["prev_recorded_at"] == datetime(2023, 1, 1, 8, 0, tzinfo=timezone.utc) + assert first_call[1]["prev_id"] is None + assert first_call[1]["limit"] == LIST_EVENTS_DEFAULT_LIMIT + + # Verify second call (pagination) + second_call = mock_client.list.call_args_list[1] + assert second_call[1]["ascending"] == True + assert ( + second_call[1]["prev_recorded_at"] == events[LIST_EVENTS_DEFAULT_LIMIT - 1].recorded_at + ) + assert second_call[1]["prev_id"] == events[LIST_EVENTS_DEFAULT_LIMIT - 1].id + assert second_call[1]["limit"] == LIST_EVENTS_DEFAULT_LIMIT + + ### + # Second poll - should make one call for new events + ### + + mock_client.reset_mock() + mock_client.list.return_value = [] + + result_events = list(tracker.poll()) + + assert result_events == [] + mock_client.list.assert_called_once_with( + ascending=True, + **asdict(filters), + prev_recorded_at=events[-1].recorded_at - timedelta(seconds=20), + prev_id=None, + limit=LIST_EVENTS_DEFAULT_LIMIT, + )