[WIP] Remove redundant casts in LLVMIR#2202
Conversation
| %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> | ||
| llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5> | ||
| %8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 | ||
| %9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> |
There was a problem hiding this comment.
why are there so many repeated llvm.fptrunc %1? I don't understand this.
Wouldn't it be easier to do this earlier? for example handling arith.extf, etc?
There was a problem hiding this comment.
The fptrunc's seem to be the result of loop unrolling. They are all writing to the same buffer. I was doing this earlier and there are quite a few more difficulties with moving this pass somewhere right after GridwiseGemmToBlockwise. The dominance analysis (used for safety) becomes tricky because it doesn't work well when the trunc/ext ops are in different regions, and also having to rewrite the linalg generic makes things more difficult as well.
Both approaches have their pros and cons. We can discuss this more in the team meeting, or elsewhere offline.
| // - If the narrow buffer has no remaining uses, erase the fptrunc stores | ||
| // - These can only be erased if they are not used by any other | ||
| // operations | ||
| // - Erase the narrow alloca if it has no remaining uses |
There was a problem hiding this comment.
do we have tests for these two cases?
|
|
||
| // Look for existing parallel wide store | ||
| for (Operation *wideUser : wideValue.getUsers()) { | ||
| auto wideStore = dyn_cast<StoreOp>(wideUser); |
There was a problem hiding this comment.
should we check it's the source instead of destination?
| info.wideStore = nullptr; | ||
|
|
||
| // Look for existing parallel wide store | ||
| for (Operation *wideUser : wideValue.getUsers()) { |
There was a problem hiding this comment.
nit: this could be done outside of this loop and store the results in a SmallVector?
| << wideStore << "\n"); | ||
| info.wideBuffer = wideBuffer; | ||
| info.wideStore = wideStore; | ||
| break; |
There was a problem hiding this comment.
why do we store only the first one?
Motivation
When processing mixed-precision computations (e.g., attention kernels with f32 intermediate values stored as f16), the generated IR often contains redundant precision conversion patterns:
This pattern causes unnecessary precision loss compared to just keeping the original wide value. This pass eliminates these redundant casts by redirectoing loads to read from a parallel wide buffer when possible.
This implements: https://github.com/ROCm/rocMLIR-internal/issues/1932
Technical Details
This PR introduces the
RemoveRedundantCastspass that operates at the LLVMIR dialect level to optimize fptrunc -> store -> load -> fpext patterns.General Algorithm:
Test Plan
Test Result
Submission Checklist