From 87edc8cd947b69c871baf222cc9e9c2973df6b77 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 4 Feb 2026 16:07:22 +0800 Subject: [PATCH 1/2] fix python api compat issues --- flashinfer/comm/cuda_ipc.py | 3 ++- flashinfer/comm/trtllm_ar.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index 46c7b2bd6b..24b3c22688 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -203,7 +203,8 @@ def create_shared_buffer( pointer = cudart.cudaMalloc(size_in_bytes) handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: - group = dist.group.WORLD + # group = dist.group.WORLD + group = dist.get_group() # world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) # handles = [None] * world_size diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 82f43a515c..d5a3ba5cc9 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -767,9 +767,22 @@ def trtllm_custom_all_reduce( def _should_use_oneshot( token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int ) -> bool: - comm_size_mb = ( - token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024 - ) + DTYPE_SIZE_MAP = { + torch.float16: 2, + torch.bfloat16: 2, + torch.float32: 4, + torch.float64: 8, + torch.int8: 1, + torch.int16: 2, + torch.int32: 4, + torch.int64: 8, + torch.uint8: 1, + torch.bool: 1, + torch.complex64: 8, + torch.complex128: 16, + } + itemsize = DTYPE_SIZE_MAP[dtype] + comm_size_mb = token_num * hidden_dim * 2 * world_size * itemsize / 1024 / 1024 return comm_size_mb <= _use_oneshot_heuristics[world_size] From c8d31bdb1540b6e3749078528c2a456a4db8fcbf Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Wed, 4 Feb 2026 16:26:31 +0800 Subject: [PATCH 2/2] Update cuda_ipc.py --- flashinfer/comm/cuda_ipc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index 24b3c22688..be41ab9d55 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -203,7 +203,6 @@ def create_shared_buffer( pointer = cudart.cudaMalloc(size_in_bytes) handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: - # group = dist.group.WORLD group = dist.get_group() # world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group)