Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 82 additions & 23 deletions backends/apple/metal/runtime/shims/et_metal_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -2492,49 +2492,106 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight(

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Converted tensor handles to ET tensors");

// Log tensor shapes for debugging
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A shape: [%d, %d], strides: [%d, %d]",
a_tensor->dim() > 0 ? (int)a_tensor->sizes()[0] : 0,
a_tensor->dim() > 1 ? (int)a_tensor->sizes()[1] : 0,
a_tensor->dim() > 0 ? (int)a_tensor->strides()[0] : 0,
a_tensor->dim() > 1 ? (int)a_tensor->strides()[1] : 0);

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B shape: [%d, %d]",
b_tensor->dim() > 0 ? (int)b_tensor->sizes()[0] : 0,
b_tensor->dim() > 1 ? (int)b_tensor->sizes()[1] : 0);

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S shape: [%d, %d], Z shape: [%d, %d]",
s_tensor->dim() > 0 ? (int)s_tensor->sizes()[0] : 0,
s_tensor->dim() > 1 ? (int)s_tensor->sizes()[1] : 0,
z_tensor->dim() > 0 ? (int)z_tensor->sizes()[0] : 0,
z_tensor->dim() > 1 ? (int)z_tensor->sizes()[1] : 0);

// Validate tensor dimensions
// Validate A tensor: ndim, dtype, contiguity
if (a_tensor->dim() != 2) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor must be 2-D, got %d", (int)a_tensor->dim());
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 2D tensor, got %d", (int)a_tensor->dim());
return Error::InvalidArgument;
}
auto a_dtype = a_tensor->scalar_type();
if (a_dtype != exec_aten::ScalarType::Float &&
a_dtype != exec_aten::ScalarType::BFloat16) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 32-bit or 16-bit float tensor, got dtype %d", (int)a_dtype);
return Error::InvalidArgument;
}
// Check A is contiguous (stride[1] == 1 and stride[0] == size[1])
if (a_tensor->strides()[1] != 1 || a_tensor->strides()[0] != a_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be contiguous, strides=[%lld, %lld]",
(long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]);
return Error::InvalidArgument;
}


// Validate B tensor: ndim, dtype (uint8), contiguity
if (b_tensor->dim() != 2) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor must be 2-D, got %d", (int)b_tensor->dim());
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be 2D tensor, got %d", (int)b_tensor->dim());
return Error::InvalidArgument;
}
if (b_tensor->scalar_type() != exec_aten::ScalarType::Byte) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be uint8 tensor, got dtype %d", (int)b_tensor->scalar_type());
return Error::InvalidArgument;
}
// Check B is contiguous
if (b_tensor->strides()[1] != 1 || b_tensor->strides()[0] != b_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be contiguous, strides=[%lld, %lld]",
(long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]);
return Error::InvalidArgument;
}

// Get dimensions: A is [M, K], B is [N, K/2] (4-bit packed, 2 values per byte)
int32_t M = static_cast<int32_t>(a_tensor->sizes()[0]);
int32_t K = static_cast<int32_t>(a_tensor->sizes()[1]);
int32_t N = static_cast<int32_t>(b_tensor->sizes()[0]);
constexpr int nbit = 4;

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: M=%d, K=%d, N=%d, group_size=%lld", M, K, N, group_size);

// Validate alignment requirements
// B.size(1) should be (K / 8) * nbit for 4-bit packing
int64_t expected_b_size1 = (K / 8) * nbit;
if (b_tensor->sizes()[1] != expected_b_size1) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B.size(1) == %lld, got %lld",
(long long)expected_b_size1, (long long)b_tensor->sizes()[1]);
return Error::InvalidArgument;
}

// Validate K alignment
if (K % 8 != 0) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: K (%d) must be divisible by 8", K);
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect K to be multiple of 8, got %d", K);
return Error::InvalidArgument;
}

// Validate N alignment
if (N % 4 != 0) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: N (%d) must be divisible by 4", N);
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4, got M=%d, N=%d", M, N);
return Error::InvalidArgument;
}

// Validate S tensor: 2D with S.size(0) == N, contiguous
if (s_tensor->dim() != 2 || s_tensor->sizes()[0] != N) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld",
N, (int)s_tensor->dim(), (long long)s_tensor->sizes()[0]);
return Error::InvalidArgument;
}
if (s_tensor->strides()[1] != 1 || s_tensor->strides()[0] != s_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be contiguous, strides=[%lld, %lld]",
(long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]);
return Error::InvalidArgument;
}

// Validate Z tensor: 2D with Z.size(0) == N, contiguous
if (z_tensor->dim() != 2 || z_tensor->sizes()[0] != N) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld",
N, (int)z_tensor->dim(), (long long)z_tensor->sizes()[0]);
return Error::InvalidArgument;
}
if (z_tensor->strides()[1] != 1 || z_tensor->strides()[0] != z_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be contiguous, strides=[%lld, %lld]",
(long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]);
return Error::InvalidArgument;
}

// Log shapes and strides for all tensors
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)a_tensor->sizes()[0], (long long)a_tensor->sizes()[1],
(long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]);
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)b_tensor->sizes()[0], (long long)b_tensor->sizes()[1],
(long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]);
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)s_tensor->sizes()[0], (long long)s_tensor->sizes()[1],
(long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]);
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Z tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)z_tensor->sizes()[0], (long long)z_tensor->sizes()[1],
(long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]);

// Determine data type
int32_t dtype = static_cast<int32_t>(a_tensor->scalar_type());
Expand Down Expand Up @@ -2652,6 +2709,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight(
// Dispatch based on kernel type (matching torchao dispatch patterns)
if (use_qmv_fast) {
// dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1)
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str());
kernel_func->dispatchThreadgroups(
M, // gridX
(N + 7) / 8, // gridY
Expand All @@ -2661,6 +2719,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight(
1); // threadsZ
} else {
// dispatch_mm_Mr1xNr4_per_TG: dispatchThreads with grid (N/4 * 32, 1, M), group (32, 1, 1)
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str());
uint64_t grid_dims[3] = {static_cast<uint64_t>(N / 4 * 32), 1, static_cast<uint64_t>(M)};
uint64_t group_dims[3] = {32, 1, 1};
kernel_func->dispatchArrayWithGroupSize(grid_dims, 3, group_dims, 3);
Expand Down
Loading