-
Notifications
You must be signed in to change notification settings - Fork 75
Use symmetric memory in matmul+allreduce reference implementation #5837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Stream Pool Logic Change
stream_pool.get(i) to stream_pool.get(i % 2), which reduces the number of streams used from potentially many to just 2. This could impact performance and should be validated to ensure it doesn't cause performance regressions. |
|
cc @kwen2501 let me know if you have any suggestions! In case you'd like to run the microbenchmark by yourself, the reference implementation above doesn't depend on nvFuser at all -- you should be able to simply |
|
Hi @wujingyue I was thinking of all-gather bc it can use CE while all-reduce still needs SMs (reduction still requires computation). Also, matmul + all-gather overlap could be more common these days due to FSDP. In Tensor Parallel, all-reduce is more likely sequential to matmul due to data dependency. |
Sure -- I'll update you when I have an allgather example. |
|
Oops, sorry, I didn't see your comment. I pushed a benchmark here: I ran it in three modes, on 8 x H100s:
I used the default command generated by Claude in the test file: (i.e. the all-gather is 64 MiB) To enable CE, we can add this option: |
The biggest performance win has come from limiting the number of streams (cb4d494). The wall time went down from 3.31ms to 3.09ms.
On top of that, using symmetric memory gives a slight speedup -- from 3.09ms (screenshot 1) to 3.07ms (screenshot 2). As a sanity check, I do see ncclSymk in the allreduce kernel name (screenshot 3), and it uses only 5 thread blocks instead of 24.
Repro:
Screenshot 1:

Screenshot 2:

Screenshot 3:
