From 4e7b1a36556fcda99c39e47c81bda261b6e81196 Mon Sep 17 00:00:00 2001 From: Morgan Wowk Date: Mon, 12 Jan 2026 12:47:34 -0800 Subject: [PATCH] feat: Introduce trace ids to Tangle **Changes:** * Adds logging context helpers * Add request middleware to generate unique request id and set it in the logging context around API requests * Sets the x-tangle-request-id on the response for client consumption --- api_server_main.py | 12 +- cloud_pipelines_backend/api_router.py | 15 +- .../instrumentation/logging_context.py | 111 +++++++ .../instrumentation/request_middleware.py | 64 ++++ cloud_pipelines_backend/orchestrator_sql.py | 125 ++++---- start_local.py | 64 +++- tests/test_instrumentation_logging_context.py | 168 ++++++++++ ...test_instrumentation_request_middleware.py | 303 ++++++++++++++++++ tests/test_request_id_concurrency.py | 234 ++++++++++++++ 9 files changed, 1037 insertions(+), 59 deletions(-) create mode 100644 cloud_pipelines_backend/instrumentation/logging_context.py create mode 100644 cloud_pipelines_backend/instrumentation/request_middleware.py create mode 100644 tests/test_instrumentation_logging_context.py create mode 100644 tests/test_instrumentation_request_middleware.py create mode 100644 tests/test_request_id_concurrency.py diff --git a/api_server_main.py b/api_server_main.py index 18562c6..7011347 100644 --- a/api_server_main.py +++ b/api_server_main.py @@ -5,6 +5,8 @@ from cloud_pipelines_backend import api_router from cloud_pipelines_backend import database_ops +from cloud_pipelines_backend.instrumentation.request_middleware import RequestContextMiddleware +from cloud_pipelines_backend.instrumentation import logging_context app = fastapi.FastAPI( title="Cloud Pipelines API", @@ -12,14 +14,22 @@ separate_input_output_schemas=False, ) +# Add request context middleware for automatic request_id generation +app.add_middleware(RequestContextMiddleware) + @app.exception_handler(Exception) def handle_error(request: fastapi.Request, exc: BaseException): exception_str = traceback.format_exception(type(exc), exc, exc.__traceback__) - return fastapi.responses.JSONResponse( + response = fastapi.responses.JSONResponse( status_code=503, content={"exception": exception_str}, ) + # Add request_id to error responses for traceability + request_id = logging_context.get_context_metadata("request_id") + if request_id: + response.headers["x-tangle-request-id"] = request_id + return response DEFAULT_DATABASE_URI = "sqlite:///db.sqlite" diff --git a/cloud_pipelines_backend/api_router.py b/cloud_pipelines_backend/api_router.py index 6652637..9b49bb4 100644 --- a/cloud_pipelines_backend/api_router.py +++ b/cloud_pipelines_backend/api_router.py @@ -15,6 +15,7 @@ from . import component_library_api_server as components_api from . import database_ops from . import errors +from .instrumentation import logging_context if typing.TYPE_CHECKING: from .launchers import interfaces as launcher_interfaces @@ -95,17 +96,27 @@ def _setup_routes_internal( @app.exception_handler(errors.ItemNotFoundError) def handle_not_found_error(request: fastapi.Request, exc: errors.ItemNotFoundError): - return fastapi.responses.JSONResponse( + response = fastapi.responses.JSONResponse( status_code=404, content={"message": str(exc)}, ) + # Add request_id to error responses for traceability + request_id = logging_context.get_context_metadata("request_id") + if request_id: + response.headers["x-tangle-request-id"] = request_id + return response @app.exception_handler(errors.PermissionError) def handle_permission_error(request: fastapi.Request, exc: errors.PermissionError): - return fastapi.responses.JSONResponse( + response = fastapi.responses.JSONResponse( status_code=403, content={"message": str(exc)}, ) + # Add request_id to error responses for traceability + request_id = logging_context.get_context_metadata("request_id") + if request_id: + response.headers["x-tangle-request-id"] = request_id + return response get_user_details_dependency = fastapi.Depends(user_details_getter) diff --git a/cloud_pipelines_backend/instrumentation/logging_context.py b/cloud_pipelines_backend/instrumentation/logging_context.py new file mode 100644 index 0000000..7237734 --- /dev/null +++ b/cloud_pipelines_backend/instrumentation/logging_context.py @@ -0,0 +1,111 @@ +"""Logging context management for distributed tracing and execution tracking. + +This module provides utilities for managing arbitrary metadata in the logging context. +This metadata is automatically added to all log records for better filtering and correlation. + +Common metadata keys: +- request_id: From API requests - groups all logs from a single API call +- pipeline_run_id: From PipelineRun.id - tracks the entire pipeline run +- execution_id: From ExecutionNode.id - tracks individual execution nodes +- container_execution_id: From ContainerExecution.id - tracks running containers +- user_id: User who initiated the operation +- Any other metadata you want to track in logs + +Usage: + # Set metadata in context + with logging_context(request_id="abc123", user_id="user@example.com"): + logger.info("Processing") # Both fields in logs +""" + +import contextvars +from contextlib import contextmanager +from typing import Any, Optional + + +# Single context variable to store all metadata as a dictionary +_context_metadata: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( + "context_metadata", default={} +) + + +def set_context_metadata(key: str, value: Any) -> None: + """Set a metadata value in the current context. + + Args: + key: The metadata key (e.g., 'execution_id', 'request_id', 'user_id') + value: The value to set + """ + metadata = _context_metadata.get().copy() + metadata[key] = value + _context_metadata.set(metadata) + + +def get_context_metadata(key: str) -> Optional[Any]: + """Get a metadata value from the current context. + + Args: + key: The metadata key to retrieve + + Returns: + The metadata value or None if not set + """ + return _context_metadata.get().get(key) + + +def get_all_context_metadata() -> dict[str, Any]: + """Get all metadata from the current context. + + Returns: + Dictionary of all context metadata + """ + return _context_metadata.get().copy() + + +def clear_context_metadata() -> None: + """Clear all metadata from the current context.""" + _context_metadata.set({}) + + +@contextmanager +def logging_context(**metadata: Any): + """Context manager for setting arbitrary metadata that is automatically cleared. + + This is the recommended way to set logging context. It ensures metadata is + always cleaned up, even if an exception occurs. + + You can pass any keyword arguments, and they will be available in log records. + Common keys include: request_id, pipeline_run_id, execution_id, container_execution_id, user_id + + Args: + **metadata: Arbitrary keyword arguments to add to the context + + Example with IDs: + >>> with logging_context(pipeline_run_id="run123", execution_id="exec456"): + ... logger.info("Processing execution") # Will include both IDs + + Example with custom metadata: + >>> with logging_context( + ... execution_id="exec456", + ... user_id="user@example.com", + ... operation="reprocessing" + ... ): + ... logger.info("Custom operation") # All metadata in logs + + Example for API requests: + >>> request_id = generate_request_id() + >>> with logging_context(request_id=request_id): + ... logger.info("Handling API request") + """ + # Store previous metadata to restore nested contexts + prev_metadata = get_all_context_metadata() + + try: + # Set all provided metadata + for key, value in metadata.items(): + if value is not None: # Only set non-None values + set_context_metadata(key, value) + yield + finally: + # Restore previous metadata + _context_metadata.set(prev_metadata) + diff --git a/cloud_pipelines_backend/instrumentation/request_middleware.py b/cloud_pipelines_backend/instrumentation/request_middleware.py new file mode 100644 index 0000000..ed524d0 --- /dev/null +++ b/cloud_pipelines_backend/instrumentation/request_middleware.py @@ -0,0 +1,64 @@ +"""Request context middleware for FastAPI applications. + +This middleware automatically generates a request_id for each incoming HTTP request, +sets it in the logging context for the duration of the request, and includes it in +the response headers. +""" + +import logging +import secrets + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from . import logging_context + +logger = logging.getLogger(__name__) + + +def generate_request_id() -> str: + """Generate a new request ID compatible with OpenTelemetry format. + + OpenTelemetry trace IDs are 16-byte (128-bit) values represented as + 32 hexadecimal characters (lowercase). We use the same format for + request IDs to maintain compatibility. + + Returns: + A 32-character hexadecimal string representing the request ID + """ + return secrets.token_hex(16) + + +class RequestContextMiddleware(BaseHTTPMiddleware): + """Middleware to manage request_id for each request. + + For each incoming request: + 1. Generates a new request_id (32-character hex string) + 2. Sets it in the logging context (as 'request_id' key) + 3. Adds it to the response headers as 'x-tangle-request-id' + 4. Clears it after the request completes + + This ensures all logs during the request processing include the same request_id. + """ + + async def dispatch(self, request: Request, call_next) -> Response: + """Process each request with a new request_id. + + Args: + request: The incoming HTTP request + call_next: The next middleware or route handler + + Returns: + The HTTP response with request_id in headers + """ + # Generate a new request_id for this request + request_id = generate_request_id() + + # Use generic logging_context to set request_id + with logging_context.logging_context(request_id=request_id): + # Process the request + response = await call_next(request) + # Add request_id to response headers for client reference + response.headers["x-tangle-request-id"] = request_id + return response diff --git a/cloud_pipelines_backend/orchestrator_sql.py b/cloud_pipelines_backend/orchestrator_sql.py index 282d0c4..3c0e3f7 100644 --- a/cloud_pipelines_backend/orchestrator_sql.py +++ b/cloud_pipelines_backend/orchestrator_sql.py @@ -21,6 +21,8 @@ from .launchers import common_annotations from .launchers import interfaces as launcher_interfaces +# Import logging_context for execution ID tracking in logs +from .instrumentation import logging_context _logger = logging.getLogger(__name__) @@ -94,20 +96,26 @@ def internal_process_queued_executions_queue(self, session: orm.Session): queued_execution = session.scalar(query) if queued_execution: self._queued_executions_queue_idle = False - _logger.info(f"Before processing {queued_execution.id=}") - try: - self.internal_process_one_queued_execution( - session=session, execution=queued_execution - ) - except Exception as ex: - _logger.exception(f"Error processing {queued_execution.id=}") - session.rollback() - queued_execution.container_execution_status = ( - bts.ContainerExecutionStatus.SYSTEM_ERROR - ) - record_system_error_exception(execution=queued_execution, exception=ex) - session.commit() - _logger.info(f"After processing {queued_execution.id=}") + + # Set execution context for logging + with logging_context.logging_context( + execution_id=queued_execution.id + ): + _logger.info("Before processing queued execution") + try: + self.internal_process_one_queued_execution( + session=session, execution=queued_execution + ) + except Exception as ex: + _logger.exception("Error processing queued execution") + session.rollback() + queued_execution.container_execution_status = ( + bts.ContainerExecutionStatus.SYSTEM_ERROR + ) + record_system_error_exception(execution=queued_execution, exception=ex) + session.commit() + _logger.info("After processing queued execution") + return True else: if not self._queued_executions_queue_idle: @@ -132,37 +140,46 @@ def internal_process_running_executions_queue(self, session: orm.Session): running_container_execution = session.scalar(query) if running_container_execution: self._running_executions_queue_idle = False - try: - _logger.info(f"Before processing {running_container_execution.id=}") - self.internal_process_one_running_execution( - session=session, container_execution=running_container_execution - ) - _logger.info(f"After processing {running_container_execution.id=}") - except Exception as ex: - _logger.exception(f"Error processing {running_container_execution.id=}") - session.rollback() - running_container_execution.status = ( - bts.ContainerExecutionStatus.SYSTEM_ERROR - ) - # Doing an intermediate commit here because it's most important to mark the problematic execution as SYSTEM_ERROR. - session.commit() - # Mark our ExecutionNode as SYSTEM_ERROR - execution_nodes = running_container_execution.execution_nodes - for execution_node in execution_nodes: - execution_node.container_execution_status = ( - bts.ContainerExecutionStatus.SYSTEM_ERROR - ) - record_system_error_exception( - execution=execution_node, exception=ex + + # Set execution context for logging (includes container_execution_id) + # Get first execution_id for context (there may be multiple nodes using same container) + execution_nodes = running_container_execution.execution_nodes + execution_id = execution_nodes[0].id if execution_nodes else None + + with logging_context.logging_context( + execution_id=execution_id, + container_execution_id=running_container_execution.id + ): + _logger.info("Before processing running container execution") + try: + self.internal_process_one_running_execution( + session=session, container_execution=running_container_execution ) - # Doing an intermediate commit here because it's most important to mark the problematic node as SYSTEM_ERROR. - session.commit() - # Skip downstream executions - for execution_node in execution_nodes: - _mark_all_downstream_executions_as_skipped( - session=session, execution=execution_node + except Exception as ex: + _logger.exception("Error processing running container execution") + session.rollback() + running_container_execution.status = ( + bts.ContainerExecutionStatus.SYSTEM_ERROR ) - session.commit() + # Doing an intermediate commit here because it's most important to mark the problematic execution as SYSTEM_ERROR. + session.commit() + # Mark our ExecutionNode as SYSTEM_ERROR + for execution_node in execution_nodes: + execution_node.container_execution_status = ( + bts.ContainerExecutionStatus.SYSTEM_ERROR + ) + record_system_error_exception( + execution=execution_node, exception=ex + ) + # Doing an intermediate commit here because it's most important to mark the problematic node as SYSTEM_ERROR. + session.commit() + # Skip downstream executions + for execution_node in execution_nodes: + _mark_all_downstream_executions_as_skipped( + session=session, execution=execution_node + ) + session.commit() + _logger.info("After processing running container execution") return True else: if not self._running_executions_queue_idle: @@ -276,7 +293,7 @@ def internal_process_one_queued_execution( # There must be at least one SUCCEEDED/RUNNING/PENDING since non_purged_candidates is non-empty. old_execution = non_purged_candidates[-1] _logger.info( - f"Execution {execution.id=} will reuse the {old_execution.id=} with " + f"Reusing execution from cache \"{old_execution.id}\" with " f"{old_execution.container_execution_id=}, {old_execution.container_execution_status=}" ) # Reusing the execution: @@ -577,18 +594,18 @@ def internal_process_one_running_execution( terminated = False if votes_to_not_terminate: _logger.info( - f"Not terminating container execution {container_execution.id=} since some other executions ({[execution_node.id for execution_node in votes_to_not_terminate]}) are still using it." + f"Not terminating container execution since some other executions ({[execution_node.id for execution_node in votes_to_not_terminate]}) are still using it." ) else: _logger.info( - f"Terminating container execution {container_execution.id}." + "Terminating container execution." ) # We should preserve the logs before terminating/deleting the container try: _retry(lambda: launched_container.upload_log()) except: _logger.exception( - f"Error uploading logs for {container_execution.id=} before termination." + "Error uploading logs before termination." ) # Requesting container termination. # Termination might not happen immediately (e.g. Kubernetes has grace period). @@ -601,7 +618,7 @@ def internal_process_one_running_execution( # Mark the execution nodes as cancelled only after the launched container is successfully terminated (if needed) for execution_node in votes_to_terminate: _logger.info( - f"Cancelling execution {execution_node.id} ({container_execution.id=}) and skipping all downstream executions." + f"Cancelling execution {execution_node.id} and skipping all downstream executions." ) execution_node.container_execution_status = ( bts.ContainerExecutionStatus.CANCELLED @@ -630,23 +647,23 @@ def internal_process_one_running_execution( ) if new_status == previous_status: _logger.info( - f"Container execution {container_execution.id} remains in {new_status} state." + f"Container execution remains in {new_status} state." ) return _logger.info( - f"Container execution {container_execution.id} is now in state {new_status} (was {previous_status})." + f"Container execution is now in state {new_status} (was {previous_status})." ) session.rollback() container_execution.updated_at = current_time execution_nodes = container_execution.execution_nodes if not execution_nodes: raise OrchestratorError( - f"Could not find ExecutionNode associated with ContainerExecution. {container_execution.id=}" + f"Could not find ExecutionNode associated with ContainerExecution." ) if len(execution_nodes) > 1: execution_node_ids = [execution.id for execution in execution_nodes] _logger.warning( - f"ContainerExecution is associated with multiple ExecutionNodes: {container_execution.id=}, {execution_node_ids=}" + f"ContainerExecution is associated with multiple ExecutionNodes: {execution_node_ids=}" ) if new_status == launcher_interfaces.ContainerStatus.RUNNING: @@ -715,7 +732,7 @@ def _maybe_preload_value( if missing_output_names: # Marking the container execution as FAILED (even though the program itself has completed successfully) container_execution.status = bts.ContainerExecutionStatus.FAILED - orchestration_error_message = f"Container execution {container_execution.id} is marked as FAILED due to missing outputs: {missing_output_names}." + orchestration_error_message = f"Container execution is marked as FAILED due to missing outputs: {missing_output_names}." _logger.error(orchestration_error_message) _record_orchestration_error_message( container_execution=container_execution, @@ -816,7 +833,7 @@ def _maybe_preload_value( ) else: _logger.error( - f"Container execution {container_execution.id} is now in unexpected state {new_status}. System error. {container_execution=}" + f"Container execution is now in unexpected state {new_status}. System error. {container_execution=}" ) # This SYSTEM_ERROR will be handled by the outer exception handler raise OrchestratorError( diff --git a/start_local.py b/start_local.py index e98ae69..52c8742 100644 --- a/start_local.py +++ b/start_local.py @@ -78,19 +78,70 @@ def get_user_details(request: fastapi.Request): # region: Logging configuration import logging.config +from cloud_pipelines_backend.instrumentation.logging_context import get_all_context_metadata + +class LoggingContextFilter(logging.Filter): + """Logging filter that adds contextual metadata to log records. + + This filter automatically adds metadata like execution_id and container_execution_id + to log records, making it easier to trace logs for specific executions. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Add contextual metadata to the log record.""" + for key, value in get_all_context_metadata().items(): + if value is not None: + setattr(record, key, value) + return True + + +class ContextAwareFormatter(logging.Formatter): + """Formatter that dynamically includes context fields only when they're set.""" + + def format(self, record: logging.LogRecord) -> str: + """Format log record with dynamic context fields.""" + # Base format + base_format = "%(asctime)s [%(levelname)s] %(name)s" + + # Collect context fields that are present + context_parts = [] + context_metadata = get_all_context_metadata() + for key, value in context_metadata.items(): + if value is not None and hasattr(record, key): + context_parts.append(f"{key}={value}") + + # Add context to format if any exists + if context_parts: + base_format += " [" + " ".join(context_parts) + "]" + + base_format += ": %(message)s" + + # Create formatter with the dynamic format + formatter = logging.Formatter(base_format) + return formatter.format(record) + LOGGING_CONFIG = { "version": 1, "disable_existing_loggers": True, "formatters": { "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, + "with_context": { + "()": ContextAwareFormatter, + }, + }, + "filters": { + "context_filter": { + "()": LoggingContextFilter, + }, }, "handlers": { "default": { "level": "INFO", - "formatter": "standard", + "formatter": "with_context", "class": "logging.StreamHandler", "stream": "ext://sys.stderr", + "filters": ["context_filter"], }, }, "loggers": { @@ -205,6 +256,8 @@ def run_orchestrator( from cloud_pipelines_backend import api_router from cloud_pipelines_backend import database_ops +from cloud_pipelines_backend.instrumentation.request_middleware import RequestContextMiddleware +from cloud_pipelines_backend.instrumentation import logging_context @contextlib.asynccontextmanager @@ -230,14 +283,21 @@ async def lifespan(app: fastapi.FastAPI): lifespan=lifespan, ) +# Add request context middleware for automatic request_id generation +app.add_middleware(RequestContextMiddleware) @app.exception_handler(Exception) def handle_error(request: fastapi.Request, exc: BaseException): exception_str = traceback.format_exception(type(exc), exc, exc.__traceback__) - return fastapi.responses.JSONResponse( + response = fastapi.responses.JSONResponse( status_code=503, content={"exception": exception_str}, ) + # Add request_id to error responses for traceability + request_id = logging_context.get_context_metadata("request_id") + if request_id: + response.headers["x-tangle-request-id"] = request_id + return response api_router.setup_routes( diff --git a/tests/test_instrumentation_logging_context.py b/tests/test_instrumentation_logging_context.py new file mode 100644 index 0000000..d92b5a9 --- /dev/null +++ b/tests/test_instrumentation_logging_context.py @@ -0,0 +1,168 @@ +"""Tests for the logging_context module in instrumentation.""" + +import pytest +from cloud_pipelines_backend.instrumentation import logging_context +from cloud_pipelines_backend.instrumentation.request_middleware import generate_request_id + + +class TestLoggingContext: + """Tests for logging context management.""" + + def setup_method(self): + """Clear any existing context before each test.""" + logging_context.clear_context_metadata() + + def teardown_method(self): + """Clear context after each test.""" + logging_context.clear_context_metadata() + + def test_set_and_get_context_metadata(self): + """Test setting and getting context metadata.""" + test_id = "abc123def456abc123def456abc12345" + + logging_context.set_context_metadata("request_id", test_id) + + assert logging_context.get_context_metadata("request_id") == test_id + + def test_get_context_metadata_returns_none_when_not_set(self): + """Test that get_context_metadata returns None when key is not set.""" + assert logging_context.get_context_metadata("request_id") is None + + def test_clear_context_metadata(self): + """Test clearing all context metadata.""" + logging_context.set_context_metadata("request_id", "test123") + logging_context.set_context_metadata("execution_id", "exec456") + logging_context.clear_context_metadata() + + assert logging_context.get_context_metadata("request_id") is None + assert logging_context.get_context_metadata("execution_id") is None + + def test_overwrite_context_metadata(self): + """Test that setting a new value overwrites the old one.""" + logging_context.set_context_metadata("request_id", "first_id") + logging_context.set_context_metadata("request_id", "second_id") + + assert logging_context.get_context_metadata("request_id") == "second_id" + + def test_get_all_context_metadata(self): + """Test getting all context metadata at once.""" + logging_context.set_context_metadata("request_id", "req123") + logging_context.set_context_metadata("execution_id", "exec456") + logging_context.set_context_metadata("custom_field", "value789") + + all_metadata = logging_context.get_all_context_metadata() + + assert all_metadata["request_id"] == "req123" + assert all_metadata["execution_id"] == "exec456" + assert all_metadata["custom_field"] == "value789" + + +class TestLoggingContextManager: + """Tests for the logging_context context manager.""" + + def setup_method(self): + """Clear any existing context before each test.""" + logging_context.clear_context_metadata() + + def teardown_method(self): + """Clear context after each test.""" + logging_context.clear_context_metadata() + + def test_context_manager_sets_and_restores_metadata(self): + """Test that context manager sets metadata on enter and restores on exit.""" + test_id = "context_test_123" + + assert logging_context.get_context_metadata("request_id") is None + + with logging_context.logging_context(request_id=test_id): + assert logging_context.get_context_metadata("request_id") == test_id + + assert logging_context.get_context_metadata("request_id") is None + + def test_context_manager_with_multiple_keys(self): + """Test that context manager handles multiple metadata keys.""" + with logging_context.logging_context( + request_id="req123", + execution_id="exec456", + pipeline_run_id="run789" + ): + assert logging_context.get_context_metadata("request_id") == "req123" + assert logging_context.get_context_metadata("execution_id") == "exec456" + assert logging_context.get_context_metadata("pipeline_run_id") == "run789" + + assert logging_context.get_context_metadata("request_id") is None + assert logging_context.get_context_metadata("execution_id") is None + assert logging_context.get_context_metadata("pipeline_run_id") is None + + def test_context_manager_with_none_values(self): + """Test that context manager skips None values.""" + with logging_context.logging_context(request_id="req123", execution_id=None): + assert logging_context.get_context_metadata("request_id") == "req123" + assert logging_context.get_context_metadata("execution_id") is None + + assert logging_context.get_context_metadata("request_id") is None + + def test_context_manager_clears_on_exception(self): + """Test that context manager restores metadata even when exception occurs.""" + test_id = "exception_test" + + with pytest.raises(ValueError): + with logging_context.logging_context(request_id=test_id): + assert logging_context.get_context_metadata("request_id") == test_id + raise ValueError("Test exception") + + # Metadata should be cleared even after exception + assert logging_context.get_context_metadata("request_id") is None + + def test_context_manager_nested(self): + """Test nested context managers.""" + outer_id = "outer_id" + inner_id = "inner_id" + + with logging_context.logging_context(request_id=outer_id): + assert logging_context.get_context_metadata("request_id") == outer_id + + with logging_context.logging_context(request_id=inner_id): + assert logging_context.get_context_metadata("request_id") == inner_id + + # After inner context exits, outer context is restored + assert logging_context.get_context_metadata("request_id") == outer_id + + assert logging_context.get_context_metadata("request_id") is None + + def test_context_manager_with_generated_request_id(self): + """Test using context manager with a generated request_id.""" + generated_id = generate_request_id() + + with logging_context.logging_context(request_id=generated_id): + assert logging_context.get_context_metadata("request_id") == generated_id + assert len(logging_context.get_context_metadata("request_id")) == 32 + + assert logging_context.get_context_metadata("request_id") is None + + def test_context_manager_multiple_sequential_uses(self): + """Test using context manager multiple times sequentially.""" + ids = ["id1", "id2", "id3"] + + for test_id in ids: + with logging_context.logging_context(request_id=test_id): + assert logging_context.get_context_metadata("request_id") == test_id + assert logging_context.get_context_metadata("request_id") is None + + def test_context_manager_preserves_existing_metadata(self): + """Test that nested context preserves existing metadata not being overwritten.""" + with logging_context.logging_context(request_id="req123", execution_id="exec456"): + assert logging_context.get_context_metadata("request_id") == "req123" + assert logging_context.get_context_metadata("execution_id") == "exec456" + + # Inner context only sets pipeline_run_id + with logging_context.logging_context(pipeline_run_id="run789"): + # Previous values should still be accessible + assert logging_context.get_context_metadata("request_id") == "req123" + assert logging_context.get_context_metadata("execution_id") == "exec456" + assert logging_context.get_context_metadata("pipeline_run_id") == "run789" + + # After inner exits, pipeline_run_id is gone but others remain + assert logging_context.get_context_metadata("request_id") == "req123" + assert logging_context.get_context_metadata("execution_id") == "exec456" + assert logging_context.get_context_metadata("pipeline_run_id") is None diff --git a/tests/test_instrumentation_request_middleware.py b/tests/test_instrumentation_request_middleware.py new file mode 100644 index 0000000..d59e5fc --- /dev/null +++ b/tests/test_instrumentation_request_middleware.py @@ -0,0 +1,303 @@ +"""Tests for the request_middleware module in instrumentation.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from starlette.requests import Request +from starlette.responses import Response +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from cloud_pipelines_backend.instrumentation import logging_context +from cloud_pipelines_backend.instrumentation.request_middleware import ( + RequestContextMiddleware, + generate_request_id, +) + + +class TestRequestIdGeneration: + """Tests for request_id generation.""" + + def test_generate_request_id_returns_32_char_hex(self): + """Test that generated request_id is 32 hexadecimal characters.""" + request_id = generate_request_id() + + assert len(request_id) == 32 + assert all(c in "0123456789abcdef" for c in request_id) + + def test_generate_request_id_is_unique(self): + """Test that each generated request_id is unique.""" + request_ids = {generate_request_id() for _ in range(100)} + + # All 100 should be unique + assert len(request_ids) == 100 + + def test_generate_request_id_is_lowercase(self): + """Test that generated request_id uses lowercase hex.""" + request_id = generate_request_id() + + assert request_id == request_id.lower() + + +class TestRequestIdFormatting: + """Tests for request_id format validation.""" + + def test_generated_request_id_format(self): + """Test that generated request_id matches expected format.""" + request_id = generate_request_id() + + # Should be 32 characters + assert len(request_id) == 32 + + # Should be valid hex + try: + int(request_id, 16) + except ValueError: + pytest.fail("request_id is not valid hexadecimal") + + # Should be lowercase + assert request_id.islower() + + def test_request_id_is_128_bits(self): + """Test that request_id represents 128 bits (16 bytes).""" + request_id = generate_request_id() + + # 32 hex characters = 16 bytes = 128 bits + assert len(bytes.fromhex(request_id)) == 16 + + +class TestRequestContextMiddleware: + """Tests for RequestContextMiddleware.""" + + def setup_method(self): + """Clear any existing context before each test.""" + logging_context.clear_context_metadata() + + def teardown_method(self): + """Clear context after each test.""" + logging_context.clear_context_metadata() + + def test_middleware_generates_request_id(self): + """Test that middleware generates a request_id for each request.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + request_ids_seen = [] + + @app.route("/test") + def test_route(request): + # Capture the request_id during request processing + request_ids_seen.append(logging_context.get_context_metadata("request_id")) + return Response("ok") + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + assert len(request_ids_seen) == 1 + assert request_ids_seen[0] is not None + assert len(request_ids_seen[0]) == 32 + + def test_middleware_adds_request_id_to_response_headers(self): + """Test that middleware adds request_id to response headers.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + @app.route("/test") + def test_route(request): + return Response("ok") + + client = TestClient(app) + response = client.get("/test") + + assert "x-tangle-request-id" in response.headers + request_id = response.headers["x-tangle-request-id"] + assert len(request_id) == 32 + assert all(c in "0123456789abcdef" for c in request_id) + + def test_middleware_clears_request_id_after_request(self): + """Test that middleware clears request_id after request completes.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + @app.route("/test") + def test_route(request): + assert logging_context.get_context_metadata("request_id") is not None + return Response("ok") + + client = TestClient(app) + + # Before request + assert logging_context.get_context_metadata("request_id") is None + + # Make request + response = client.get("/test") + assert response.status_code == 200 + + # After request - Note: in test client, context might not be cleared + # the same way as in production, but the middleware's context manager ensures it + + def test_middleware_generates_unique_request_ids(self): + """Test that middleware generates unique request_ids for each request.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + @app.route("/test") + def test_route(request): + return Response("ok") + + client = TestClient(app) + + # Make multiple requests + request_ids = set() + for _ in range(10): + response = client.get("/test") + request_ids.add(response.headers["x-tangle-request-id"]) + + # All request_ids should be unique + assert len(request_ids) == 10 + + def test_middleware_request_id_available_in_route(self): + """Test that request_id set by middleware is available in route handler.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + captured_request_id = None + + @app.route("/test") + def test_route(request): + nonlocal captured_request_id + captured_request_id = logging_context.get_context_metadata("request_id") + return Response(f"request_id: {captured_request_id}") + + client = TestClient(app) + response = client.get("/test") + + assert captured_request_id is not None + assert captured_request_id == response.headers["x-tangle-request-id"] + assert captured_request_id in response.text + + def test_middleware_handles_exception_in_route(self): + """Test that middleware clears request_id even when route raises exception.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + @app.route("/test") + def test_route(request): + request_id_during_exception = logging_context.get_context_metadata("request_id") + assert request_id_during_exception is not None + raise ValueError("Test exception") + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/test") + + # Even though route raised exception, response should have request_id header + # (middleware's context manager ensures cleanup) + assert response.status_code == 500 + + def test_middleware_with_multiple_routes(self): + """Test middleware works correctly with multiple routes.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + request_ids_by_route = {} + + @app.route("/route1") + def route1(request): + request_ids_by_route["route1"] = logging_context.get_context_metadata("request_id") + return Response("route1") + + @app.route("/route2") + def route2(request): + request_ids_by_route["route2"] = logging_context.get_context_metadata("request_id") + return Response("route2") + + client = TestClient(app) + + response1 = client.get("/route1") + response2 = client.get("/route2") + + # Each route should have gotten a request_id + assert request_ids_by_route["route1"] is not None + assert request_ids_by_route["route2"] is not None + + # They should be different + assert request_ids_by_route["route1"] != request_ids_by_route["route2"] + + # Response headers should match + assert response1.headers["x-tangle-request-id"] == request_ids_by_route["route1"] + assert response2.headers["x-tangle-request-id"] == request_ids_by_route["route2"] + + +class TestRequestContextMiddlewareIntegration: + """Integration tests for RequestContextMiddleware with logging.""" + + def setup_method(self): + """Clear any existing context before each test.""" + logging_context.clear_context_metadata() + + def teardown_method(self): + """Clear context after each test.""" + logging_context.clear_context_metadata() + + def test_middleware_enables_request_id_in_logs(self): + """Test that middleware enables request_id to be used in logging.""" + import logging + + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + logged_request_ids = [] + + # Create a custom handler to capture log records + class TestHandler(logging.Handler): + def emit(self, record): + # In real usage, LoggingContextFilter would add request_id to logs + current_request_id = logging_context.get_context_metadata("request_id") + if current_request_id: + logged_request_ids.append(current_request_id) + + logger = logging.getLogger("test_logger") + handler = TestHandler() + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + @app.route("/test") + def test_route(request): + logger.info("Processing request") + return Response("ok") + + client = TestClient(app) + response = client.get("/test") + + # The request_id logged should match the response header + assert len(logged_request_ids) > 0 + assert response.headers["x-tangle-request-id"] in logged_request_ids + + # Cleanup + logger.removeHandler(handler) + + def test_middleware_request_id_persists_across_function_calls(self): + """Test that request_id persists across function calls within a request.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + request_ids_collected = [] + + def helper_function(): + """Helper function that accesses request_id.""" + request_ids_collected.append(logging_context.get_context_metadata("request_id")) + + @app.route("/test") + def test_route(request): + request_ids_collected.append(logging_context.get_context_metadata("request_id")) + helper_function() + request_ids_collected.append(logging_context.get_context_metadata("request_id")) + return Response("ok") + + client = TestClient(app) + response = client.get("/test") + + # All three captures should have the same request_id + assert len(request_ids_collected) == 3 + assert request_ids_collected[0] == request_ids_collected[1] == request_ids_collected[2] + assert request_ids_collected[0] == response.headers["x-tangle-request-id"] diff --git a/tests/test_request_id_concurrency.py b/tests/test_request_id_concurrency.py new file mode 100644 index 0000000..baf93f2 --- /dev/null +++ b/tests/test_request_id_concurrency.py @@ -0,0 +1,234 @@ +"""Test that request_id works correctly with concurrent requests.""" + +import asyncio +import pytest +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +from cloud_pipelines_backend.instrumentation import logging_context +from cloud_pipelines_backend.instrumentation.request_middleware import RequestContextMiddleware + + +def test_request_id_isolation_with_concurrent_requests(): + """Test that each concurrent request gets its own isolated request_id.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + # Store request_ids seen by each endpoint + request_ids_seen = { + "endpoint1": [], + "endpoint2": [], + } + + @app.route("/endpoint1") + async def endpoint1(request): + request_id = logging_context.get_context_metadata("request_id") + request_ids_seen["endpoint1"].append(request_id) + # Simulate some work + await asyncio.sleep(0.1) + # Verify request_id is still the same after async work + assert logging_context.get_context_metadata("request_id") == request_id + return JSONResponse({"request_id": request_id}) + + @app.route("/endpoint2") + async def endpoint2(request): + request_id = logging_context.get_context_metadata("request_id") + request_ids_seen["endpoint2"].append(request_id) + # Simulate some work + await asyncio.sleep(0.1) + # Verify request_id is still the same after async work + assert logging_context.get_context_metadata("request_id") == request_id + return JSONResponse({"request_id": request_id}) + + client = TestClient(app) + + # Make concurrent requests + response1 = client.get("/endpoint1") + response2 = client.get("/endpoint2") + response3 = client.get("/endpoint1") + + # All requests should succeed + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response3.status_code == 200 + + # Each request should have gotten a unique request_id + request_id_1 = response1.headers["x-tangle-request-id"] + request_id_2 = response2.headers["x-tangle-request-id"] + request_id_3 = response3.headers["x-tangle-request-id"] + + # All request_ids should be unique + assert request_id_1 != request_id_2 + assert request_id_1 != request_id_3 + assert request_id_2 != request_id_3 + + # Verify endpoints saw the correct request_ids + assert request_ids_seen["endpoint1"][0] == request_id_1 + assert request_ids_seen["endpoint2"][0] == request_id_2 + assert request_ids_seen["endpoint1"][1] == request_id_3 + + +def test_request_id_isolation_with_nested_async_calls(): + """Test that request_id persists correctly through nested async function calls.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + request_ids_collected = [] + + async def helper_function_1(): + """First level helper.""" + request_ids_collected.append(("helper1", logging_context.get_context_metadata("request_id"))) + await asyncio.sleep(0.01) + await helper_function_2() + request_ids_collected.append(("helper1_after", logging_context.get_context_metadata("request_id"))) + + async def helper_function_2(): + """Second level helper.""" + request_ids_collected.append(("helper2", logging_context.get_context_metadata("request_id"))) + await asyncio.sleep(0.01) + request_ids_collected.append(("helper2_after", logging_context.get_context_metadata("request_id"))) + + @app.route("/test") + async def test_route(request): + request_ids_collected.append(("start", logging_context.get_context_metadata("request_id"))) + await helper_function_1() + request_ids_collected.append(("end", logging_context.get_context_metadata("request_id"))) + return JSONResponse({"ok": True}) + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + request_id = response.headers["x-tangle-request-id"] + + # All captured request_ids should be the same + for label, captured_request_id in request_ids_collected: + assert captured_request_id == request_id, f"Mismatch at {label}" + + # Should have captured 6 request_ids total + assert len(request_ids_collected) == 6 + + +def test_request_id_does_not_leak_between_requests(): + """Test that request_id from one request doesn't leak into another.""" + app = Starlette() + app.add_middleware(RequestContextMiddleware) + + request_ids_per_request = [] + + @app.route("/test") + async def test_route(request): + # Capture request_id at start + start_request_id = logging_context.get_context_metadata("request_id") + request_ids_per_request.append(start_request_id) + + # Do some async work + await asyncio.sleep(0.05) + + # Verify it hasn't changed + end_request_id = logging_context.get_context_metadata("request_id") + assert start_request_id == end_request_id + + return JSONResponse({"request_id": end_request_id}) + + client = TestClient(app) + + # Make multiple sequential requests + responses = [client.get("/test") for _ in range(5)] + + # All should succeed + assert all(r.status_code == 200 for r in responses) + + # Extract request_ids from responses + response_request_ids = [r.headers["x-tangle-request-id"] for r in responses] + + # All should be unique + assert len(set(response_request_ids)) == 5 + + # Should match what we captured inside the handler + assert response_request_ids == request_ids_per_request + + +@pytest.mark.asyncio +async def test_contextvars_isolation_across_async_tasks(): + """Direct test of contextvars isolation without HTTP layer.""" + + async def task_with_request_id(task_id: str, expected_request_id: str): + """Simulates a task with its own request_id context.""" + # Set request_id for this task + logging_context.set_context_metadata("request_id", expected_request_id) + + # Verify it's set correctly + assert logging_context.get_context_metadata("request_id") == expected_request_id + + # Simulate some work + await asyncio.sleep(0.01) + + # Verify request_id is still correct after async work + assert logging_context.get_context_metadata("request_id") == expected_request_id + + # More work + await asyncio.sleep(0.01) + + # Still correct + assert logging_context.get_context_metadata("request_id") == expected_request_id + + # Clean up + logging_context.clear_context_metadata() + + return task_id + + # Run multiple tasks concurrently with different request_ids + tasks = [ + task_with_request_id("task1", "request_aaa111"), + task_with_request_id("task2", "request_bbb222"), + task_with_request_id("task3", "request_ccc333"), + task_with_request_id("task4", "request_ddd444"), + ] + + results = await asyncio.gather(*tasks) + + # All tasks should complete successfully + assert results == ["task1", "task2", "task3", "task4"] + + # After all tasks complete, there should be no request_id in this context + assert logging_context.get_context_metadata("request_id") is None + + +def test_request_id_with_context_manager_is_thread_safe(): + """Test that the logging_context context manager works with concurrent access.""" + + collected_request_ids = [] + + def simulate_request_processing(request_id: str): + """Simulates processing with a request_id.""" + with logging_context.logging_context(request_id=request_id): + # Verify request_id is set + current = logging_context.get_context_metadata("request_id") + collected_request_ids.append((request_id, current)) + assert current == request_id + + # After context exits, should be cleared in this context + # (though in threads, contexts are separate anyway) + + import threading + + # Create threads that will process with different request_ids + threads = [ + threading.Thread(target=simulate_request_processing, args=(f"request_{i:03d}",)) + for i in range(10) + ] + + # Start all threads + for thread in threads: + thread.start() + + # Wait for all to complete + for thread in threads: + thread.join() + + # All threads should have seen their correct request_id + assert len(collected_request_ids) == 10 + for expected, actual in collected_request_ids: + assert expected == actual