You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: README.md
+3-1
Original file line number
Diff line number
Diff line change
@@ -82,7 +82,9 @@ Usage of MOELayer:
82
82
```
83
83
* Usage of MOELayer Args:
84
84
85
-
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
85
+
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'},
86
+
or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}],
87
+
the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert
86
88
model_dim : the number of channels for MOE's input tensor
87
89
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
88
90
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
179
-
model_dim : the number of channels for MOE's input tensor
180
-
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
181
-
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
182
-
result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
183
-
group : specify the explicit communication group of all_to_all
184
-
seeds : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's)
logging.warning(f"gate_type value `{gate_type}` in tutel.moe_layer has been deprecated, please use gate_type = {{'type': 'top', 'k': {top_k}}} instead.")
343
337
gate_type= {'type': 'top', 'k': top_k}
344
338
345
-
ifgate_type['type'] =='top':
346
-
ifseedsisnotNoneandseeds[0] isnotNone:
347
-
torch.manual_seed(seeds[0])
348
-
349
-
if"fp32_gate"inkwargs:
350
-
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
351
-
gate_type["fp32_gate"] =kwargs["fp32_gate"]
339
+
ifnotisinstance(gate_type, list):
340
+
gate_type= [gate_type]
341
+
342
+
self.gates= []
343
+
forgi, single_gate_typeinenumerate(gate_type):
344
+
ifsingle_gate_type['type'] =='top':
345
+
ifseedsisnotNoneandseeds[0] isnotNone:
346
+
torch.manual_seed(seeds[0] +gi)
347
+
if"fp32_gate"inkwargs:
348
+
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
0 commit comments