From caa504673602b67b491f85cd0eab27a08eb4ccc9 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 6 Feb 2026 13:55:41 +0100 Subject: [PATCH 1/2] Support Cast node in pushdown logic --- duckdb/polars_io.py | 7 +++++++ tests/fast/arrow/test_polars.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index ad74b848..23e0062f 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -159,6 +159,13 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: msg = f"Unsupported function type: {func_dict}" raise NotImplementedError(msg) + if node_type == "Cast": + cast_tree = tree[node_type] + assert isinstance(cast_tree, dict), f"A {node_type} should be a dict but got {type(cast_tree)}" + cast_expr = cast_tree["expr"] + assert isinstance(cast_expr, dict), f"A {node_type} should be a dict but got {type(cast_expr)}" + return _pl_tree_to_sql(cast_expr) + if node_type == "Scalar": # Detect format: old style (dtype/value) or new style (direct type key) scalar_tree = tree[node_type] diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 0eba5eeb..3d83e1b1 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -702,3 +702,40 @@ def test_decimal_scale(self): } } """ assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1" + + def test_cast_node_unwraps_inner_expression(self): + """Cast nodes should be unwrapped to process the inner expression.""" + # A Cast wrapping a Column reference + cast_column = json.loads( + '{"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}}' + ) + assert _pl_tree_to_sql(cast_column) == '"a"' + + # A Cast wrapping a full binary expression + cast_expr = json.loads(""" + { + "BinaryExpr": { + "left": {"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}}, + "op": "Eq", + "right": {"Literal": {"Int": 1}} + } + } + """) + assert _pl_tree_to_sql(cast_expr) == '("a" = 1)' + + def test_cast_node_predicate_pushdown(self): + """Predicates with Cast nodes should be successfully pushed down.""" + # A decimal with non-38 precision produces a Cast node in Polars + expr = pl.col("a") == pl.lit(1, dtype=pl.Decimal(precision=20, scale=0)) + valid_filter(expr) + + def test_polars_lazy_pushdown_decimal_with_cast(self): + """End-to-end test: decimal columns with non-38 precision should push down filters.""" + con = duckdb.connect() + con.execute("CREATE TABLE test_cast (a DECIMAL(20,0))") + con.execute("INSERT INTO test_cast VALUES (1), (10), (100), (NULL)") + rel = con.sql("FROM test_cast") + lazy_df = rel.pl(lazy=True) + + assert lazy_df.filter(pl.col("a") == 1).collect().to_dicts() == [{"a": 1}] + assert lazy_df.filter(pl.col("a") > 1).collect().to_dicts() == [{"a": 10}, {"a": 100}] From e93d591cc36cee4fa245e1af76c8cc6012075064 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 6 Feb 2026 14:27:39 +0100 Subject: [PATCH 2/2] Only push down unstrict casts --- duckdb/polars_io.py | 3 +++ tests/fast/arrow/test_polars.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index 23e0062f..85675697 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -162,6 +162,9 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: if node_type == "Cast": cast_tree = tree[node_type] assert isinstance(cast_tree, dict), f"A {node_type} should be a dict but got {type(cast_tree)}" + if cast_tree.get("options") != "NonStrict": + msg = f"Only NonStrict casts can be safely unwrapped, got {cast_tree.get('options')!r}" + raise NotImplementedError(msg) cast_expr = cast_tree["expr"] assert isinstance(cast_expr, dict), f"A {node_type} should be a dict but got {type(cast_expr)}" return _pl_tree_to_sql(cast_expr) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 3d83e1b1..8e6040ae 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -739,3 +739,9 @@ def test_polars_lazy_pushdown_decimal_with_cast(self): assert lazy_df.filter(pl.col("a") == 1).collect().to_dicts() == [{"a": 1}] assert lazy_df.filter(pl.col("a") > 1).collect().to_dicts() == [{"a": 10}, {"a": 100}] + + def test_explicit_cast_not_pushed_down(self): + """Explicit user .cast() (Strict) should not be pushed down - falls back to Polars.""" + # pl.col("a").cast(pl.Int64) produces a Strict Cast node + expr = pl.col("a").cast(pl.Int64) > 5 + invalid_filter(expr)