From a08b76c39f69b75993d35c7a7e1eef9430f268c6 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Thu, 25 Dec 2025 19:12:32 +0200 Subject: [PATCH 1/4] Optimize column_encryption_policy checks in recv_results_rows There's no point in checking a global policy for every single value decoding, not for every row decoded. Adjusted the code to only check it once per recv_results_rows() call - decode_row() should be defined either as is today with column_encryption_policy enabled, or much simpler without all those extra checks. Added a unit test from CoPilot. Fixes: https://github.com/scylladb/python-driver/issues/582 Signed-off-by: Yaniv Kaul --- cassandra/protocol.py | 31 +++- .../unit/test_protocol_decode_optimization.py | 155 ++++++++++++++++++ 2 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_protocol_decode_optimization.py diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e574965de8..5f77818c70 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -719,24 +719,37 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] self.column_names = [c[2] for c in column_metadata] self.column_types = [c[3] for c in column_metadata] - col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_val(val, col_md, col_desc): - uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) - col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] - raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) + if column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_row(row): - return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + def decode_val(val, col_md, col_desc): + uses_ce = column_encryption_policy.contains_column(col_desc) + if uses_ce: + col_type = column_encryption_policy.column_type(col_desc) + raw_bytes = column_encryption_policy.decrypt(col_desc, val) + return col_type.from_binary(raw_bytes, protocol_version) + else: + return col_md[3].from_binary(val, protocol_version) + + def decode_row(row): + return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + else: + def decode_row(row): + return tuple(col_md[3].from_binary(val, protocol_version) for val, col_md in zip(row, column_metadata)) try: self.parsed_rows = [decode_row(row) for row in rows] except Exception: + if not column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] for row in rows: for val, col_md, col_desc in zip(row, column_metadata, col_descs): try: - decode_val(val, col_md, col_desc) + if column_encryption_policy: + decode_val(val, col_md, col_desc) + else: + col_md[3].from_binary(val, protocol_version) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], col_md[3].cql_parameterized_type(), diff --git a/tests/unit/test_protocol_decode_optimization.py b/tests/unit/test_protocol_decode_optimization.py new file mode 100644 index 0000000000..e0fd81fe3e --- /dev/null +++ b/tests/unit/test_protocol_decode_optimization.py @@ -0,0 +1,155 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import unittest +from unittest.mock import Mock + +from cassandra import ProtocolVersion +from cassandra.cqltypes import Int32Type, UTF8Type +from cassandra.marshal import int32_pack +from cassandra.policies import ColDesc +from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS + + +class DecodeOptimizationTest(unittest.TestCase): + """ + Tests to verify the optimization of column_encryption_policy checks + in recv_results_rows. The optimization checks if the policy exists once + per result message, avoiding the redundant 'column_encryption_policy and ...' + check for every value. + """ + + def _create_mock_result_metadata(self): + """Create mock result metadata for testing""" + return [ + ('keyspace1', 'table1', 'col1', Int32Type), + ('keyspace1', 'table1', 'col2', UTF8Type), + ] + + def _create_mock_result_message(self): + """Create a mock result message with data""" + msg = ResultMessage(kind=RESULT_KIND_ROWS) + msg.column_metadata = self._create_mock_result_metadata() + msg.recv_results_metadata = Mock() + msg.recv_row = Mock(side_effect=[ + [int32_pack(42), b'hello'], + [int32_pack(100), b'world'], + ]) + return msg + + def _create_mock_stream(self): + """Create a mock stream for reading rows""" + # Pack rowcount (2 rows) + data = int32_pack(2) + return io.BytesIO(data) + + def test_decode_without_encryption_policy(self): + """ + Test that decoding works correctly without column encryption policy. + This should use the optimized simple path. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + self.assertEqual(msg.parsed_rows[1][0], 100) + self.assertEqual(msg.parsed_rows[1][1], 'world') + + def test_decode_with_encryption_policy_no_encrypted_columns(self): + """ + Test that decoding works with encryption policy when no columns are encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy that has no encrypted columns + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + def test_decode_with_encryption_policy_with_encrypted_column(self): + """ + Test that decoding works with encryption policy when one column is encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy where first column is encrypted + mock_policy = Mock() + def contains_column_side_effect(col_desc): + return col_desc.col == 'col1' + mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) + mock_policy.column_type = Mock(return_value=Int32Type) + mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) + self.assertEqual(mock_policy.decrypt.call_count, 2) + + def test_optimization_efficiency(self): + """ + Verify that the optimization checks policy existence once per result message. + The key optimization is checking 'if column_encryption_policy:' once, + rather than 'column_encryption_policy and ...' for every value. + """ + msg = self._create_mock_result_message() + + # Create more rows to make the check pattern clear + msg.recv_row = Mock(side_effect=[ + [int32_pack(i), f'text{i}'.encode()] for i in range(100) + ]) + + # Create mock stream with 100 rows + f = io.BytesIO(int32_pack(100)) + + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # With optimization: policy existence checked once, contains_column called per value + # = 100 rows * 2 columns = 200 calls to contains_column + # The key is we avoid checking 'column_encryption_policy and ...' 200 times + self.assertEqual(mock_policy.contains_column.call_count, 200, + "contains_column should be called for each value when policy exists") + + +if __name__ == '__main__': + unittest.main() From 7325c7d00287bc688bf478755f8e9aa3ca995807 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 4 Jan 2026 17:59:21 +0200 Subject: [PATCH 2/4] (improvement)Optimize column_encryption_policy checks in Cython's unpack_row() function Very similar to the native Python code, separate the two cases, if column encryption (CE) policy is not enabled, the code is substantially simplified. If it is, it's slightly more elaborate. Decided to have two loops in two functions, one for each case, for performance reasons, even if readability-wise it's not as great. AI agreed with me: Recommendation: Keep it as is. In high-performance Cython code like this, duplicating a small block of code Fixes: https://github.com/scylladb/python-driver/issues/639 Signed-off-by: Yaniv Kaul --- cassandra/obj_parser.pyx | 47 ++++++++++++++++++++++++++++++++++------ cassandra/row_parser.pyx | 8 +++++-- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index cf43771dd7..2d366fc5bb 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -31,7 +31,10 @@ cdef class ListParser(ColumnParser): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] + if desc.column_encryption_policy: + return [rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)] + else: + return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] cdef class LazyParser(ColumnParser): @@ -47,7 +50,10 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) + if desc.column_encryption_policy: + return (rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)) + else: + return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) cdef class TupleRowParser(RowParser): @@ -55,9 +61,11 @@ cdef class TupleRowParser(RowParser): Parse a single returned row into a tuple of objects: (obj1, ..., objN) + If CE (Column encryption) policy is enabled - use unpack_ce_row(), + otherwsise use unpack_row() """ - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_ce_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf @@ -73,9 +81,9 @@ cdef class TupleRowParser(RowParser): # Deserialize bytes to python object deserializer = desc.deserializers[i] - coldesc = desc.coldescs[i] - uses_ce = ce_policy and ce_policy.contains_column(coldesc) try: + coldesc = desc.coldescs[i] + uses_ce = ce_policy.contains_column(coldesc) if uses_ce: col_type = ce_policy.column_type(coldesc) decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) @@ -84,11 +92,36 @@ cdef class TupleRowParser(RowParser): val = from_binary(deserializer, &newbuf, desc.protocol_version) else: val = from_binary(deserializer, &buf, desc.protocol_version) + # Insert new object into tuple + tuple_set(res, i, val) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) + + return res + + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + assert desc.rowsize >= 0 + + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef Deserializer deserializer + cdef tuple res = tuple_new(desc.rowsize) + + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + try: + val = from_binary(deserializer, &buf, desc.protocol_version) + # Insert new object into tuple + tuple_set(res, i, val) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], desc.coltypes[i].cql_parameterized_type(), str(e))) - # Insert new object into tuple - tuple_set(res, i, val) return res diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..1308f5b2ce 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -44,7 +44,11 @@ def make_recv_results_rows(ColumnParser colparser): reader.buf_ptr = reader.buf reader.pos = 0 rowcount = read_int(reader) - for i in range(rowcount): - rowparser.unpack_row(reader, desc) + if desc.column_encryption_policy: + for i in range(rowcount): + rowparser.unpack_ce_row(reader, desc) + else: + for i in range(rowcount): + rowparser.unpack_row(reader, desc) return recv_results_rows From 49b8047d13c230edfb522e1c2c465be44be871ef Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Mon, 5 Jan 2026 11:07:44 +0200 Subject: [PATCH 3/4] (improvement)Optimize column_encryption_policy checks: tests Add tests, respond to review feedback on added tests. Signed-off-by: Yaniv Kaul --- tests/unit/test_protocol.py | 135 ++++++++++++++- .../unit/test_protocol_decode_optimization.py | 155 ------------------ 2 files changed, 133 insertions(+), 157 deletions(-) delete mode 100644 tests/unit/test_protocol_decode_optimization.py diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..ea12fa7b5a 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -12,22 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import unittest from unittest.mock import Mock from cassandra import ProtocolVersion, UnsupportedOperation +from cassandra.cqltypes import Int32Type, UTF8Type from cassandra.protocol import ( PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, + ResultMessage, RESULT_KIND_ROWS ) from cassandra.query import BatchType -from cassandra.marshal import uint32_unpack +from cassandra.marshal import uint32_unpack, int32_pack from cassandra.cluster import ContinuousPagingOptions import pytest +from cassandra.policies import ColDesc class MessageTest(unittest.TestCase): @@ -189,3 +193,130 @@ def test_batch_message_with_keyspace(self): (b'\x00\x03',), (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) ) + +class ResultTest(unittest.TestCase): + """ + Tests to verify the optimization of column_encryption_policy checks + in recv_results_rows. The optimization checks if the policy exists once + per result message, avoiding the redundant 'column_encryption_policy and ...' + check for every value. + """ + + def _create_mock_result_metadata(self): + """Create mock result metadata for testing""" + return [ + ('keyspace1', 'table1', 'col1', Int32Type), + ('keyspace1', 'table1', 'col2', UTF8Type), + ] + + def _create_mock_result_message(self): + """Create a mock result message with data""" + msg = ResultMessage(kind=RESULT_KIND_ROWS) + msg.column_metadata = self._create_mock_result_metadata() + msg.recv_results_metadata = Mock() + msg.recv_row = Mock(side_effect=[ + [int32_pack(42), b'hello'], + [int32_pack(100), b'world'], + ]) + return msg + + def _create_mock_stream(self): + """Create a mock stream for reading rows""" + # Pack rowcount (2 rows) + data = int32_pack(2) + return io.BytesIO(data) + + def test_decode_without_encryption_policy(self): + """ + Test that decoding works correctly without column encryption policy. + This should use the optimized simple path. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + self.assertEqual(msg.parsed_rows[1][0], 100) + self.assertEqual(msg.parsed_rows[1][1], 'world') + + def test_decode_with_encryption_policy_no_encrypted_columns(self): + """ + Test that decoding works with encryption policy when no columns are encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy that has no encrypted columns + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + def test_decode_with_encryption_policy_with_encrypted_column(self): + """ + Test that decoding works with encryption policy when one column is encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy where first column is encrypted + mock_policy = Mock() + def contains_column_side_effect(col_desc): + return col_desc.col == 'col1' + mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) + mock_policy.column_type = Mock(return_value=Int32Type) + mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) + self.assertEqual(mock_policy.decrypt.call_count, 2) + + def test_optimization_efficiency(self): + """ + Verify that the optimization checks policy existence once per result message. + The key optimization is checking 'if column_encryption_policy:' once, + rather than 'column_encryption_policy and ...' for every value. + """ + msg = self._create_mock_result_message() + + # Create more rows to make the check pattern clear + msg.recv_row = Mock(side_effect=[ + [int32_pack(i), f'text{i}'.encode()] for i in range(100) + ]) + + # Create mock stream with 100 rows + f = io.BytesIO(int32_pack(100)) + + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # With optimization: policy existence checked once, contains_column called per value + # = 100 rows * 2 columns = 200 calls to contains_column + # The key is we avoid checking 'column_encryption_policy and ...' 200 times + self.assertEqual(mock_policy.contains_column.call_count, 200, + "contains_column should be called for each value when policy exists") diff --git a/tests/unit/test_protocol_decode_optimization.py b/tests/unit/test_protocol_decode_optimization.py deleted file mode 100644 index e0fd81fe3e..0000000000 --- a/tests/unit/test_protocol_decode_optimization.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import unittest -from unittest.mock import Mock - -from cassandra import ProtocolVersion -from cassandra.cqltypes import Int32Type, UTF8Type -from cassandra.marshal import int32_pack -from cassandra.policies import ColDesc -from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS - - -class DecodeOptimizationTest(unittest.TestCase): - """ - Tests to verify the optimization of column_encryption_policy checks - in recv_results_rows. The optimization checks if the policy exists once - per result message, avoiding the redundant 'column_encryption_policy and ...' - check for every value. - """ - - def _create_mock_result_metadata(self): - """Create mock result metadata for testing""" - return [ - ('keyspace1', 'table1', 'col1', Int32Type), - ('keyspace1', 'table1', 'col2', UTF8Type), - ] - - def _create_mock_result_message(self): - """Create a mock result message with data""" - msg = ResultMessage(kind=RESULT_KIND_ROWS) - msg.column_metadata = self._create_mock_result_metadata() - msg.recv_results_metadata = Mock() - msg.recv_row = Mock(side_effect=[ - [int32_pack(42), b'hello'], - [int32_pack(100), b'world'], - ]) - return msg - - def _create_mock_stream(self): - """Create a mock stream for reading rows""" - # Pack rowcount (2 rows) - data = int32_pack(2) - return io.BytesIO(data) - - def test_decode_without_encryption_policy(self): - """ - Test that decoding works correctly without column encryption policy. - This should use the optimized simple path. - """ - msg = self._create_mock_result_message() - f = self._create_mock_stream() - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) - - # Verify results - self.assertEqual(len(msg.parsed_rows), 2) - self.assertEqual(msg.parsed_rows[0][0], 42) - self.assertEqual(msg.parsed_rows[0][1], 'hello') - self.assertEqual(msg.parsed_rows[1][0], 100) - self.assertEqual(msg.parsed_rows[1][1], 'world') - - def test_decode_with_encryption_policy_no_encrypted_columns(self): - """ - Test that decoding works with encryption policy when no columns are encrypted. - """ - msg = self._create_mock_result_message() - f = self._create_mock_stream() - - # Create mock encryption policy that has no encrypted columns - mock_policy = Mock() - mock_policy.contains_column = Mock(return_value=False) - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) - - # Verify results - self.assertEqual(len(msg.parsed_rows), 2) - self.assertEqual(msg.parsed_rows[0][0], 42) - self.assertEqual(msg.parsed_rows[0][1], 'hello') - - # Verify contains_column was called for each value (but policy existence check happens once) - # Should be called 4 times (2 rows × 2 columns) - self.assertEqual(mock_policy.contains_column.call_count, 4) - - def test_decode_with_encryption_policy_with_encrypted_column(self): - """ - Test that decoding works with encryption policy when one column is encrypted. - """ - msg = self._create_mock_result_message() - f = self._create_mock_stream() - - # Create mock encryption policy where first column is encrypted - mock_policy = Mock() - def contains_column_side_effect(col_desc): - return col_desc.col == 'col1' - mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) - mock_policy.column_type = Mock(return_value=Int32Type) - mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) - - # Verify results - self.assertEqual(len(msg.parsed_rows), 2) - self.assertEqual(msg.parsed_rows[0][0], 42) - self.assertEqual(msg.parsed_rows[0][1], 'hello') - - # Verify contains_column was called for each value (but policy existence check happens once) - # Should be called 4 times (2 rows × 2 columns) - self.assertEqual(mock_policy.contains_column.call_count, 4) - - # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) - self.assertEqual(mock_policy.decrypt.call_count, 2) - - def test_optimization_efficiency(self): - """ - Verify that the optimization checks policy existence once per result message. - The key optimization is checking 'if column_encryption_policy:' once, - rather than 'column_encryption_policy and ...' for every value. - """ - msg = self._create_mock_result_message() - - # Create more rows to make the check pattern clear - msg.recv_row = Mock(side_effect=[ - [int32_pack(i), f'text{i}'.encode()] for i in range(100) - ]) - - # Create mock stream with 100 rows - f = io.BytesIO(int32_pack(100)) - - mock_policy = Mock() - mock_policy.contains_column = Mock(return_value=False) - - msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) - - # With optimization: policy existence checked once, contains_column called per value - # = 100 rows * 2 columns = 200 calls to contains_column - # The key is we avoid checking 'column_encryption_policy and ...' 200 times - self.assertEqual(mock_policy.contains_column.call_count, 200, - "contains_column should be called for each value when policy exists") - - -if __name__ == '__main__': - unittest.main() From 132859f30d3f060374f8e55c402ee1836843bb8f Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 18 Jan 2026 14:47:18 +0200 Subject: [PATCH 4/4] query: split CE bind path Split BoundStatement.bind() into CE and non-CE loops to avoid per-value CE checks when no policy is configured. In the CE loop, use a single uses_ce branch to select type serialization and optional encryption for each column. Signed-off-by: Yaniv Kaul --- cassandra/query.py | 65 +++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..fd165469a2 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -636,28 +636,51 @@ def bind(self, values): self.raw_values = values self.values = [] - for value, col_spec in zip(values, col_meta): - if value is None: - self.values.append(None) - elif value is UNSET_VALUE: - if proto_version >= 4: - self._append_unset_value() + if ce_policy: + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) - else: - try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) - uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) - if uses_ce: - col_bytes = ce_policy.encrypt(col_desc, col_bytes) - self.values.append(col_bytes) - except (TypeError, struct.error) as exc: - actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) - raise TypeError(message) + try: + col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + uses_ce = ce_policy.contains_column(col_desc) + if uses_ce: + col_type = ce_policy.column_type(col_desc) + col_bytes = col_type.serialize(value, proto_version) + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + else: + col_type = col_spec.type + col_bytes = col_type.serialize(value, proto_version) + self.values.append(col_bytes) + except (TypeError, struct.error) as exc: + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) + else: + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + else: + try: + col_type = col_spec.type + col_bytes = col_type.serialize(value, proto_version) + self.values.append(col_bytes) + except (TypeError, struct.error) as exc: + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) if proto_version >= 4: diff = col_meta_len - len(self.values)