From c16dc590c40f6895b5bf1810fa768c85b5d9ed18 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Feb 2026 19:03:02 -0500 Subject: [PATCH] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.mm | 105 ++++++++++++++---- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index ec884b50776..c484a03c433 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -2492,30 +2492,38 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Converted tensor handles to ET tensors"); - // Log tensor shapes for debugging - ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A shape: [%d, %d], strides: [%d, %d]", - a_tensor->dim() > 0 ? (int)a_tensor->sizes()[0] : 0, - a_tensor->dim() > 1 ? (int)a_tensor->sizes()[1] : 0, - a_tensor->dim() > 0 ? (int)a_tensor->strides()[0] : 0, - a_tensor->dim() > 1 ? (int)a_tensor->strides()[1] : 0); - - ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B shape: [%d, %d]", - b_tensor->dim() > 0 ? (int)b_tensor->sizes()[0] : 0, - b_tensor->dim() > 1 ? (int)b_tensor->sizes()[1] : 0); - - ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S shape: [%d, %d], Z shape: [%d, %d]", - s_tensor->dim() > 0 ? (int)s_tensor->sizes()[0] : 0, - s_tensor->dim() > 1 ? (int)s_tensor->sizes()[1] : 0, - z_tensor->dim() > 0 ? (int)z_tensor->sizes()[0] : 0, - z_tensor->dim() > 1 ? (int)z_tensor->sizes()[1] : 0); - - // Validate tensor dimensions + // Validate A tensor: ndim, dtype, contiguity if (a_tensor->dim() != 2) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor must be 2-D, got %d", (int)a_tensor->dim()); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 2D tensor, got %d", (int)a_tensor->dim()); + return Error::InvalidArgument; + } + auto a_dtype = a_tensor->scalar_type(); + if (a_dtype != exec_aten::ScalarType::Float && + a_dtype != exec_aten::ScalarType::BFloat16) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 32-bit or 16-bit float tensor, got dtype %d", (int)a_dtype); + return Error::InvalidArgument; + } + // Check A is contiguous (stride[1] == 1 and stride[0] == size[1]) + if (a_tensor->strides()[1] != 1 || a_tensor->strides()[0] != a_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be contiguous, strides=[%lld, %lld]", + (long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]); return Error::InvalidArgument; } + + + // Validate B tensor: ndim, dtype (uint8), contiguity if (b_tensor->dim() != 2) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor must be 2-D, got %d", (int)b_tensor->dim()); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be 2D tensor, got %d", (int)b_tensor->dim()); + return Error::InvalidArgument; + } + if (b_tensor->scalar_type() != exec_aten::ScalarType::Byte) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be uint8 tensor, got dtype %d", (int)b_tensor->scalar_type()); + return Error::InvalidArgument; + } + // Check B is contiguous + if (b_tensor->strides()[1] != 1 || b_tensor->strides()[0] != b_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be contiguous, strides=[%lld, %lld]", + (long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]); return Error::InvalidArgument; } @@ -2523,18 +2531,67 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( int32_t M = static_cast(a_tensor->sizes()[0]); int32_t K = static_cast(a_tensor->sizes()[1]); int32_t N = static_cast(b_tensor->sizes()[0]); + constexpr int nbit = 4; ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: M=%d, K=%d, N=%d, group_size=%lld", M, K, N, group_size); - // Validate alignment requirements + // B.size(1) should be (K / 8) * nbit for 4-bit packing + int64_t expected_b_size1 = (K / 8) * nbit; + if (b_tensor->sizes()[1] != expected_b_size1) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B.size(1) == %lld, got %lld", + (long long)expected_b_size1, (long long)b_tensor->sizes()[1]); + return Error::InvalidArgument; + } + + // Validate K alignment if (K % 8 != 0) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: K (%d) must be divisible by 8", K); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect K to be multiple of 8, got %d", K); return Error::InvalidArgument; } + + // Validate N alignment if (N % 4 != 0) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: N (%d) must be divisible by 4", N); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4, got M=%d, N=%d", M, N); + return Error::InvalidArgument; + } + + // Validate S tensor: 2D with S.size(0) == N, contiguous + if (s_tensor->dim() != 2 || s_tensor->sizes()[0] != N) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld", + N, (int)s_tensor->dim(), (long long)s_tensor->sizes()[0]); + return Error::InvalidArgument; + } + if (s_tensor->strides()[1] != 1 || s_tensor->strides()[0] != s_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be contiguous, strides=[%lld, %lld]", + (long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]); + return Error::InvalidArgument; + } + + // Validate Z tensor: 2D with Z.size(0) == N, contiguous + if (z_tensor->dim() != 2 || z_tensor->sizes()[0] != N) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld", + N, (int)z_tensor->dim(), (long long)z_tensor->sizes()[0]); return Error::InvalidArgument; } + if (z_tensor->strides()[1] != 1 || z_tensor->strides()[0] != z_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be contiguous, strides=[%lld, %lld]", + (long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]); + return Error::InvalidArgument; + } + + // Log shapes and strides for all tensors + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)a_tensor->sizes()[0], (long long)a_tensor->sizes()[1], + (long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]); + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)b_tensor->sizes()[0], (long long)b_tensor->sizes()[1], + (long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]); + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)s_tensor->sizes()[0], (long long)s_tensor->sizes()[1], + (long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]); + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Z tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)z_tensor->sizes()[0], (long long)z_tensor->sizes()[1], + (long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]); // Determine data type int32_t dtype = static_cast(a_tensor->scalar_type()); @@ -2652,6 +2709,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( // Dispatch based on kernel type (matching torchao dispatch patterns) if (use_qmv_fast) { // dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str()); kernel_func->dispatchThreadgroups( M, // gridX (N + 7) / 8, // gridY @@ -2661,6 +2719,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( 1); // threadsZ } else { // dispatch_mm_Mr1xNr4_per_TG: dispatchThreads with grid (N/4 * 32, 1, M), group (32, 1, 1) + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str()); uint64_t grid_dims[3] = {static_cast(N / 4 * 32), 1, static_cast(M)}; uint64_t group_dims[3] = {32, 1, 1}; kernel_func->dispatchArrayWithGroupSize(grid_dims, 3, group_dims, 3);