Skip to content

Commit ea17ea6

Browse files
authored
support handling multi-gate options (#71)
1 parent e7d165f commit ea17ea6

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ Usage of MOELayer:
8282
```
8383
* Usage of MOELayer Args:
8484
85-
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
85+
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'},
86+
or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}],
87+
the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert
8688
model_dim : the number of channels for MOE's input tensor
8789
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
8890
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`

tutel/impls/moe_layer.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def apply_on_expert_fn(self, input, expert_fn, group, sharded_count):
148148
return result_output, l_loss
149149

150150

151-
class MegatronLMGate():
151+
class MegatronLMGate(torch.nn.Module):
152152
"""Megatron-LM Tensor Parallel over MoE Gate Type
153153
"""
154154

@@ -157,6 +157,9 @@ def __init__(
157157
**kwargs,
158158
):
159159
self.l_zero = None
160+
self._modules = dict()
161+
self._parameters = dict()
162+
self._buffers = dict()
160163

161164
def named_parameters(self):
162165
return []
@@ -173,15 +176,6 @@ def apply_on_expert_fn(self, input, expert_fn, group, sharded_count):
173176

174177
class MOELayer(torch.nn.Module):
175178
"""Tutel optimized MOELayer
176-
177-
Args:
178-
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
179-
model_dim : the number of channels for MOE's input tensor
180-
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
181-
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
182-
result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
183-
group : specify the explicit communication group of all_to_all
184-
seeds : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's)
185179
"""
186180

187181
def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func = None, result_func = None, group: Optional[Any] = None, seeds = None, **kwargs):
@@ -342,22 +336,28 @@ def to(self, *args, **kwargs):
342336
logging.warning(f"gate_type value `{gate_type}` in tutel.moe_layer has been deprecated, please use gate_type = {{'type': 'top', 'k': {top_k}}} instead.")
343337
gate_type = {'type': 'top', 'k': top_k}
344338

345-
if gate_type['type'] == 'top':
346-
if seeds is not None and seeds[0] is not None:
347-
torch.manual_seed(seeds[0])
348-
349-
if "fp32_gate" in kwargs:
350-
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
351-
gate_type["fp32_gate"] = kwargs["fp32_gate"]
339+
if not isinstance(gate_type, list):
340+
gate_type = [gate_type]
341+
342+
self.gates = []
343+
for gi, single_gate_type in enumerate(gate_type):
344+
if single_gate_type['type'] == 'top':
345+
if seeds is not None and seeds[0] is not None:
346+
torch.manual_seed(seeds[0] + gi)
347+
if "fp32_gate" in kwargs:
348+
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
349+
single_gate_type["fp32_gate"] = kwargs["fp32_gate"]
350+
351+
self.gates += [TopKGate(model_dim=model_dim, top_k=single_gate_type['k'], num_global_experts=self.num_global_experts, **single_gate_type)]
352+
elif single_gate_type['type'] == 'megatron':
353+
self.gates += [MegatronLMGate(**single_gate_type)]
354+
assert isinstance(experts, dict), "Gate type `megatron` requires dict-type expert description."
355+
assert self.num_local_experts == 1, "Gate type `megatron` requires `count_per_node` == 1 in expert attributions."
356+
assert experts['type'] == 'ffn', "Gate type `megatron` requires `type` == `ffn` in expert attributions."
357+
else:
358+
raise Exception("Unrecognized gate_type: %s" % single_gate_type)
352359

353-
self.gate = TopKGate(model_dim=model_dim, top_k=gate_type['k'], num_global_experts=self.num_global_experts, **gate_type)
354-
elif gate_type['type'] == 'megatron':
355-
self.gate = MegatronLMGate(**gate_type)
356-
assert isinstance(experts, dict), "Gate type `megatron` requires dict-type expert description."
357-
assert self.num_local_experts == 1, "Gate type `megatron` requires `count_per_node` == 1 in expert attributions."
358-
assert experts['type'] == 'ffn', "Gate type `megatron` requires `type` == `ffn` in expert attributions."
359-
else:
360-
raise Exception("Unrecognized gate_type: %s" % gate_type)
360+
self.gates = ModuleList(self.gates)
361361

362362
if seeds is not None and len(seeds) > 2 and seeds[2] is not None:
363363
torch.manual_seed(seeds[2])
@@ -375,13 +375,13 @@ def expert_fn(dispatched_input):
375375

376376
def get_parameter_iterator(self, param_type):
377377
if param_type == 'gate':
378-
return self.gate.named_parameters()
378+
return self.gates.named_parameters()
379379
elif param_type == 'local_experts':
380380
return self.experts.named_parameters()
381381
else:
382382
raise Exception("Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type)
383383

384-
def forward(self, input: Tensor, **kwargs: Any):
384+
def forward(self, input: Tensor, gate_index=0, **kwargs: Any):
385385
if self.skip_moe:
386386
result_output = input
387387
result_output.l_aux = None
@@ -404,7 +404,7 @@ def forward(self, input: Tensor, **kwargs: Any):
404404
reshaped_input = pad_input
405405

406406
reshaped_input = reshaped_input.to(next(iter(self.experts.parameters())).dtype)
407-
result_output, l_aux = self.gate.apply_on_expert_fn(reshaped_input, self.expert_fn, self.group, sharded_count=self.sharded_count)
407+
result_output, l_aux = self.gates[gate_index].apply_on_expert_fn(reshaped_input, self.expert_fn, self.group, sharded_count=self.sharded_count)
408408

409409
result_output = result_output[:reshaped_input_samples, :]
410410
result_output = result_output.view(original_shape).to(original_dtype)

0 commit comments

Comments
 (0)