diff --git a/.lintrunner.toml b/.lintrunner.toml index 2cffdf9c053..a1eec7204bb 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -510,6 +510,7 @@ include_patterns = [ 'backends/arm/vgf/**/*.py', 'backends/arm/tosa/**/*.py', 'backends/arm/ethosu/**/*.py', + 'backends/arm/operators/**/*.py', ] exclude_patterns = ['third-party/**', '**/third-party/**'] command = [ diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 43ef4cd1793..2a8cad85e31 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -2,7 +2,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Provide utilities to register and apply TOSA node visitors. Use this module to construct and serialize TOSA operators from FX nodes. diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 5bcbfb00ca6..e28f332e10a 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -30,8 +30,7 @@ def __init__(self, *args): super().__init__(*args) def _get_tensor_info(self, tensor: Node): - """ - Consolidates obtaining name, dtype and shape into a common function + """Consolidates obtaining name, dtype and shape into a common function reconciling access based on the type of the input. Args: @@ -103,22 +102,25 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - """ + """Flatten index tensors into a single index for value lookup. + This approach uses the fact that all indexing tensors are incremented - simultaneously and they essentially act as a map along the corresponding - dimensions of the values tensor. - Note: that this does not hold true when slicing or ellipsis ops - are involved as such they are not currently not supported. + simultaneously and act as a map along the corresponding dimensions of + the values tensor. + + Note: this does not hold when slicing or ellipsis ops are involved, so + those cases are not currently supported. + + As such, this approach flattens out the values tensor and constructs a + flattened index obtained by flattening the index tensors, multiplying + them by the relevant stride, and accumulating them. - As such this approach flattens out the values tensor and - constructs a flattened out index obtained by flattening out the - index tensors, multiplying them by the relevant stride and accumulating them. + This approach suffers from the fact that we are taking a number of index + tensors of type int32 and applying multiplications and additions. - This approach suffers from the fact that we are taking a number of index tensors of - type int32 and applying multiplications and additions. + If the number of total elements in the values tensor exceeds int32 + limits, then this approach falls apart. - If the number of total elements in the values tensor exceeds int32 limits - then this approach falls apart. """ validate_same_dtype(self.target, [inputs[0], output]) diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index 53dc7444837..46d9d987e19 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -26,9 +26,10 @@ def get_negate_zero_points(node: torch.fx.Node, is_int8: bool) -> tuple[int, int]: - """ - Returns (input1_zp, output_zp) for TOSA NEGATE. + """Returns (input1_zp, output_zp) for TOSA NEGATE. + Must be zero for non-int8 types. + """ if is_int8: return ( diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 90e2b9d2dc4..fe6ca686c52 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -23,14 +23,14 @@ def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: - """ - Converts a permutation vector of length N to a NxN matrix that describes the same permutation. - for example: - (1,0,2) - -> - [0 1 0] - |1 0 0| - [0 0 1] + """Convert a permutation vector of length N to an N x N matrix. + + Example: + (1, 0, 2) -> + [0 1 0] + [1 0 0] + [0 0 1] + """ N = len(permutation_vector) P = torch.zeros(N, N) @@ -40,13 +40,14 @@ def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: - """ - Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation. - [0 1 0] - |1 0 0| - [0 0 1] - -> - (1,0,2) + """Convert an N x N permutation matrix to a permutation vector of length N. + + Example: + [0 1 0] + [1 0 0] + [0 0 1] + -> (1, 0, 2) + """ N = len(permutation_matrix) if N != len(permutation_matrix[0]): diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index 1b3153a9715..8e9a8f6e8f3 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -21,13 +21,13 @@ @register_node_visitor class ToDimOrderCopyVisitor(NodeVisitor): - """ - Implement the type cast functionality of _to_dim_order_copy. + """Implement the type cast functionality of _to_dim_order_copy. Other features like setting of the dim_order or moving a tensor to a different device are not supported. Also note that the node should not be quantized. + """ target = "dim_order_ops._to_dim_order_copy.default" diff --git a/backends/arm/operators/op_tosa_conv3d.py b/backends/arm/operators/op_tosa_conv3d.py index e0a8d2ef6ac..c033314f9a7 100644 --- a/backends/arm/operators/op_tosa_conv3d.py +++ b/backends/arm/operators/op_tosa_conv3d.py @@ -1,8 +1,7 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Provide a visitor for lowering 3D convolution to TOSA (INT/FP).""" from executorch.backends.arm.operators.node_visitor import register_node_visitor diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py index 7df58cc31b3..8b6fbf70e50 100644 --- a/backends/arm/operators/op_tosa_depthwise_conv2d.py +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -2,7 +2,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP).""" import tosa_serializer as ts diff --git a/backends/arm/operators/op_tosa_gather.py b/backends/arm/operators/op_tosa_gather.py index acd20d43f4b..c242d351c06 100644 --- a/backends/arm/operators/op_tosa_gather.py +++ b/backends/arm/operators/op_tosa_gather.py @@ -22,13 +22,13 @@ @register_node_visitor class GatherVisitor(NodeVisitor): - """ - Lowers backend TOSA dialect `tosa.GATHER.default`. + """Lowers backend TOSA dialect `tosa.GATHER.default`. Expected signature (per TOSA): values: [N, K, C] (rank 3) indices: [N, W] (rank 2, int32) output: [N, W, C] (rank 3) + """ target = "tosa.GATHER.default" diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 18b090283fe..a64f8c48032 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Provide a visitor for lowering batched matmul (BMM) to TOSA.""" from typing import Any, List diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index 4585f02c613..8f768005477 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -24,10 +24,11 @@ @register_node_visitor class TransposeVisitor(NodeVisitor): - """ - This node visitor targets the tosa::TRANSPOSE op defined in the - TOSA backend dialect. Used when switching between tosa_dim_orders. - Inserts a TOSA TRANSPOSE. + """Lower the TOSA TRANSPOSE op when switching dim orders. + + Targets the tosa::TRANSPOSE op in the TOSA backend dialect and inserts a + TOSA TRANSPOSE. + """ target = "tosa.TRANSPOSE.default" diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 6b3271ee8e4..e71bbe7b286 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -153,7 +153,9 @@ def validate_valid_dtype( def validate_cf_extension(op_name: str, tosa_spec: TosaSpecification) -> None: - """Ensure that the requested control-flow operator is supported by the active TOSA spec.""" + """Ensure that the requested control-flow operator is supported by the + active TOSA spec. + """ if not isinstance(tosa_spec, Tosa_1_00): raise ValueError( f"Got TOSA version {tosa_spec.version}, that does not support extensions." diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 95d9f6d1cb6..37a8fd226ff 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -26,7 +26,9 @@ def binary_operator_factory( bw_target: str, tosa_op, attr_builder: Callable[[Any], None] ): - """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" + """Creates and registers NodeVisitors for operators that have two inputs and + map directly to a TOSA op. + """ class BinaryOperator(NodeVisitor): target = bw_target diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 27bfaddf901..41c0f4c6bff 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -24,9 +24,8 @@ def identity_operator_factory(identity_target: str): - """ - Creates and registers NodeVisitors for operators that map directly - to a TOSA IDENTITY op. + """Creates and registers NodeVisitors for operators that map directly to a + TOSA IDENTITY op. """ class IdentityOperatorVisitor(NodeVisitor):