Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deepseek-r1 gating & mla for AMD MI300 #261

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 78 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,58 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
- Supported GPUs: CUDA(fp64/fp32/fp16/bfp16), ROCm(fp64/fp32/fp16)
- Supported CPU: fp64/fp32

### ***Support Full Precision Inference of MoE-based Deepseek R1 671B on AMD MI300:***

### What's New:
We compare three solutions that support <ins>Full-Precision Inference (PPL = 0) of Deepseek R1 671B</ins>. PPL = 0 means any quantization or unofficial sparsity techniques that may lower the scores of model, are prohibited.

![benchmarking](doc/DeepSeekR1-tutel-accel.png)

-----------

## What's New:

- Tutel v0.4.0: Accelerating Deepseek R1 Full-precision-Chat for AMD MI300x8 (more platform support will be added in later versions):
```sh
>> Example:

# Step-1: Download Deepseek R1 671B Model
huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir ./deepseek-ai/DeepSeek-R1

# Step-2: Using 8 MI300 GPUs to Run Deepseek R1 Chat with Full Precision (PPL = 0)
docker run -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --privileged \
-v /:/host -w /host$(pwd) tutelgroup/deepseek-671b:mi300x8-fp16xfp8 \
--model_path ./deepseek-ai/DeepSeek-R1 \
--prompt "Calculate the result of: 1 / (sqrt(5) - sqrt(3))"

```

- Tutel v0.3.3: Add all-to-all benchmark:
```py
```sh
>> Example:

python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.bandwidth_test --size_mb=256
```

- Tutel v0.3.2: Add tensorcore option for extra benchmarks / Extend the example for custom experts / Allow NCCL timeout settings:
```py
>> Example for using tensorcore:
```sh
>> Example of using tensorcore:

python3 -m tutel.examples.helloworld --dtype=float32
python3 -m tutel.examples.helloworld --dtype=float32 --use_tensorcore

python3 -m tutel.examples.helloworld --dtype=float16
python3 -m tutel.examples.helloworld --dtype=float16 --use_tensorcore

>> Example for custom gates/experts:
>> Example of custom gates/experts:
python3 -m tutel.examples.helloworld_custom_gate_expert --batch_size=16

>> Example for NCCL timeout settings:
>> Example of NCCL timeout settings:
TUTEL_GLOBAL_TIMEOUT_SEC=60 python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --use_tensorcore

```

- Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers:
```py
```sh
>> Example:
# All_to_All_v:
python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_to_all_v
Expand All @@ -48,8 +70,8 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
```

- Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2:
```py
>> Example (capacity_factor=0 for dropless-MoE):
```sh
>> Example (capacity_factor=0 required by dropless-MoE):
# Using BatchMatmul:
python3 -m tutel.examples.helloworld --megablocks_size=0 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0
# Using Megablocks with block_size = 1:
Expand All @@ -62,7 +84,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
```

- Tutel v0.2: Allow most configurations to be dynamic switchable with free cost:
```py
```sh
>> Example:
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16

Expand All @@ -74,35 +96,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
```

- Tutel v0.1: Optimize the Einsum Complexity of Data Dispatch Encoding and Decoding, add 2DH option to deal with All-to-All at scale:
```py
```sh
>> Example (suggest enabling 2DH only at scale, note that the value of --nproc_per_node MUST equal to total physical GPU counts per node, e.g. 8 for A100x8):
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16 --use_2dh
```

-----------
## Getting Started

### How to setup Tutel MoE for Pytorch 2 and [run examples](tutel/examples), or [enable fairseq with MoE](tutel/examples/fairseq_moe):
### 1. Prepare Pytorch (if applicable):
```
* Prepare Recommended Pytorch >= 2.0.0 (minimal version == 1.8.0):
* Prepare Recommended Pytorch >= 2.0.0:
# Windows/Linux Pytorch for NVIDIA CUDA >= 11.7:
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# Linux Pytorch for AMD ROCm == 5.4.2:
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2
# Linux Pytorch for AMD ROCm >= 6.2.2:
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.2
# Windows/Linux Pytorch for CPU:
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```

* Install Tutel Online:
### 2. Tutel Installation:
```
* Option-1: Install Tutel Online:

$ python3 -m pip uninstall tutel -y
$ python3 -m pip install setuptools wheel
$ python3 -m pip install -v -U --no-build-isolation git+https://github.com/microsoft/tutel@main

* Build Tutel from Source:
* Option-2: Build Tutel from Source:

$ git clone https://github.com/microsoft/tutel --branch main

$ python3 -m pip uninstall tutel -y
$ python3 ./tutel/setup.py install --user
```

### 3. Quick Test for Single Device / CPU:
```
* Quick Test on Single-GPU:

$ python3 -m tutel.examples.helloworld --batch_size=16 # Test Tutel-optimized MoE + manual distribution
Expand All @@ -117,32 +145,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
(If building from source, the following method also works:)
$ python3 ./tutel/examples/helloworld.py --batch_size=16
..
```

### 4. Quick Test for 8 GPUs within 1 Machine:
```
$ python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16
```

### 5. Quick Test for Multiple GPUs across Machines:
```
* Run Tutel MoE in Distributed Mode:

(Method A - Torch launcher for `Multi-Node x Multi-GPU`:)
(Option A - Torch launcher for `Multi-Node x Multi-GPU`:)
$ ssh <node-ip-0> python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
$ ssh <node-ip-1> python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16

(Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:)
(Option B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:)
# << Single Node >>
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld_ddp_tutel --batch_size=16
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.moe_mnist
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.moe_cifar10
...

# << Cross Nodes >>
# << MPI-based launch for GPU backend>>
$ mpiexec -bind-to none -host <node-ip-0>,<node-ip-1>,.. -x MASTER_ADDR=<node-ip-0> -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16

# << For CPU-based Launch>>
# << MPI-based Launch for CPU backend>>
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=1 -x OMP_NUM_THREADS=1024 python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16 --device cpu

```

### How to convert checkpoint files that adapt to different distributed world sizes:
Documentation has been moved [here](doc/CHECKPOINT.md).
-----------

### Advance: Convert Checkpoint Files for Different World Sizes:
Documentation for checkpoint conversion has been moved [here](doc/CHECKPOINT.md).

### How to import Tutel-optimized MoE in Pytorch:
### Examples: How to import Tutel-optimized MoE in Pytorch:
```
# Input Example:
import torch
Expand Down Expand Up @@ -177,6 +214,20 @@ y = moe_layer(x)
print(y)
```

### Reference
You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel:
```
@article {tutel,
author = {Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},
title = {Tutel: Adaptive Mixture-of-Experts at Scale},
year = {2022},
month = jun,
journal = {CoRR},
volume= {abs/2206.03382},
url = {https://arxiv.org/pdf/2206.03382.pdf},
}
```

### Usage of MOELayer:
```
* Usage of MOELayer Args:
Expand Down Expand Up @@ -205,20 +256,6 @@ print(y)
has_fc2_bias : If set to False, the expert bias parameters `batched_fc2_bias` is disabled. Default: True
```

### Reference
You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel:
```
@article {tutel,
author = {Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},
title = {Tutel: Adaptive Mixture-of-Experts at Scale},
year = {2022},
month = jun,
journal = {CoRR},
volume= {abs/2206.03382},
url = {https://arxiv.org/pdf/2206.03382.pdf},
}
```

### Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Expand Down
Binary file added doc/DeepSeekR1-tutel-accel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 10 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Tuple

from setuptools import setup, find_packages, Command
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension

try:
Expand Down Expand Up @@ -57,7 +58,7 @@ def run(self):
def install(use_cuda, use_nccl):
ext_libs = []
if pf.system() == 'Linux':
ext_args = ['-Wno-sign-compare', '-Wno-unused-but-set-variable', '-Wno-terminate', '-Wno-unused-function', '-Wno-strict-aliasing']
ext_args = ['-w']
elif pf.system() == 'Darwin':
ext_args = ['-mmacosx-version-min=10.13']
else:
Expand All @@ -80,7 +81,7 @@ def install(use_cuda, use_nccl):

setup(
name='tutel',
version='0.3',
version='0.4',
description='An Optimized Mixture-of-Experts Implementation.',
url='https://github.com/microsoft/Tutel',
author='Microsoft',
Expand Down Expand Up @@ -138,16 +139,14 @@ def install(use_cuda, use_nccl):
},
)

if int(os.environ.get('NO_CUDA', 0)) == 1:
print('Installing without CUDA extension..')
install(use_cuda=False, use_nccl=False)
else:
if (torch.version.cuda or torch.version.hip) and int(os.environ.get('NO_CUDA', 0)) == 0:
try:
print('Try installing with NCCL extension..')
install(use_cuda=True, use_nccl=True)
except:
print('Try installing without NCCL extension..')
try:
install(use_cuda=True, use_nccl=False)
except:
print('Try installing without CUDA extension..')
install(use_cuda=False, use_nccl=False)
install(use_cuda=True, use_nccl=False)
else:
print('Installing without CUDA extension..')
install(use_cuda=False, use_nccl=False)

Loading