Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,10 @@ def register_fake(
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"slice_scatter_(Tensor(a!) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a!)"
)


# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
aten_lib = Library("aten", "FRAGMENT")
Expand Down Expand Up @@ -2857,6 +2861,18 @@ def quantized_w8a32_gru_meta(
return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32)


@register_fake("cadence::slice_scatter_")
def slice_scatter_meta(
self: torch.Tensor,
src: torch.Tensor,
dim: int = 0,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
) -> torch.Tensor:
return self.new_empty(self.shape, dtype=self.dtype)


# Validate that all meta kernels have reference implementations
# This is called at module import time to catch missing implementations early
_validate_ref_impl_exists()
18 changes: 15 additions & 3 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

# pyre-strict

from typing import Callable, Protocol, TypeVar
from pathlib import Path
from typing import Callable, Optional, Protocol, TypeVar

import torch
import torch.nn as nn
Expand All @@ -19,8 +20,6 @@
try:
torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")
except (OSError, RuntimeError):
# Fall back to path-based loading for CMake/OSS builds
from pathlib import Path

custom_libs: list[Path] = list(
Path(__file__)
Expand Down Expand Up @@ -2290,3 +2289,16 @@ def sdpa_bitwise_mask_gen(mask: torch.Tensor, threshold: float) -> torch.Tensor:
packed_last = last_dim // 8
# Reshape packed to match mask shape, with last dim packed to bytes
return packed.view(*original_shape[:-1], packed_last)


@impl_tracked(m, "slice_scatter_")
def slice_scatter_impl(
self: torch.Tensor,
src: torch.Tensor,
dim: int = 0,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
) -> torch.Tensor:
self[:] = torch.ops.aten.slice_scatter.default(self, src, dim, start, end, step)
return self
77 changes: 77 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3149,3 +3149,80 @@ def test_quantized_softmax(self) -> None:
input_tensor.shape,
"Output shape should match input shape",
)

@expand(
[
# Basic 1D slice_scatter tests
("1d_basic", (10,), (3,), 0, 2, 5, 1),
("1d_with_step", (10,), (2,), 0, 0, 6, 3),
("1d_end_slice", (10,), (3,), 0, 7, 10, 1),
# 2D slice_scatter tests
("2d_dim0", (4, 5), (2, 5), 0, 1, 3, 1),
("2d_dim1", (4, 5), (4, 2), 1, 2, 4, 1),
("2d_dim1_with_step", (4, 6), (4, 2), 1, 0, 6, 3),
# 3D slice_scatter tests
("3d_dim0", (3, 4, 5), (1, 4, 5), 0, 1, 2, 1),
("3d_dim1", (3, 4, 5), (3, 2, 5), 1, 1, 3, 1),
("3d_dim2", (3, 4, 5), (3, 4, 2), 2, 2, 4, 1),
]
)
def test_slice_scatter_(
self,
name: str,
self_shape: typing.Tuple[int, ...],
src_shape: typing.Tuple[int, ...],
dim: int,
start: int,
end: int,
step: int,
) -> None:
self_tensor = torch.randn(self_shape)
src_tensor = torch.randn(src_shape)
self_tensor_copy = self_tensor.clone()

# Call the in-place slice_scatter_ op
torch.ops.cadence.slice_scatter_(self_tensor, src_tensor, dim, start, end, step)

# Compute expected result using aten slice_scatter
expected = torch.ops.aten.slice_scatter.default(
self_tensor_copy, src_tensor, dim, start, end, step
)

self.assertEqual(
self_tensor.shape,
expected.shape,
f"Shape mismatch in {name}",
)
self.assertTrue(
torch.allclose(self_tensor, expected, rtol=1e-5, atol=1e-5),
f"Values don't match in {name}: got {self_tensor}, expected {expected}",
)

def test_slice_scatter_inplace_mutation(self) -> None:
self_tensor = torch.zeros(10)
src_tensor = torch.ones(3)

ref = self_tensor

torch.ops.cadence.slice_scatter_(self_tensor, src_tensor, 0, 2, 5, 1)

self.assertTrue(ref is self_tensor)

expected = torch.tensor([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
self.assertTrue(
torch.equal(self_tensor, expected),
f"Values don't match: got {self_tensor}, expected {expected}",
)

def test_slice_scatter_with_none_start_end(self) -> None:
self_tensor = torch.zeros(10)
src_tensor = torch.ones(10)

# When start=None and end=None, the entire slice should be replaced
torch.ops.cadence.slice_scatter_(self_tensor, src_tensor, 0, None, None, 1)

expected = torch.ones(10)
self.assertTrue(
torch.equal(self_tensor, expected),
f"Values don't match: got {self_tensor}, expected {expected}",
)
Loading