From ed4143d21c2f816a5312253ba3649b4641404ff3 Mon Sep 17 00:00:00 2001 From: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> Date: Wed, 29 Jan 2025 05:05:45 -0500 Subject: [PATCH] fix: add generic pytorch distributed all_reduce support Signed-off-by: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> --- vllm/distributed/parallel_state.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7fe9b68d4b9e8..2cb0ad3d173ff 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -327,9 +327,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) - return input_ + try: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + return input_ + except ImportError: + """ + Intel IPEX not found. Falling back to PyTorch native + all_reduce for CPU + """ + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: