diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 301666f45c6..30c243aedd5 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -523,6 +523,16 @@ class vTensor final { return packed_dim_info_.packed_dim; } + /* + * Returns the WHCN index of the fastest moving dimension (dim_order[0]). + * This is the dimension with stride 1 in the buffer layout. + * Note: dim_order_ is in NCHW order, so we convert to WHCN (3 - nchw_dim). + */ + inline int32_t fastest_whcn_dim() const { + return packed_dim_info_.block_transposed ? packed_dim_info_.outer_packed_dim + : packed_dim_info_.packed_dim; + } + inline const PackedDimInfo& packed_dim_info() const { return packed_dim_info_; } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 2231b78964a..a7c8cffffd1 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -456,6 +456,10 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().packed_dim(); } + inline int32_t fastest_whcn_dim_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().fastest_whcn_dim(); + } + inline const api::PackedDimInfo& packed_dim_info_of( const ValueRef idx) const { return values_.at(idx).toConstTensor().packed_dim_info(); diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 24f050b694e..16c4112547c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -70,6 +70,8 @@ bool is_channels_last(const int hashed_layout) { return layout_id(hashed_layout) == CHANNELS_LAST_BUFFER_LAYOUT_ID; } +#define get_fastest_dim(layout) ((layout) & 0xF) + // Extract packed dim info from hashed_layout (bits 16-31) // These match the format created by create_hashed_layout() in Tensor.cpp #define get_packed_dim(layout) (((layout) >> 16) & 0xF)