diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index 46c7b2bd6b..be41ab9d55 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -203,7 +203,7 @@ def create_shared_buffer( pointer = cudart.cudaMalloc(size_in_bytes) handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: - group = dist.group.WORLD + group = dist.get_group() # world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) # handles = [None] * world_size diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 82f43a515c..d5a3ba5cc9 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -767,9 +767,22 @@ def trtllm_custom_all_reduce( def _should_use_oneshot( token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int ) -> bool: - comm_size_mb = ( - token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024 - ) + DTYPE_SIZE_MAP = { + torch.float16: 2, + torch.bfloat16: 2, + torch.float32: 4, + torch.float64: 8, + torch.int8: 1, + torch.int16: 2, + torch.int32: 4, + torch.int64: 8, + torch.uint8: 1, + torch.bool: 1, + torch.complex64: 8, + torch.complex128: 16, + } + itemsize = DTYPE_SIZE_MAP[dtype] + comm_size_mb = token_num * hidden_dim * 2 * world_size * itemsize / 1024 / 1024 return comm_size_mb <= _use_oneshot_heuristics[world_size]