When using a split-k value above 1, it is not sufficient to count the total number of accumulations since K / split_k accumulations are performed in float, and split_k accumulations are performed in the output data type.
When the output data type is less accurate than float, this leads to additional truncation errors and tolerance must therefore be relaxed. The CPU-verification path shows how this calculation should be performed. It's possible to perform the same calculation in the GPU-verification path by
- Calling
gpu_reduce_max to get the maximum tensor value
- Computing
rtol, atol, rtol_split_k, atol_split_k as in the CPU path
- Calling
gpu_verify with explicit tolerance values
However, it would be simpler if these calculations could be performed in gpu_verify with the user passing in the extra information necessary (split_k, data types, etc.)