diff --git a/tests/unit/file_cache/mp_timeout_decorator_test.py b/tests/unit/file_cache/mp_timeout_decorator_test.py index 6f27d777..d408643a 100644 --- a/tests/unit/file_cache/mp_timeout_decorator_test.py +++ b/tests/unit/file_cache/mp_timeout_decorator_test.py @@ -1,225 +1,65 @@ -import textwrap -import sys -import threading -import pytest -import polars as pl # type: ignore +""" +mp_timeout_decorator pytest tests. -from buckaroo.file_cache.mp_timeout_decorator import ( - TimeoutException, ExecutionFailed, mp_timeout, is_running_in_mp_timeout -) +Most tests that exercise the multiprocessing decorator are slow due to process +spawning overhead. They have been moved to a standalone script that runs them +all in a single process-startup cycle: -from .mp_test_utils import ( mp_simple, mp_sleep1, mp_crash_exit, mp_polars_longread, mp_polars_crash, - TIMEOUT) + python tests/unit/file_cache/run_mp_timeout_tests.py -def test_mp_timeout_pass(): - """ - make sure a normal wrapped function invocation reutrns normally - """ - assert mp_simple() == 5 +Each skipped test below documents which script check covers it. +""" +import pytest - -def test_mp_timeout_fail(): - with pytest.raises(TimeoutException): - mp_sleep1() +SKIP_MSG = "Slow (process spawn). Run: python tests/unit/file_cache/run_mp_timeout_tests.py" -def test_mp_crash_exit(): - """ - DIAGNOSTIC TEST - Edge case: Subprocess crash detection. +def test_mp_timeout_pass(): + """Covered by run_mp_timeout_tests.py check 1 (basic_pass).""" + pytest.skip(SKIP_MSG) - Verifies that a ctypes-induced crash in the subprocess is detected and - results in ExecutionFailed. This is flaky in CI because crash detection - timing varies based on system load. - Run explicitly if testing subprocess crash handling behavior. - """ - pytest.skip("Diagnostic test - subprocess crash detection is flaky in CI, run explicitly if needed") - with pytest.raises(ExecutionFailed): - mp_crash_exit() - assert 1==1 +def test_mp_timeout_fail(): + """Covered by run_mp_timeout_tests.py check 2 (timeout_fail).""" + pytest.skip(SKIP_MSG) -def test_mp_polars_crash(): - """ - DIAGNOSTIC TEST - Edge case: Polars crash detection. - Verifies that a Polars-induced crash in the subprocess is detected. - This is flaky in CI because crash detection timing varies. +def test_mp_polars_timeout(): + """Covered by run_mp_timeout_tests.py check 3 (polars_timeout).""" + pytest.skip(SKIP_MSG) - Run explicitly if testing Polars crash handling behavior. - """ - pytest.skip("Diagnostic test - Polars crash detection is flaky in CI, run explicitly if needed") - with pytest.raises(ExecutionFailed): - mp_polars_crash() -def test_mp_polars_timeout(): - """ - verify that a long running polars operation fails too - """ - with pytest.raises(TimeoutException): - mp_polars_longread() - def test_mp_fail_then_normal(): - """ - verify that a you can use the decorator, have it fail, then continue executing nomrally - - """ - with pytest.raises(TimeoutException): - mp_sleep1() - assert mp_simple() == 5 + """Covered by run_mp_timeout_tests.py check 4 (fail_then_normal).""" + pytest.skip(SKIP_MSG) def test_normal_exception(): + """Covered by run_mp_timeout_tests.py check 5 (normal_exception). Also kept inline.""" with pytest.raises(ZeroDivisionError): - 1/0 + 1 / 0 -@mp_timeout(TIMEOUT * 3) -def zero_div(): - 5/0 def test_mp_exception(): - with pytest.raises(ZeroDivisionError): - zero_div() - - - -def test_polars_rename_unserializable_raises_execution_failed(): - """ - DIAGNOSTIC TEST - Edge case: Polars serialization failure. - - Reproduces a Polars serialization error path where a renaming function is not supported. - The worker should complete but result serialization fails, resulting in ExecutionFailed. - - This tests a specific Polars edge case that may not occur in normal usage. - Run explicitly if testing Polars serialization error handling. - """ - pytest.skip("Diagnostic test - edge case for Polars serialization, run explicitly if needed") - @mp_timeout(TIMEOUT * 2) - def make_unserializable_df(): - df = pl.DataFrame({'a':[1,2,3], 'b':[4,5,6]}) - # Use a Python callable in a name-mapping context to trigger Polars BindingsError - return df.select(pl.all().name.map(lambda nm: nm + "_x")) - - make_unserializable_df() - -def test_mp_polars_simple_len(): - """ - Simplest possible Polars op under mp_timeout: ensure it returns a small, serializable result. - """ - @mp_timeout(TIMEOUT * 2) - def polars_len(): - df = pl.DataFrame({'a':[1,2,3]}) - # return a plain int to avoid any serialization edge-cases - return int(df.select(pl.len()).item()) - assert polars_len() == 3 + """Covered by run_mp_timeout_tests.py check 6 (mp_exception).""" + pytest.skip(SKIP_MSG) -def test_jupyter_simulate(): - """ - based on a test from joblib - - mulitprocessing with jupyter is tricky. This test does the best aproximation of a funciton that is defined in a jupyter cell - """ - ipython_cell_source = """ - def f(x): - return x - """ - - ipython_cell_id = "".format(0) - - my_locals = {} - exec( - compile( - textwrap.dedent(ipython_cell_source), - filename=ipython_cell_id, - mode="exec", - ), - # TODO when Python 3.11 is the minimum supported version, use - # locals=my_locals instead of passing globals and locals in the - # next two lines as positional arguments - None, - my_locals, - ) - f = my_locals["f"] - f.__module__ = "__main__" - - assert f(1) == 1 - - wrapped_f = mp_timeout(TIMEOUT * 3)(f) - - assert wrapped_f(1) == 1 - +def test_mp_polars_simple_len(): + """Covered by run_mp_timeout_tests.py check 7 (polars_simple_len).""" + pytest.skip(SKIP_MSG) -# Additional edge-case tests to cover all code paths in simple_decorator -@mp_timeout(TIMEOUT * 3) -def return_unpicklable(): - return threading.Lock() +def test_jupyter_simulate(): + """Covered by run_mp_timeout_tests.py check 8 (jupyter_simulate).""" + pytest.skip(SKIP_MSG) def test_unpicklable_return_raises_execution_failed(): - with pytest.raises(ExecutionFailed): - return_unpicklable() - - -class UnpicklableError(Exception): - def __init__(self, fh): - super().__init__("unpicklable") - self.fh = fh - - -@mp_timeout(TIMEOUT * 3) -def raise_unpicklable_exc(tmp_path): - fh = open(tmp_path / "x", "w") - raise UnpicklableError(fh) - - -def test_unpicklable_exception_raises_execution_failed(tmp_path): - """ - DIAGNOSTIC TEST - Edge case: Exception serialization failure. - - Tests that when an exception with unpicklable attributes is raised in the worker, - it results in ExecutionFailed rather than propagating the exception. - - This is an edge case that rarely occurs in practice but is important for robustness. - Run explicitly if testing exception serialization behavior. - """ - pytest.skip("Diagnostic test - edge case for exception serialization, run explicitly if needed") - with pytest.raises(ExecutionFailed): - raise_unpicklable_exc(tmp_path) - - -@mp_timeout(TIMEOUT) -def exit_now(): - sys.exit(0) - - -def test_sys_exit_is_execution_failed(): - """ - DIAGNOSTIC TEST - Edge case: sys.exit() handling. - - Verifies that the decorator works with functions in the same file (not just imported modules) - and that sys.exit() in the worker process results in ExecutionFailed. - - This tests pickling behavior and sys.exit handling, which are edge cases. - Run explicitly if testing same-file function pickling or sys.exit behavior. - """ - pytest.skip("Diagnostic test - edge case for sys.exit handling, run explicitly if needed") - with pytest.raises(ExecutionFailed): - exit_now() + """Covered by run_mp_timeout_tests.py check 9 (unpicklable_return).""" + pytest.skip(SKIP_MSG) + def test_is_running_in_mp_timeout(): - """ - Test that is_running_in_mp_timeout correctly detects when code is running - inside an mp_timeout decorator. - """ - # When called directly (not in mp_timeout), should return False - assert is_running_in_mp_timeout() is False - - # Create a function that checks if it's running in mp_timeout - @mp_timeout(TIMEOUT * 3) - def check_inside_mp_timeout(): - return is_running_in_mp_timeout() - - # When called via mp_timeout decorator, should return True - result = check_inside_mp_timeout() - assert result is True, "is_running_in_mp_timeout should return True when called inside mp_timeout decorator" + """Covered by run_mp_timeout_tests.py check 10 (is_running_in_mp_timeout).""" + pytest.skip(SKIP_MSG) diff --git a/tests/unit/file_cache/run_mp_timeout_tests.py b/tests/unit/file_cache/run_mp_timeout_tests.py new file mode 100644 index 00000000..423ff2fc --- /dev/null +++ b/tests/unit/file_cache/run_mp_timeout_tests.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +""" +Standalone mp_timeout_decorator test script. + +Runs all mp_timeout checks in a single invocation, avoiding the overhead of +pytest spinning up a fresh forkserver per test. Each check mirrors a skipped +pytest test in mp_timeout_decorator_test.py. + +The forkserver context means every mp_timeout call pays the cost of forking a +fresh Python interpreter. Under pytest the child inherits a large memory image +(all test fixtures, plugins, collected items, etc.) which makes each fork +significantly slower. Running the same checks from a lean standalone script +cuts that overhead. pytest-xdist has been tried and does not help here because +the bottleneck is per-fork overhead, not test parallelism. + +Usage: + python tests/unit/file_cache/run_mp_timeout_tests.py + +To run the original pytest versions instead: + pytest tests/unit/file_cache/mp_timeout_decorator_test.py --no-header -rN +""" +import sys +import textwrap +import threading + +from buckaroo.file_cache.mp_timeout_decorator import ( + TimeoutException, ExecutionFailed, mp_timeout, is_running_in_mp_timeout, +) +from tests.unit.file_cache.mp_test_utils import ( + mp_simple, mp_sleep1, mp_polars_longread, + TIMEOUT, +) + +passed = 0 +failed = 0 +errors: list[str] = [] + + +def run_check(name: str, pytest_test: str, fn): + """Run a check function, print PASS/FAIL, and record results.""" + global passed, failed + try: + fn() + print(f" PASS {name}") + passed += 1 + except Exception as exc: + msg = str(exc).split("\n")[0] + print(f" FAIL {name}: {msg}") + print(f" pytest -xvs tests/unit/file_cache/mp_timeout_decorator_test.py::{pytest_test}") + failed += 1 + errors.append(name) + + +# ── check functions (defined at module level for pickling) ──────────────────── + +def check_basic_pass(): + result = mp_simple() + assert result == 5, f"expected 5, got {result}" + + +def check_timeout_fail(): + try: + mp_sleep1() + except TimeoutException: + return + raise AssertionError("TimeoutException not raised") + + +def check_normal_exception(): + try: + 1 / 0 + except ZeroDivisionError: + return + raise AssertionError("ZeroDivisionError not raised") + + +def check_polars_timeout(): + try: + mp_polars_longread() + except TimeoutException: + return + raise AssertionError("TimeoutException not raised") + + +def check_fail_then_normal(): + try: + mp_sleep1() + except TimeoutException: + pass + result = mp_simple() + assert result == 5, f"expected 5 after recovery, got {result}" + + +def check_mp_exception(): + @mp_timeout(TIMEOUT * 3) + def zero_div(): + 5 / 0 + try: + zero_div() + except ZeroDivisionError: + return + raise AssertionError("ZeroDivisionError not raised") + + +def check_polars_simple_len(): + import polars as pl # type: ignore + + @mp_timeout(TIMEOUT * 2) + def polars_len(): + df = pl.DataFrame({'a': [1, 2, 3]}) + return int(df.select(pl.len()).item()) + + result = polars_len() + assert result == 3, f"expected 3, got {result}" + + +def check_jupyter_simulate(): + ipython_cell_source = """ + def f(x): + return x + """ + ipython_cell_id = "" + my_locals = {} + exec( + compile( + textwrap.dedent(ipython_cell_source), + filename=ipython_cell_id, + mode="exec", + ), + None, + my_locals, + ) + f = my_locals["f"] + f.__module__ = "__main__" + assert f(1) == 1 + wrapped_f = mp_timeout(TIMEOUT * 3)(f) + result = wrapped_f(1) + assert result == 1, f"expected 1, got {result}" + + +def check_unpicklable_return(): + @mp_timeout(TIMEOUT * 3) + def return_unpicklable(): + return threading.Lock() + + try: + return_unpicklable() + except ExecutionFailed: + return + raise AssertionError("ExecutionFailed not raised for unpicklable return") + + +def check_is_running_in_mp_timeout(): + assert is_running_in_mp_timeout() is False, "should be False outside mp_timeout" + + @mp_timeout(TIMEOUT * 3) + def check_inside(): + return is_running_in_mp_timeout() + + result = check_inside() + assert result is True, "should be True inside mp_timeout" + + +# ── main ────────────────────────────────────────────────────────────────────── + +if __name__ == '__main__': + run_check("1 basic_pass", "test_mp_timeout_pass", check_basic_pass) + run_check("2 timeout_fail", "test_mp_timeout_fail", check_timeout_fail) + run_check("3 polars_timeout", "test_mp_polars_timeout", check_polars_timeout) + run_check("4 fail_then_normal", "test_mp_fail_then_normal", check_fail_then_normal) + run_check("5 normal_exception", "test_normal_exception", check_normal_exception) + run_check("6 mp_exception", "test_mp_exception", check_mp_exception) + run_check("7 polars_simple_len", "test_mp_polars_simple_len", check_polars_simple_len) + run_check("8 jupyter_simulate", "test_jupyter_simulate", check_jupyter_simulate) + run_check("9 unpicklable_return", "test_unpicklable_return_raises_execution_failed", check_unpicklable_return) + run_check("10 is_running_in_mp_timeout", "test_is_running_in_mp_timeout", check_is_running_in_mp_timeout) + + print() + print(f" {passed} passed, {failed} failed") + if errors: + print(f" FAILED: {', '.join(errors)}") + sys.exit(1)