diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index ad74b848..abe7e0cb 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -197,10 +197,12 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: "Int16", "Int32", "Int64", + "Int128", "UInt8", "UInt16", "UInt32", "UInt64", + "UInt128", "Float32", "Float64", "Boolean", diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 0eba5eeb..b8a5434a 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -681,6 +681,30 @@ def test_invalid_expr_json(self): with pytest.raises(AssertionError, match="The col name of a Column should be a str but got"): _pl_tree_to_sql(json.loads(bad_type_expr)) + @pytest.mark.parametrize( + ("dtype", "test_value"), + [ + (pl.Int8, 1), + (pl.Int16, 1), + (pl.Int32, 1), + (pl.Int64, 1), + (pl.Int128, 1), + (pl.UInt8, 1), + (pl.UInt16, 1), + (pl.UInt32, 1), + (pl.UInt64, 1), + (pl.UInt128, 1), + (pl.Float32, 1.0), + (pl.Float64, 1.0), + (pl.Boolean, True), + ], + ) + def test_scalar_type_pushdown(self, dtype, test_value): + """Verify that literals of each scalar type can be pushed down.""" + expr = pl.col("a") == pl.lit(test_value, dtype=dtype) + sql_expression = _predicate_to_expression(expr) + assert sql_expression is not None, f"Pushdown failed for {dtype}" + def test_decimal_scale(self): scalar_decimal_no_scale = """ { "Scalar": {