From 8b8a9d6ff1d14c1cd05e1e4ab1cce88f3055afc9 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Wed, 4 Feb 2026 15:42:34 -0800 Subject: [PATCH] Add custom slice scatter op that is guaranteed to be inplace. (#17189) Summary: Adds an in-place version of slice scatter that updates the `self` arg with `update` tensor and returns it. Reviewed By: nitish2112 Differential Revision: D92197496 --- backends/cadence/aot/ops_registrations.py | 16 ++++ backends/cadence/aot/ref_implementations.py | 18 ++++- .../aot/tests/test_ref_implementations.py | 77 +++++++++++++++++++ 3 files changed, 108 insertions(+), 3 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index b965082b58f..49732da4ce8 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -667,6 +667,10 @@ def register_fake( "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "slice_scatter_(Tensor(a!) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a!)" +) + # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined aten_lib = Library("aten", "FRAGMENT") @@ -2857,6 +2861,18 @@ def quantized_w8a32_gru_meta( return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32) +@register_fake("cadence::slice_scatter_") +def slice_scatter_meta( + self: torch.Tensor, + src: torch.Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +) -> torch.Tensor: + return self.new_empty(self.shape, dtype=self.dtype) + + # Validate that all meta kernels have reference implementations # This is called at module import time to catch missing implementations early _validate_ref_impl_exists() diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 4faeebfa3eb..8a533c80db1 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,7 +6,8 @@ # pyre-strict -from typing import Callable, Protocol, TypeVar +from pathlib import Path +from typing import Callable, Optional, Protocol, TypeVar import torch import torch.nn as nn @@ -19,8 +20,6 @@ try: torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") except (OSError, RuntimeError): - # Fall back to path-based loading for CMake/OSS builds - from pathlib import Path custom_libs: list[Path] = list( Path(__file__) @@ -2290,3 +2289,16 @@ def sdpa_bitwise_mask_gen(mask: torch.Tensor, threshold: float) -> torch.Tensor: packed_last = last_dim // 8 # Reshape packed to match mask shape, with last dim packed to bytes return packed.view(*original_shape[:-1], packed_last) + + +@impl_tracked(m, "slice_scatter_") +def slice_scatter_impl( + self: torch.Tensor, + src: torch.Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +) -> torch.Tensor: + self[:] = torch.ops.aten.slice_scatter.default(self, src, dim, start, end, step) + return self diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index ccee27f47a5..e0960522c32 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -3149,3 +3149,80 @@ def test_quantized_softmax(self) -> None: input_tensor.shape, "Output shape should match input shape", ) + + @expand( + [ + # Basic 1D slice_scatter tests + ("1d_basic", (10,), (3,), 0, 2, 5, 1), + ("1d_with_step", (10,), (2,), 0, 0, 6, 3), + ("1d_end_slice", (10,), (3,), 0, 7, 10, 1), + # 2D slice_scatter tests + ("2d_dim0", (4, 5), (2, 5), 0, 1, 3, 1), + ("2d_dim1", (4, 5), (4, 2), 1, 2, 4, 1), + ("2d_dim1_with_step", (4, 6), (4, 2), 1, 0, 6, 3), + # 3D slice_scatter tests + ("3d_dim0", (3, 4, 5), (1, 4, 5), 0, 1, 2, 1), + ("3d_dim1", (3, 4, 5), (3, 2, 5), 1, 1, 3, 1), + ("3d_dim2", (3, 4, 5), (3, 4, 2), 2, 2, 4, 1), + ] + ) + def test_slice_scatter_( + self, + name: str, + self_shape: typing.Tuple[int, ...], + src_shape: typing.Tuple[int, ...], + dim: int, + start: int, + end: int, + step: int, + ) -> None: + self_tensor = torch.randn(self_shape) + src_tensor = torch.randn(src_shape) + self_tensor_copy = self_tensor.clone() + + # Call the in-place slice_scatter_ op + torch.ops.cadence.slice_scatter_(self_tensor, src_tensor, dim, start, end, step) + + # Compute expected result using aten slice_scatter + expected = torch.ops.aten.slice_scatter.default( + self_tensor_copy, src_tensor, dim, start, end, step + ) + + self.assertEqual( + self_tensor.shape, + expected.shape, + f"Shape mismatch in {name}", + ) + self.assertTrue( + torch.allclose(self_tensor, expected, rtol=1e-5, atol=1e-5), + f"Values don't match in {name}: got {self_tensor}, expected {expected}", + ) + + def test_slice_scatter_inplace_mutation(self) -> None: + self_tensor = torch.zeros(10) + src_tensor = torch.ones(3) + + ref = self_tensor + + torch.ops.cadence.slice_scatter_(self_tensor, src_tensor, 0, 2, 5, 1) + + self.assertTrue(ref is self_tensor) + + expected = torch.tensor([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + self.assertTrue( + torch.equal(self_tensor, expected), + f"Values don't match: got {self_tensor}, expected {expected}", + ) + + def test_slice_scatter_with_none_start_end(self) -> None: + self_tensor = torch.zeros(10) + src_tensor = torch.ones(10) + + # When start=None and end=None, the entire slice should be replaced + torch.ops.cadence.slice_scatter_(self_tensor, src_tensor, 0, None, None, 1) + + expected = torch.ones(10) + self.assertTrue( + torch.equal(self_tensor, expected), + f"Values don't match: got {self_tensor}, expected {expected}", + )