Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ if(CK_USE_OCP_FP8)
endif()

list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")

if(GPU_TARGETS MATCHES "gfx95")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP)
endif()

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
set(EXE_NAME tile_example_gemm_quant)
Expand Down
38 changes: 21 additions & 17 deletions example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

#include "run_gemm_quant_example.inc"

#if defined(CK_TILE_EIGHTWARP_SUP)
template <typename T>
using GemmConfig = GemmConfigEightWarps<T>;
template <typename T>
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T>;
#else
template <typename T>
using GemmConfig = GemmConfigABQuantPrefill<T>;

template <typename T>
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;

// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
#endif

static auto _ = []() {
auto& lut = get_kernel_lut();
Expand All @@ -23,7 +26,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down Expand Up @@ -53,7 +56,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down Expand Up @@ -83,7 +86,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -98,7 +101,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -113,7 +116,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -128,7 +131,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down Expand Up @@ -173,7 +176,7 @@ static auto _ = []() {
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -188,11 +191,12 @@ static auto _ = []() {
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
return run_gemm_example_prec_type<
GemmConfigPreshuffleB_ABQuant_Prefill<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
return 0;
}();
23 changes: 23 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,29 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
static constexpr bool TransposeC = true;
};

template <typename PrecType>
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Tile = 192;
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType) * K_Warp;

static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
static constexpr int kBlockPerCu = 1;
};

template <typename PrecType>
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType>
{
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
Expand Down
82 changes: 51 additions & 31 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,20 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c =
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;

// Use automatically determined compute type from
using ComputeDataType = void;
constexpr bool IS_FP8BLOCKSCALE =
QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 &&
(std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>) &&
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
constexpr bool transpose_c = GemmConfig::TransposeC;
constexpr bool eight_warps =
IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
GemmConfig::K_Warp_Tile == 128;

using ComputeDataType =
std::conditional_t<IS_FP8BLOCKSCALE, typename TypeConfig::ADataType, void>;

using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
Expand Down Expand Up @@ -71,19 +80,22 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ComputeDataType>;

// Base pipeline selection based on quant mode and preshuffle settings
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::ABQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>>;
constexpr auto base_gemm_pipeline = []() {
if constexpr(eight_warps)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(GemmConfig::PreshuffleB)
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
GemmConfig::APreshuffleQuant)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || IS_FP8BLOCKSCALE)
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
else
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
}();
using BaseGemmPipeline = std::decay_t<decltype(base_gemm_pipeline)>;

const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
Expand Down Expand Up @@ -163,10 +175,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;

using ABQuantPipeline =
using ABQuantPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;

using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
Expand All @@ -185,7 +199,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
Expand Down Expand Up @@ -1136,20 +1149,27 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped &&
!GemmConfig::APreshuffleQuant && BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8))
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Col{}, Col{}, Col{}, Row{});
else
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}

if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if(a_layout == "R" && b_layout == "R")
{
Expand Down
Loading