Skip to content

Commit 591d2c2

Browse files
authored
Add deepseek-r1 gating & mla for AMD MI300 (#261)
1 parent 3ad9069 commit 591d2c2

7 files changed

+683
-52
lines changed

README.md

+78-41
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,58 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
66
- Supported GPUs: CUDA(fp64/fp32/fp16/bfp16), ROCm(fp64/fp32/fp16)
77
- Supported CPU: fp64/fp32
88

9+
### ***Support Full Precision Inference of MoE-based Deepseek R1 671B on AMD MI300:***
910

10-
### What's New:
11+
We compare three solutions that support <ins>Full-Precision Inference (PPL = 0) of Deepseek R1 671B</ins>. PPL = 0 means any quantization or unofficial sparsity techniques that may lower the scores of model, are prohibited.
12+
13+
![benchmarking](doc/DeepSeekR1-tutel-accel.png)
14+
15+
-----------
16+
17+
## What's New:
18+
19+
- Tutel v0.4.0: Accelerating Deepseek R1 Full-precision-Chat for AMD MI300x8 (more platform support will be added in later versions):
20+
```sh
21+
>> Example:
22+
23+
# Step-1: Download Deepseek R1 671B Model
24+
huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir ./deepseek-ai/DeepSeek-R1
25+
26+
# Step-2: Using 8 MI300 GPUs to Run Deepseek R1 Chat with Full Precision (PPL = 0)
27+
docker run -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --privileged \
28+
-v /:/host -w /host$(pwd) tutelgroup/deepseek-671b:mi300x8-fp16xfp8 \
29+
--model_path ./deepseek-ai/DeepSeek-R1 \
30+
--prompt "Calculate the result of: 1 / (sqrt(5) - sqrt(3))"
31+
32+
```
1133

1234
- Tutel v0.3.3: Add all-to-all benchmark:
13-
```py
35+
```sh
1436
>> Example:
1537

1638
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.bandwidth_test --size_mb=256
1739
```
1840

1941
- Tutel v0.3.2: Add tensorcore option for extra benchmarks / Extend the example for custom experts / Allow NCCL timeout settings:
20-
```py
21-
>> Example for using tensorcore:
42+
```sh
43+
>> Example of using tensorcore:
2244

2345
python3 -m tutel.examples.helloworld --dtype=float32
2446
python3 -m tutel.examples.helloworld --dtype=float32 --use_tensorcore
2547

2648
python3 -m tutel.examples.helloworld --dtype=float16
2749
python3 -m tutel.examples.helloworld --dtype=float16 --use_tensorcore
2850

29-
>> Example for custom gates/experts:
51+
>> Example of custom gates/experts:
3052
python3 -m tutel.examples.helloworld_custom_gate_expert --batch_size=16
3153

32-
>> Example for NCCL timeout settings:
54+
>> Example of NCCL timeout settings:
3355
TUTEL_GLOBAL_TIMEOUT_SEC=60 python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --use_tensorcore
3456

3557
```
3658

3759
- Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers:
38-
```py
60+
```sh
3961
>> Example:
4062
# All_to_All_v:
4163
python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_to_all_v
@@ -48,8 +70,8 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
4870
```
4971

5072
- Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2:
51-
```py
52-
>> Example (capacity_factor=0 for dropless-MoE):
73+
```sh
74+
>> Example (capacity_factor=0 required by dropless-MoE):
5375
# Using BatchMatmul:
5476
python3 -m tutel.examples.helloworld --megablocks_size=0 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0
5577
# Using Megablocks with block_size = 1:
@@ -62,7 +84,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
6284
```
6385

6486
- Tutel v0.2: Allow most configurations to be dynamic switchable with free cost:
65-
```py
87+
```sh
6688
>> Example:
6789
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16
6890

@@ -74,35 +96,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
7496
```
7597

7698
- Tutel v0.1: Optimize the Einsum Complexity of Data Dispatch Encoding and Decoding, add 2DH option to deal with All-to-All at scale:
77-
```py
99+
```sh
78100
>> Example (suggest enabling 2DH only at scale, note that the value of --nproc_per_node MUST equal to total physical GPU counts per node, e.g. 8 for A100x8):
79101
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16 --use_2dh
80102
```
81103

104+
-----------
105+
## Getting Started
82106

83-
### How to setup Tutel MoE for Pytorch 2 and [run examples](tutel/examples), or [enable fairseq with MoE](tutel/examples/fairseq_moe):
107+
### 1. Prepare Pytorch (if applicable):
84108
```
85-
* Prepare Recommended Pytorch >= 2.0.0 (minimal version == 1.8.0):
109+
* Prepare Recommended Pytorch >= 2.0.0:
86110
# Windows/Linux Pytorch for NVIDIA CUDA >= 11.7:
87111
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
88-
# Linux Pytorch for AMD ROCm == 5.4.2:
89-
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2
112+
# Linux Pytorch for AMD ROCm >= 6.2.2:
113+
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.2
90114
# Windows/Linux Pytorch for CPU:
91115
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
116+
```
92117

93-
* Install Tutel Online:
118+
### 2. Tutel Installation:
119+
```
120+
* Option-1: Install Tutel Online:
94121
95122
$ python3 -m pip uninstall tutel -y
96-
$ python3 -m pip install setuptools wheel
97123
$ python3 -m pip install -v -U --no-build-isolation git+https://github.com/microsoft/tutel@main
98124
99-
* Build Tutel from Source:
125+
* Option-2: Build Tutel from Source:
100126
101127
$ git clone https://github.com/microsoft/tutel --branch main
102-
103128
$ python3 -m pip uninstall tutel -y
104129
$ python3 ./tutel/setup.py install --user
130+
```
105131

132+
### 3. Quick Test for Single Device / CPU:
133+
```
106134
* Quick Test on Single-GPU:
107135
108136
$ python3 -m tutel.examples.helloworld --batch_size=16 # Test Tutel-optimized MoE + manual distribution
@@ -117,32 +145,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall
117145
(If building from source, the following method also works:)
118146
$ python3 ./tutel/examples/helloworld.py --batch_size=16
119147
..
148+
```
149+
150+
### 4. Quick Test for 8 GPUs within 1 Machine:
151+
```
152+
$ python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16
153+
```
120154

155+
### 5. Quick Test for Multiple GPUs across Machines:
156+
```
121157
* Run Tutel MoE in Distributed Mode:
122158
123-
(Method A - Torch launcher for `Multi-Node x Multi-GPU`:)
159+
(Option A - Torch launcher for `Multi-Node x Multi-GPU`:)
124160
$ ssh <node-ip-0> python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
125161
$ ssh <node-ip-1> python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
126162
127-
(Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:)
163+
(Option B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:)
128164
# << Single Node >>
129165
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld_ddp_tutel --batch_size=16
130166
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.moe_mnist
131167
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.moe_cifar10
132168
...
133169
134-
# << Cross Nodes >>
170+
# << MPI-based launch for GPU backend>>
135171
$ mpiexec -bind-to none -host <node-ip-0>,<node-ip-1>,.. -x MASTER_ADDR=<node-ip-0> -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16
136172
137-
# << For CPU-based Launch>>
173+
# << MPI-based Launch for CPU backend>>
138174
$ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=1 -x OMP_NUM_THREADS=1024 python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16 --device cpu
139-
140175
```
141176

142-
### How to convert checkpoint files that adapt to different distributed world sizes:
143-
Documentation has been moved [here](doc/CHECKPOINT.md).
177+
-----------
178+
179+
### Advance: Convert Checkpoint Files for Different World Sizes:
180+
Documentation for checkpoint conversion has been moved [here](doc/CHECKPOINT.md).
144181

145-
### How to import Tutel-optimized MoE in Pytorch:
182+
### Examples: How to import Tutel-optimized MoE in Pytorch:
146183
```
147184
# Input Example:
148185
import torch
@@ -177,6 +214,20 @@ y = moe_layer(x)
177214
print(y)
178215
```
179216

217+
### Reference
218+
You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel:
219+
```
220+
@article {tutel,
221+
author = {Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},
222+
title = {Tutel: Adaptive Mixture-of-Experts at Scale},
223+
year = {2022},
224+
month = jun,
225+
journal = {CoRR},
226+
volume= {abs/2206.03382},
227+
url = {https://arxiv.org/pdf/2206.03382.pdf},
228+
}
229+
```
230+
180231
### Usage of MOELayer:
181232
```
182233
* Usage of MOELayer Args:
@@ -205,20 +256,6 @@ print(y)
205256
has_fc2_bias : If set to False, the expert bias parameters `batched_fc2_bias` is disabled. Default: True
206257
```
207258

208-
### Reference
209-
You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel:
210-
```
211-
@article {tutel,
212-
author = {Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},
213-
title = {Tutel: Adaptive Mixture-of-Experts at Scale},
214-
year = {2022},
215-
month = jun,
216-
journal = {CoRR},
217-
volume= {abs/2206.03382},
218-
url = {https://arxiv.org/pdf/2206.03382.pdf},
219-
}
220-
```
221-
222259
### Contributing
223260

224261
This project welcomes contributions and suggestions. Most contributions require you to agree to a

doc/DeepSeekR1-tutel-accel.png

184 KB
Loading

setup.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Tuple
1717

1818
from setuptools import setup, find_packages, Command
19+
import torch
1920
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
2021

2122
try:
@@ -57,7 +58,7 @@ def run(self):
5758
def install(use_cuda, use_nccl):
5859
ext_libs = []
5960
if pf.system() == 'Linux':
60-
ext_args = ['-Wno-sign-compare', '-Wno-unused-but-set-variable', '-Wno-terminate', '-Wno-unused-function', '-Wno-strict-aliasing']
61+
ext_args = ['-w']
6162
elif pf.system() == 'Darwin':
6263
ext_args = ['-mmacosx-version-min=10.13']
6364
else:
@@ -80,7 +81,7 @@ def install(use_cuda, use_nccl):
8081

8182
setup(
8283
name='tutel',
83-
version='0.3',
84+
version='0.4',
8485
description='An Optimized Mixture-of-Experts Implementation.',
8586
url='https://github.com/microsoft/Tutel',
8687
author='Microsoft',
@@ -138,16 +139,14 @@ def install(use_cuda, use_nccl):
138139
},
139140
)
140141

141-
if int(os.environ.get('NO_CUDA', 0)) == 1:
142-
print('Installing without CUDA extension..')
143-
install(use_cuda=False, use_nccl=False)
144-
else:
142+
if (torch.version.cuda or torch.version.hip) and int(os.environ.get('NO_CUDA', 0)) == 0:
145143
try:
144+
print('Try installing with NCCL extension..')
146145
install(use_cuda=True, use_nccl=True)
147146
except:
148147
print('Try installing without NCCL extension..')
149-
try:
150-
install(use_cuda=True, use_nccl=False)
151-
except:
152-
print('Try installing without CUDA extension..')
153-
install(use_cuda=False, use_nccl=False)
148+
install(use_cuda=True, use_nccl=False)
149+
else:
150+
print('Installing without CUDA extension..')
151+
install(use_cuda=False, use_nccl=False)
152+

0 commit comments

Comments
 (0)