Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 17, 2026

The biggest performance win has come from limiting the number of streams (cb4d494). The wall time went down from 3.31ms to 3.09ms.

On top of that, using symmetric memory gives a slight speedup -- from 3.09ms (screenshot 1) to 3.07ms (screenshot 2). As a sanity check, I do see ncclSymk in the allreduce kernel name (screenshot 3), and it uses only 5 thread blocks instead of 24.

Repro:

me @ viking-prod-232 : dev | /opt/pytorch/nvfuser (wjy/symm)
$ mpirun -np 2 pytest tests/python/multidevice/test_overlap.py::'test_row_parallel_linear_forward_reference_benchmark' --only-mpi -vs

Screenshot 1:
image

Screenshot 2:
image

Screenshot 3:
image

@wujingyue wujingyue changed the title Use symmetric memory Use symmetric memory in reference model Jan 17, 2026
@github-actions
Copy link

Description

  • Introduce symmetric memory usage in multi-device reference model

  • Add NCCL process group configuration with zero CTA policy

  • Replace torch.empty with symm_mem.empty for output tensor allocation

  • Add symm_mem.rendezvous() for cross-device memory synchronization

  • Update stream pool access pattern to use modulo operation

Changes walkthrough

Relevant files
Enhancement
conftest.py
Configure symmetric memory in test setup                                 

tests/python/multidevice/conftest.py

  • Add symmetric memory import
  • Configure NCCL process group with zero CTA policy
  • Initialize symmetric memory backend and enable for world group
  • Simplify device_id parameter by removing torch.device wrapper
  • +11/-2   
    test_overlap.py
    Integrate symmetric memory in reference model                       

    tests/python/multidevice/test_overlap.py

  • Add symmetric memory import
  • Modify row_parallel_linear_forward_reference to accept output tensor
    parameter
  • Replace torch.empty with symm_mem.empty for output allocation
  • Add symm_mem.rendezvous() call for memory synchronization
  • Update stream pool access pattern using modulo operation
  • +21/-10 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Stream Pool Logic Change

    The stream selection logic was changed from stream_pool.get(i) to stream_pool.get(i % 2), which reduces the number of streams used from potentially many to just 2. This could impact performance and should be validated to ensure it doesn't cause performance regressions.

    worker_stream = stream_pool.get(i % 2)
    Symmetric Memory API Usage

    The PR introduces new symmetric memory APIs (symm_mem.empty(), symm_mem.rendezvous()) which are relatively new in PyTorch. The correctness and performance implications of these APIs should be thoroughly tested, especially in multi-GPU scenarios.

    out = symm_mem.empty(
        inp_shard.size(0),
        weight_shard.size(0),
        device="cuda",
        dtype=inp_shard.dtype,
    )
    symm_mem.rendezvous(out, group=dist.group.WORLD)
    Process Group Configuration

    The PR modifies the process group initialization to include NCCL options and device_id parameter. The impact of these changes on existing functionality should be validated, particularly for single-device scenarios.

    opts = dist.ProcessGroupNCCL.Options()
    opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO
    dist.init_process_group(
        backend="nccl",
        pg_options=opts,
        # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51.
        init_method="tcp://localhost:29500",
        world_size=world_size,
        rank=rank,
        device_id=local_rank,

    @wujingyue wujingyue changed the title Use symmetric memory in reference model Use symmetric memory in matmul+allreduce reference implementation Jan 17, 2026
    @wujingyue
    Copy link
    Collaborator Author

    cc @kwen2501 let me know if you have any suggestions! In case you'd like to run the microbenchmark by yourself, the reference implementation above doesn't depend on nvFuser at all -- you should be able to simply git clone and mpirun.

    @kwen2501
    Copy link

    kwen2501 commented Jan 17, 2026

    Hi @wujingyue I was thinking of all-gather bc it can use CE while all-reduce still needs SMs (reduction still requires computation).

    Also, matmul + all-gather overlap could be more common these days due to FSDP. In Tensor Parallel, all-reduce is more likely sequential to matmul due to data dependency.

    @wujingyue
    Copy link
    Collaborator Author

    Hi @wujingyue I was thinking of all-gather bc it can use CE while all-reduce still needs SMs (reduction still requires computation).

    Also, matmul + all-gather overlap could be more common these days due to FSDP. In Tensor Parallel, all-reduce is more likely sequential to matmul due to data dependency.

    Sure -- I'll update you when I have an allgather example.

    @kwen2501
    Copy link

    Oops, sorry, I didn't see your comment.

    I pushed a benchmark here:
    https://github.com/pytorch/pytorch/pull/172714
    (I don’t have access to Fuser, so I pushed it to a PyTorch branch)

    I ran it in three modes, on 8 x H100s:

    • Sequential: 2.96 ms
    • Overlap, w/o CE: 2.02 ms
    • Overlap, with CE: 1.77 ms

    I used the default command generated by Claude in the test file:

    torchrun --nproc_per_node=8 benchmarks/distributed/bench_overlapped_matmul_allgather.py \
    --m 8192 --n 8192 --k 8192 --ag-mb 64 --dtype fp16 --iters 200 --warmup 50
    

    (i.e. the all-gather is 64 MiB)

    To enable CE, we can add this option:
    --nccl-cta-policy-zero

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants