diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 313dc321d37..12b6ec1ae2b 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -696,6 +696,112 @@ inline U load_vector(constant T* x, thread U* x_thread) { return sum; } + template + inline U load_vector_safe(constant T* x, thread U* x_thread, int N) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U sum = 0; + + if (bits == 1) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 7) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 128.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 32.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 8.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 2.0f; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; + } + template inline U qdot( constant uint8_t* w, @@ -838,6 +944,149 @@ inline U qdot( return scale * accum + sum * bias; } + template + inline U qdot_safe( + constant uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U accum = 0; + + if (bits == 1) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + + accum += + (x_thread[0] * (w[i] & 0x01) + + x_thread[1] * (w[i] & 0x02) + + x_thread[2] * (w[i] & 0x04) + + x_thread[3] * (w[i] & 0x08) + + x_thread[4] * (w[i] & 0x10) + + x_thread[5] * (w[i] & 0x20) + + x_thread[6] * (w[i] & 0x40) + + x_thread[7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + constant uint16_t* ws = (constant uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + + accum += (w[1] & 0x03) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + + accum += (w[2] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + + accum += (w[3] & 0x01) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + + accum += (w[4] & 0x07) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 7) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 7 * i; + + accum += (w[0] & 0x7f) * x_thread[0]; + accum += (w[0] & 0x80) * x_thread[1]; + + accum += (w[1] & 0x3f) * (x_thread[1] * 256.0f); + accum += (w[1] & 0xc0) * x_thread[2]; + + accum += (w[2] & 0x1f) * (x_thread[2] * 256.0f); + accum += (w[2] & 0xe0) * x_thread[3]; + + accum += (w[3] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[3] & 0xf0) * x_thread[4]; + + accum += (w[4] & 0x07) * (x_thread[4] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[5]; + + accum += (w[5] & 0x03) * (x_thread[5] * 256.0f); + accum += (w[5] & 0xfc) * x_thread[6]; + + accum += (w[6] & 0x01) * (x_thread[6] * 256.0f); + accum += (w[6] & 0xfe) * x_thread[7]; + } + } + + return scale * accum + sum * bias; + } + template [[kernel]] void qmv_fast( constant T* x [[buffer(0)]], @@ -942,6 +1191,202 @@ inline U qdot( INSTANTIATE_QMV_FAST_DTYPE(bfloat); #endif + /** + * qmv_impl.metal - handles generic N (any even N, not just N % 8 == 0) + */ + + template + [[kernel]] void qmv_impl( + constant T* x [[buffer(0)]], + constant uchar* w [[buffer(1)]], + constant T* scales [[buffer(2)]], + constant T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int in_vec_size = static_cast(sizes.y); // K + const int out_vec_size = static_cast(sizes.z); // N + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int pack_factor = bits == 1 ? 16 : power_of_2_bits ? 32 / bits : bits == 6 ? 4 : 8; + constexpr int bytes_per_pack = bits == 1 ? 2 : power_of_2_bits ? 4 : bits == 6 ? 3 : bits; + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + constant uint8_t* ws = (constant uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + } + + #define INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, NBIT) \ + template [[host_name("qmv_impl_" #NBIT "bit_" #GSIZE "_" #DTYPE)]] kernel void \ + qmv_impl( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + + #define INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, GSIZE) \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 1); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 2); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 3); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 4); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 5); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 6); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 7); + + #define INSTANTIATE_QMV_IMPL_DTYPE(DTYPE) \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 32); \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 64); \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 128); \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 256); + + INSTANTIATE_QMV_IMPL_DTYPE(float); + INSTANTIATE_QMV_IMPL_DTYPE(half); + #if __METAL_VERSION__ >= 310 + INSTANTIATE_QMV_IMPL_DTYPE(bfloat); + #endif + )"; } @@ -2550,8 +2995,8 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( } // Validate N alignment - if (N % 4 != 0) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4, got M=%d, N=%d", M, N); + if (N % 4 != 0 && M != 1) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4 when M != 1, got M=%d, N=%d", M, N); return Error::InvalidArgument; } @@ -2629,11 +3074,16 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( // Select kernel based on dimensions (matching torchao's get_shader_func_and_dispatch) std::string kernel_name; bool use_qmv_fast = (M == 1 && N % 8 == 0 && K % 512 == 0); + bool use_qmv_impl = (M == 1 && !use_qmv_fast); if (use_qmv_fast) { - // Use optimized qmv_fast kernel for M=1 case + // Use optimized qmv_fast kernel for M=1 case with aligned dimensions kernel_name = "qmv_fast_4bit_" + std::to_string(group_size) + "_" + type_str; ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Using qmv_fast kernel: %s", kernel_name.c_str()); + } else if (use_qmv_impl) { + // Use qmv_impl kernel for M=1 case with generic N (handles any even N) + kernel_name = "qmv_impl_4bit_" + std::to_string(group_size) + "_" + type_str; + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Using qmv_impl kernel: %s", kernel_name.c_str()); } else { // Use general int4pack_mm kernel kernel_name = "int4pack_mm_" + std::to_string(group_size) + "_" + type_str; @@ -2717,8 +3167,8 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: All arguments set, dispatching"); // Dispatch based on kernel type (matching torchao dispatch patterns) - if (use_qmv_fast) { - // dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) + if (use_qmv_fast || use_qmv_impl) { + // dispatch_qmv: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str()); kernel_func->dispatchThreadgroups( M, // gridX diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 9d934bb352b..52e65f8c15d 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -274,6 +274,32 @@ def forward(self, x: torch.Tensor): "skip": not TORCHAO_AVAILABLE, } + +# ------------------------------------------------------------------------- +class LinearInt4_QMV_IMPL(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(640, 8198, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_int4_qmv_impl"] = { + "model_class": LinearInt4_QMV_IMPL, + "input_shapes": [(1, 640)], + "description": "Linear int4 quantization dispatching to qmv_impl", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + # ------------------------------------------------------------------------- # Convolution Modules # -------------------------------------------------------------------------