Skip to content

Commit d61df8d

Browse files
authored
add tutel.examples.helloworld_switch (#199)
1 parent 1456b49 commit d61df8d

File tree

6 files changed

+179
-35
lines changed

6 files changed

+179
-35
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enabl
4848
$ python3 ./tutel/examples/helloworld.py --batch_size=16
4949
..
5050
51+
* Switch Test using single-node 8 GPUs:
52+
53+
$ python3 -m torch.distributed.launch --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16
54+
5155
* Run Tutel MoE in Distributed Mode:
5256
5357
(Method A - Torch launcher for `Multi-Node x Multi-GPU`:)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def install(use_cuda, use_nccl):
7474

7575
setup(
7676
name='tutel',
77-
version='0.1',
77+
version='0.2',
7878
description='An Optimized Mixture-of-Experts Implementation.',
7979
url='https://github.com/microsoft/Tutel',
8080
author='Microsoft',

tutel/examples/helloworld.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1)
3030
parser.add_argument('--allreduce_degree', type=int, default=1)
3131
parser.add_argument('--num_steps', type=int, default=100)
32-
parser.add_argument('--parallel_type', type=str, default='auto')
32+
parser.add_argument('--parallel_type', type=str, default='adaptive:1')
3333
parser.add_argument('--checkpoint_path', type=str, default='')
3434
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
3535
parser.add_argument('--use_2dh', default=False, action='store_true')

tutel/examples/helloworld_switch.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT license.
5+
6+
import os
7+
import torch
8+
import torch.optim as optim
9+
import torch.nn.functional as F
10+
from torch import nn
11+
import argparse
12+
13+
from tutel import system
14+
from tutel import moe as tutel_moe
15+
from tutel import net
16+
17+
parser = argparse.ArgumentParser()
18+
19+
parser.add_argument('--local_rank', type=int, default=-1)
20+
parser.add_argument('--batch_size', type=int, default=16)
21+
parser.add_argument('--num_tokens', type=int, default=512)
22+
parser.add_argument('--model_dim', type=int, default=2048)
23+
parser.add_argument('--hidden_size', type=int, default=2048)
24+
parser.add_argument('--num_local_experts', type=int, default=2)
25+
parser.add_argument('--dtype', type=str, default='float32')
26+
parser.add_argument('--fp32_gate', default=False, action='store_true')
27+
parser.add_argument('--top', type=int, default=2)
28+
parser.add_argument('--l_aux_wt', type=float, default=0.0)
29+
parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1)
30+
parser.add_argument('--allreduce_degree', type=int, default=1)
31+
parser.add_argument('--num_steps', type=int, default=100)
32+
parser.add_argument('--cap_factor', type=float, default=1.0)
33+
parser.add_argument('--checkpoint_path', type=str, default='')
34+
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
35+
parser.add_argument('--use_2dh', default=False, action='store_true')
36+
parser.add_argument('--eval', default=False, action='store_true')
37+
args = parser.parse_args()
38+
39+
parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo')
40+
dist_rank, dist_world_size, dist_print = parallel_env.global_rank, parallel_env.global_size, parallel_env.dist_print
41+
args.local_rank = parallel_env.local_device.index
42+
43+
batch_size = args.batch_size
44+
num_tokens = args.num_tokens
45+
model_dim = args.model_dim
46+
hidden_size = args.hidden_size
47+
num_local_experts = args.num_local_experts
48+
top_value = args.top
49+
a2a_ffn_overlap_degree = args.a2a_ffn_overlap_degree
50+
device = parallel_env.local_device
51+
52+
if args.dtype == 'float32':
53+
torch.set_default_dtype(torch.float32)
54+
elif args.dtype == 'float64':
55+
torch.set_default_dtype(torch.float64)
56+
elif args.dtype == 'float16':
57+
torch.set_default_dtype(torch.float16)
58+
elif args.dtype == 'bfloat16':
59+
torch.set_default_dtype(torch.bfloat16)
60+
else:
61+
raise Exception('Unrecognized data type specified: %s' % args.dtype)
62+
63+
64+
class ExampleModel(torch.nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
68+
self._moe_layer = tutel_moe.moe_layer(
69+
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate},
70+
experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
71+
model_dim = model_dim,
72+
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
73+
seeds = (1, dist_rank + 1, 1),
74+
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree,
75+
use_2dh=args.use_2dh,
76+
)
77+
78+
# Summary of different parameter types: gate, local_experts
79+
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
80+
shared_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='gate')])
81+
dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count))
82+
self.r_index = -1
83+
84+
def forward(self, input):
85+
r, o = self._moe_layer.valid_rs[(self.r_index // 8) % len(self._moe_layer.valid_rs)], self.r_index % 8 + 1
86+
self.r_index += 1
87+
88+
result = self._moe_layer(input, capacity_factor=args.cap_factor, adaptive_r=r, a2a_ffn_overlap_degree=o)
89+
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
90+
return result
91+
92+
model = ExampleModel().to(device)
93+
dist_print(model)
94+
95+
if args.checkpoint_path:
96+
checkpoint_path = system.apply_rank_size_from_pattern(args.checkpoint_path, rank=parallel_env.global_rank, size=parallel_env.global_size)
97+
if os.path.exists(checkpoint_path):
98+
model.load_state_dict(torch.load(checkpoint_path))
99+
else:
100+
print('Checkpoint not loaded: file `%s` is not found. Will train the model from start.' % checkpoint_path)
101+
102+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
103+
104+
torch.manual_seed(0)
105+
x = torch.tensor(torch.randn([batch_size, num_tokens, model_dim], dtype=torch.float32, device='cpu').detach().numpy(), dtype=torch.get_default_dtype(), requires_grad=False, device=device)
106+
y = torch.LongTensor(batch_size).random_(1).to(device)
107+
108+
tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, a2a_ffn_overlap_degree, device)
109+
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, device = `%s`' % tuples)
110+
111+
average_time, num_steps = 0, args.num_steps
112+
113+
if args.allreduce_degree == -1:
114+
params_for_all_reduce = []
115+
else:
116+
params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'skip_allreduce') and getattr(p, 'requires_grad', False)]
117+
118+
for i in range(num_steps):
119+
t_start = system.record_time()
120+
121+
if not args.eval:
122+
optimizer.zero_grad()
123+
output = model(x)
124+
loss = F.nll_loss(output, y)
125+
if args.l_aux_wt:
126+
loss += args.l_aux_wt * model._moe_layer.l_aux
127+
loss.backward()
128+
if dist_world_size > 1:
129+
for p in params_for_all_reduce:
130+
p.grad /= dist_world_size
131+
p.grad = net.simple_all_reduce(p.grad)
132+
optimizer.step()
133+
else:
134+
with torch.no_grad():
135+
output = model(x)
136+
loss = F.nll_loss(output, y)
137+
138+
t_stop = system.record_time()
139+
140+
num_global_experts = tutel_moe.moe_layer.global_expert_count(num_local_experts, group=system.get_local_session().model_group)
141+
mm_ceof, cap_ceof = 1 if args.eval else 3, min(args.top, num_global_experts)
142+
tflops = (batch_size * num_tokens * model_dim * hidden_size) * 4 * mm_ceof * cap_ceof * 1e-12 / (t_stop - t_start)
143+
dist_print('STEP-%s: loss = %.5f, step_time = %.6f sec, perf = %.2f tflops. (f = %.1f, r = %d, o = %d)' % (i, float(loss.data), t_stop - t_start, tflops, args.cap_factor, model._moe_layer.adaptive_degree, model._moe_layer.a2a_ffn_overlap_degree))
144+
145+
if i + 10 >= num_steps:
146+
average_time += t_stop - t_start
147+
148+
average_time /= 10
149+
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)
150+
151+
if args.checkpoint_path:
152+
torch.save(model.state_dict(), checkpoint_path)

tutel/experts/ffn.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def forward(self, x, ctx):
5757
batched_fc1_bias = self.batched_fc1_bias.unsqueeze(1)
5858
batched_fc2_bias = self.batched_fc2_bias.unsqueeze(1)
5959

60-
if ctx.force_data_parallel:
60+
if ctx.adaptive_degree == 0:
6161
batched_fc1_w = net.zero_gather(batched_fc1_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc1_w.size(2))
6262
batched_fc2_w = net.zero_gather(batched_fc2_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc2_w.size(2))
6363
batched_fc1_bias = net.zero_gather(batched_fc1_bias, group=ctx.group).view(ctx.num_global_experts, 1, -1)
6464
batched_fc2_bias = net.zero_gather(batched_fc2_bias, group=ctx.group).view(ctx.num_global_experts, 1, -1)
65-
elif ctx.force_adaptive:
65+
else:
6666
if ctx.sharded_count > 1:
6767
group_size = ctx.sharded_count // ctx.adaptive_degree
6868
if group_size > 1:
@@ -71,25 +71,11 @@ def forward(self, x, ctx):
7171
batched_fc2_w = net.zero_gather(batched_fc2_w, group=ffn_zero_group).view(1, -1, self.output_dim)
7272
batched_fc1_bias = net.zero_gather(batched_fc1_bias, group=ffn_zero_group).view(1, 1, -1)
7373

74-
ffn_zero_group2 = net.create_groups_from_world(group_count=ctx.num_global_experts).model_group
75-
batched_fc2_bias = net.zero_gather(batched_fc2_bias, group=ffn_zero_group2)
74+
batched_fc2_bias = net.zero_gather(batched_fc2_bias, group=net.create_groups_from_world(group_count=ctx.num_global_experts).model_group)
7675
batched_fc2_bias = batched_fc2_bias.view(1, 1, -1)
7776

7877
if ctx.adaptive_degree > 1:
7978
batched_fc2_bias = torch.mul(batched_fc2_bias, 1.0 / ctx.adaptive_degree)
80-
else:
81-
if ctx.sharded_count > 1:
82-
ffn_zero_group = net.create_groups_from_world(group_count=ctx.num_global_experts).model_group
83-
if not ctx.use_model_parallel:
84-
batched_fc1_w = net.zero_gather(batched_fc1_w, group=ffn_zero_group).view(1, -1, ctx.model_dim)
85-
batched_fc2_w = net.zero_gather(batched_fc2_w, group=ffn_zero_group).view(1, -1, self.output_dim)
86-
batched_fc1_bias = net.zero_gather(batched_fc1_bias, group=ffn_zero_group).view(1, 1, -1)
87-
88-
batched_fc2_bias = net.zero_gather(batched_fc2_bias, group=ffn_zero_group)
89-
batched_fc2_bias = batched_fc2_bias.view(self.batched_fc2_bias.size(0), 1, -1)
90-
91-
if ctx.use_model_parallel:
92-
batched_fc2_bias = torch.mul(batched_fc2_bias, 1.0 / ctx.sharded_count)
9379

9480
if batched_fc2_bias.size(-1) != self.output_dim:
9581
batched_fc2_bias = batched_fc2_bias[:, :, :self.output_dim]

tutel/impls/moe_layer.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
batch_prioritized_routing=False,
7878
normalize_gate=True,
7979
is_gshard_loss=True,
80-
parallel_type='auto',
80+
parallel_type='adaptive:1',
8181
use_2dh=False,
8282
**kwargs
8383
):
@@ -105,23 +105,20 @@ def __init__(
105105
else:
106106
self.sharded_count = 1
107107

108-
self.force_data_parallel, self.force_adaptive, self.adaptive_degree = False, False, self.sharded_count
108+
self.auto_parallel, self.adaptive_degree, self.use_model_parallel = False, self.sharded_count, True
109+
self.valid_rs = [0] + [i for i in range(1, self.sharded_count + 1) if self.sharded_count % i == 0]
110+
109111
if parallel_type.startswith('adaptive:'):
110112
self.adaptive_degree = int(parallel_type[parallel_type.index(':') + 1:])
111-
if self.adaptive_degree == 0:
112-
self.force_data_parallel = True
113-
else:
114-
if self.adaptive_degree < 0 or self.sharded_count % self.adaptive_degree != 0:
115-
valids = [i for i in range(1, self.sharded_count + 1) if self.sharded_count % i == 0]
116-
raise Exception("Unexpected value of adaptive_degree: %d, expecting a candidate within %s." % (self.adaptive_degree, valids))
117-
self.force_adaptive = True
118-
self.auto_parallel, self.use_model_parallel = False, True
113+
self.adaptive_degree = min(max(self.adaptive_degree, 0), self.sharded_count)
114+
if self.adaptive_degree not in self.valid_rs:
115+
raise Exception("Unexpected value of adaptive_degree: %d, expecting a candidate within %s." % (self.adaptive_degree, self.valid_rs))
119116
elif self.sharded_count == 1:
120-
self.auto_parallel, self.use_model_parallel = False, False
117+
pass
121118
elif parallel_type in ('data', 'model'):
122-
self.auto_parallel, self.use_model_parallel = False, (parallel_type == 'model')
119+
self.adaptive_degree = 1 if parallel_type == 'data' else self.sharded_count
123120
elif parallel_type == 'auto':
124-
self.auto_parallel, self.use_model_parallel = True, False
121+
self.adaptive_degree = 1
125122
else:
126123
raise Exception('Unrecognized parallel type specified: %s' % parallel_type)
127124

@@ -219,7 +216,7 @@ def expert_local(self, x, reserve_shape):
219216
self.protected_shape = y.shape
220217
return y.reshape(y.size(0), y.size(1), -1)
221218

222-
def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1, inequivalent_tokens=False):
219+
def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1, inequivalent_tokens=False, adaptive_r=None):
223220
if self.skip_moe:
224221
result_output = input
225222
result_output.l_aux = None
@@ -233,7 +230,9 @@ def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None,
233230
x = x.to(p.dtype)
234231
break
235232
gctx = self.gates[gate_index]
236-
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree if a2a_ffn_overlap_degree is not None else self.a2a_ffn_overlap_degree
233+
if a2a_ffn_overlap_degree is not None:
234+
self.a2a_ffn_overlap_degree = a2a_ffn_overlap_degree
235+
a2a_ffn_overlap_degree = self.a2a_ffn_overlap_degree
237236

238237
def routing():
239238
logits = gctx(x)
@@ -270,7 +269,10 @@ def routing():
270269

271270
y = fast_encode(x.to(logits_dtype), crit, self.is_postscore).to(x.dtype)
272271

273-
if self.force_data_parallel:
272+
if adaptive_r is not None:
273+
self.adaptive_degree = adaptive_r
274+
275+
if self.adaptive_degree == 0:
274276
y = self.expert_local(y, original_shape[-reserve_dims:])
275277
else:
276278
if self.auto_parallel:

0 commit comments

Comments
 (0)