Skip to content

Commit bddc915

Browse files
authored
rename is_postnorm to is_postscore (#107)
1 parent 712bf2e commit bddc915

File tree

5 files changed

+52
-28
lines changed

5 files changed

+52
-28
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,14 @@ Usage of MOELayer:
9999
or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}],
100100
the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert
101101
model_dim : the number of channels for MOE's input tensor
102-
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
102+
experts : a dict-type config for builtin expert network
103103
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
104104
result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
105105
group : specify the explicit communication group of all_to_all
106106
seeds : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's)
107+
a2a_ffn_overlap_degree : the value to control a2a overlap depth, 1 by default for no overlap, 2 for overlap a2a with half gemm, ..
108+
parallel_type : the parallel method to compute MoE, valid types: 'auto', 'data', 'model'
109+
pad_samples : whether do auto padding on newly-coming input data to maximum data size in history
107110
108111
* Usage of dict-type Experts Config:
109112

tutel/impls/communicate.py

+4
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,12 @@ def init(group: dist.ProcessGroup, num_split: int, split_dim: int) -> None:
150150

151151

152152
class AllToAll(torch.autograd.Function):
153+
_use_builtins = False
154+
153155
@staticmethod
154156
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor):
157+
AllToAll._use_builtins = True
158+
155159
ctx.group = group
156160
world_size = get_world_size(group)
157161
if world_size <= 1:

tutel/impls/fast_dispatch.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype):
8989
self.original_dtype = dispatch_dtype
9090
self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1)
9191

92-
def update(self, indices_, locations_, gates_, capacity=None, is_postnorm=True):
92+
def update(self, indices_, locations_, gates_, capacity=None, is_postscore=True):
9393
self.indices_ = [x.to(torch.int32).view(-1) for x in indices_]
9494
self.locations_ = [x.to(torch.int32) for x in locations_]
9595
self.gates_ = [x.to(self.dtype) for x in gates_]
9696
sample_size = int(self.indices_[0].size(0))
9797
capacity = int(capacity) or self.capacity
98-
self.is_postnorm = is_postnorm
98+
self.is_postscore = is_postscore
9999

100100
if sample_size != self.expected_sample_size or capacity != self.capacity:
101101
self.expected_sample_size, self.capacity = sample_size, capacity
@@ -109,13 +109,13 @@ def update(self, indices_, locations_, gates_, capacity=None, is_postnorm=True):
109109
self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper = self.kernel_pool[tuple((sample_size, capacity))]
110110

111111
def encode(self, data):
112-
if self.is_postnorm:
112+
if self.is_postscore:
113113
return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype)
114114
else:
115115
return GatingEncoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)
116116

117117
def decode(self, data):
118-
if self.is_postnorm:
118+
if self.is_postscore:
119119
return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)
120120
else:
121121
return GatingDecoder.apply(self, data.to(self.dtype)).to(self.original_dtype)

tutel/impls/moe_layer.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,22 @@ def __init__(
4848
num_global_experts,
4949
a2a_ffn_overlap_degree=1,
5050
capacity_factor=1.0,
51-
top_k=2,
51+
k=2,
5252
batch_prioritized_routing=False,
53-
**kwargs,
53+
fp32_gate=False,
54+
is_postscore=True,
55+
input_dropout_p=0,
5456
):
5557
super().__init__()
56-
top_k = min(top_k, num_global_experts)
57-
self.top_k = top_k
58+
k = min(k, num_global_experts)
59+
self.top_k = k
5860
assert self.top_k > 0, "Top-k value %d is not valid." % self.top_k
5961

6062
self.wg = torch.nn.Linear(model_dim, num_global_experts, bias=False)
6163

62-
self.fp32_gate = kwargs.get('fp32_gate', False)
64+
self.fp32_gate = fp32_gate
6365
if self.fp32_gate:
64-
self.wg = self.wg.float()
66+
self.wg = self.wg.float()
6567

6668
self.capacity_factor = float(os.environ.get('CAP_FACTOR', capacity_factor))
6769
self.is_ones_gate = (int(os.environ.get('ONES_GATE', 0)) == 1)
@@ -71,8 +73,7 @@ def __init__(
7173
if int(os.environ.get('BATCH_PRIO', 0)) != 0:
7274
self.batch_prioritized_routing = True
7375

74-
self.is_postnorm = kwargs.get('is_postnorm', True)
75-
input_dropout_p = kwargs.get('input_dropout_p', 0)
76+
self.is_postscore = is_postscore
7677
self.input_dropout = torch.nn.Dropout(p=input_dropout_p) if input_dropout_p else None
7778

7879
self.a2a_ffn_overlap_degree = a2a_ffn_overlap_degree
@@ -134,7 +135,7 @@ def apply_on_expert_fn(self, input, ctx):
134135

135136
if self.is_ones_gate:
136137
gates_s = [torch.ones_like(x) for x in gates_s]
137-
self._fdr.update(indices_s, locations_s, gates_s, capacity=capacity, is_postnorm=self.is_postnorm)
138+
self._fdr.update(indices_s, locations_s, gates_s, capacity=capacity, is_postscore=self.is_postscore)
138139

139140
dispatched_input = self._fdr.encode(input)
140141

@@ -223,7 +224,19 @@ class MOELayer(torch.nn.Module):
223224
"""Tutel optimized MOELayer
224225
"""
225226

226-
def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func = None, result_func = None, group: Optional[Any] = None, seeds = None, a2a_ffn_overlap_degree = 1, **kwargs):
227+
def __init__(
228+
self,
229+
gate_type,
230+
model_dim: int,
231+
experts=None,
232+
scan_expert_func=None,
233+
result_func=None,
234+
group=None,
235+
seeds=None,
236+
a2a_ffn_overlap_degree=1,
237+
parallel_type='auto',
238+
pad_samples=False,
239+
):
227240
super().__init__()
228241
assert model_dim % 2 == 0, "Model_dim (%s) must be even value, while this Model_dim mod 2 > 0." % model_dim
229242
group = group or dist.group.WORLD
@@ -257,7 +270,6 @@ def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func =
257270
self.num_global_experts = num_devices * self.num_local_experts
258271
sharded_count = 1
259272

260-
parallel_type = kwargs.get('parallel_type', 'auto')
261273
if sharded_count == 1 or not self.is_builtin_experts:
262274
self.auto_parallel, self.use_model_parallel = False, False
263275
elif parallel_type == 'auto':
@@ -413,11 +425,9 @@ def to(self, *args, **kwargs):
413425
if single_gate_type['type'] == 'top':
414426
if seeds is not None and seeds[0] is not None:
415427
torch.manual_seed(seeds[0] + gi)
416-
if "fp32_gate" in kwargs:
417-
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
418-
single_gate_type["fp32_gate"] = kwargs["fp32_gate"]
419428

420-
self.gates += [TopKGate(model_dim=model_dim, top_k=single_gate_type['k'], num_global_experts=self.num_global_experts, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, **single_gate_type)]
429+
single_gate_type.pop('type')
430+
self.gates += [TopKGate(model_dim=model_dim, num_global_experts=self.num_global_experts, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, **single_gate_type)]
421431
else:
422432
raise Exception("Unrecognized gate_type: %s" % single_gate_type)
423433

@@ -435,7 +445,7 @@ def expert_fn(dispatched_input):
435445
return expert_output
436446

437447
self.expert_fn = expert_fn
438-
self.expected_sample_size = 0 if kwargs.get('scale_samples', False) else -1
448+
self.expected_sample_size = 0 if pad_samples else -1
439449

440450
def get_parameter_iterator(self, param_type):
441451
if param_type == 'gate':
@@ -445,7 +455,7 @@ def get_parameter_iterator(self, param_type):
445455
else:
446456
raise Exception("Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type)
447457

448-
def forward(self, input: Tensor, gate_index=0, **kwargs: Any):
458+
def forward(self, input: Tensor, gate_index=0):
449459
if self.skip_moe:
450460
result_output = input
451461
result_output.l_aux = None

tutel/system_init.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
import os
4+
import os, sys
55
import re
66
import logging
77

@@ -25,12 +25,19 @@ def init_affinity_at_program_beginning():
2525
logging.warning('Failed to set NUMA status: %s' % ex)
2626

2727
def init_data_model_parallel(group_count=1, backend='nccl'):
28-
from tutel.impls.communicate import create_groups_from_world
29-
result = create_groups_from_world(group_count=group_count, include_init=backend)
28+
from tutel.impls import communicate as C
29+
result = C.create_groups_from_world(group_count=group_count, include_init=backend)
30+
logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}')
31+
32+
def on_quit():
33+
sys.stdout.flush()
34+
sys.stderr.flush()
35+
# Builtin dist.all_to_all_single in torch is unstable in some versions.
36+
# Temp work around: https://github.com/pytorch/pytorch/issues/56390
37+
if C.AllToAll._use_builtins:
38+
os._exit(0)
3039

31-
# Temp work around for: https://github.com/pytorch/pytorch/issues/56390
3240
import atexit
33-
atexit.register(lambda *args: os._exit(0))
41+
atexit.register(lambda *args: on_quit())
3442

35-
logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}')
3643
return result

0 commit comments

Comments
 (0)