From 939bec07a3ab3cd10a0de1ac00fa8d529c90643e Mon Sep 17 00:00:00 2001 From: Peter Sprygada Date: Fri, 12 Dec 2025 07:00:20 -0500 Subject: [PATCH] refactor: Enhance thread safety in logging module - Add threading.RLock for thread-safe sensitive data filtering - Protect filtering check with reentrant lock in log() function - Create snapshot of logger names in _get_loggers() to prevent race conditions - Fix fatal() to write to stderr instead of stdout - Accept "NONE" string in set_level() for disabling logging - Move initial logger level setting to after NONE constant definition - Update docstrings to document thread-safety guarantees - Update tests to verify stderr output in fatal function --- src/ipsdk/logging.py | 74 ++++++++++++++++++++++++++++--------------- tests/test_logging.py | 5 +-- 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/src/ipsdk/logging.py b/src/ipsdk/logging.py index 4c135da..ed5280d 100644 --- a/src/ipsdk/logging.py +++ b/src/ipsdk/logging.py @@ -105,6 +105,7 @@ def process_data(data): import inspect import logging import sys +import threading import time import traceback @@ -120,8 +121,6 @@ def process_data(data): logging_message_format = "%(asctime)s: [%(name)s] %(levelname)s: %(message)s" -logging.getLogger(metadata.name).setLevel(100) - # Add the FATAL logging level logging.FATAL = 90 # type: ignore[misc] logging.addLevelName(logging.FATAL, "FATAL") @@ -143,7 +142,11 @@ def process_data(data): FATAL = logging.FATAL NONE = logging.NONE -# Global flag for sensitive data filtering +# Set initial log level to NONE (disabled) +logging.getLogger(metadata.name).setLevel(NONE) + +# Thread-safe configuration for sensitive data filtering +_filtering_lock = threading.RLock() _sensitive_data_filtering_enabled = False @@ -159,6 +162,9 @@ def log(lvl: int, msg: str) -> None: functions (debug, info, warning, error, critical, fatal) to send a log message with a given level. + This function is thread-safe. The sensitive data filtering check is + protected by a reentrant lock to prevent race conditions. + Args: lvl (int): The logging level of the message. msg (str): The message to write to the logger. @@ -169,9 +175,10 @@ def log(lvl: int, msg: str) -> None: Raises: None """ - # Apply sensitive data filtering if enabled - if _sensitive_data_filtering_enabled is True: - msg = heuristics.scan_and_redact(msg) + # Apply sensitive data filtering if enabled (thread-safe) + with _filtering_lock: + if _sensitive_data_filtering_enabled: + msg = heuristics.scan_and_redact(msg) logging.getLogger(metadata.name).log(lvl, msg) @@ -279,7 +286,7 @@ def fatal(msg: str) -> None: """Log a fatal error and exit the application. A fatal error will log the message using level 90 (FATAL) and print - an error message to stdout. It will then exit the application with + an error message to stderr. It will then exit the application with return code 1. Args: @@ -292,7 +299,7 @@ def fatal(msg: str) -> None: SystemExit: Always raised with exit code 1 after logging the fatal error. """ log(logging.FATAL, msg) - print(f"ERROR: {msg}") + print(f"ERROR: {msg}", file=sys.stderr) sys.exit(1) @@ -304,6 +311,10 @@ def _get_loggers() -> set[logging.Logger]: dependencies (ipsdk, FastMCP). Results are cached to improve performance on subsequent calls. + This function is thread-safe. It creates a snapshot of logger names before + iteration to prevent issues if the logger dictionary is modified by other + threads during iteration. + Note: The cached result may not immediately reflect loggers created after the first call. Call _get_loggers.cache_clear() to force a refresh @@ -313,7 +324,9 @@ def _get_loggers() -> set[logging.Logger]: set[logging.Logger]: Set of logger instances for the application and dependencies. """ loggers = set() - for name in logging.Logger.manager.loggerDict: + # Create a snapshot of logger names to prevent race conditions during iteration + logger_names = list(logging.Logger.manager.loggerDict.keys()) + for name in logger_names: if name.startswith((metadata.name, "httpx")): loggers.add(logging.getLogger(name)) return loggers @@ -334,12 +347,13 @@ def get_logger() -> logging.Logger: return logging.getLogger(metadata.name) -def set_level(lvl: int, *, propagate: bool = False) -> None: +def set_level(lvl: int | str, *, propagate: bool = False) -> None: """Set logging level for all loggers in the current Python process. Args: - lvl (int): Logging level (e.g., logging.INFO, logging.DEBUG). This - is a required argument. + lvl (int | str): Logging level (e.g., logging.INFO, logging.DEBUG, or "NONE"). + This is a required argument. Can be an integer level or the string "NONE" + to disable all logging. propagate (bool): Setting this value to True will also turn on logging for httpx and httpcore. Defaults to False. @@ -347,12 +361,17 @@ def set_level(lvl: int, *, propagate: bool = False) -> None: None Raises: - None + TypeError: If lvl is a string other than "NONE". """ logger = get_logger() - if lvl == "NONE": - lvl = NONE + # Convert string "NONE" to NONE constant + if isinstance(lvl, str): + if lvl == "NONE": + lvl = NONE + else: + msg = f"Invalid level string: {lvl}. Only 'NONE' is supported as a string." + raise TypeError(msg) logger.setLevel(lvl) logger.propagate = False @@ -374,8 +393,7 @@ def enable_sensitive_data_filtering() -> None: information (such as passwords, tokens, API keys) and redacted before being written to the log output. - Args: - None + This function is thread-safe. Returns: None @@ -384,7 +402,8 @@ def enable_sensitive_data_filtering() -> None: None """ global _sensitive_data_filtering_enabled # noqa: PLW0603 - _sensitive_data_filtering_enabled = True + with _filtering_lock: + _sensitive_data_filtering_enabled = True def disable_sensitive_data_filtering() -> None: @@ -394,8 +413,7 @@ def disable_sensitive_data_filtering() -> None: for sensitive information. Use with caution in production environments as this may expose sensitive data in log files. - Args: - None + This function is thread-safe. Returns: None @@ -404,7 +422,8 @@ def disable_sensitive_data_filtering() -> None: None """ global _sensitive_data_filtering_enabled # noqa: PLW0603 - _sensitive_data_filtering_enabled = False + with _filtering_lock: + _sensitive_data_filtering_enabled = False def is_sensitive_data_filtering_enabled() -> bool: @@ -413,8 +432,7 @@ def is_sensitive_data_filtering_enabled() -> bool: Returns the current state of sensitive data filtering to determine if log messages are being scanned and redacted. - Args: - None + This function is thread-safe. Returns: bool: True if filtering is enabled, False otherwise @@ -422,7 +440,8 @@ def is_sensitive_data_filtering_enabled() -> bool: Raises: None """ - return _sensitive_data_filtering_enabled + with _filtering_lock: + return _sensitive_data_filtering_enabled def configure_sensitive_data_patterns( @@ -516,8 +535,11 @@ def initialize() -> None: replacing them with a standard StreamHandler that writes to stderr. This ensures consistent logging configuration across all related loggers. - Args: - None + Warning: + This function is NOT thread-safe. It should only be called during + application startup before any logging activity begins. Calling this + function while other threads are actively logging may result in lost + log messages or exceptions. Returns: None diff --git a/tests/test_logging.py b/tests/test_logging.py index 60b911c..0630bf2 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -265,7 +265,7 @@ def test_fatal_function_logs_and_exits(self): ipsdk_logging.fatal("fatal error") mock_log.assert_called_once_with(ipsdk_logging.FATAL, "fatal error") - mock_print.assert_called_once_with("ERROR: fatal error") + mock_print.assert_called_once_with("ERROR: fatal error", file=sys.stderr) mock_exit.assert_called_once_with(1) def test_fatal_function_different_messages(self): @@ -279,7 +279,8 @@ def test_fatal_function_different_messages(self): ipsdk_logging.fatal(message) mock_log.assert_called_once_with(ipsdk_logging.FATAL, message) - mock_print.assert_called_once_with(f"ERROR: {message}") + expected_msg = f"ERROR: {message}" + mock_print.assert_called_once_with(expected_msg, file=sys.stderr) mock_exit.assert_called_once_with(1)