Skip to content

Commit 712bf2e

Browse files
authored
update data flow for auto parallel (#105)
* update data flow for auto parallel * not scale_samples by default for acceleration
1 parent 14f1ae3 commit 712bf2e

8 files changed

+64
-58
lines changed

tests/test_tutel.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def run(
3737
if helloworld_file == 'helloworld':
3838
command = 'python3 -m torch.distributed.launch --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps)
3939
if use_model_parallel:
40-
command += ' --use_model_parallel'
40+
command += ' --parallel_type model'
41+
else:
42+
command += ' --parallel_type data'
4143

4244
p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=new_env)
4345
losses = []

tutel/examples/helloworld.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
parser.add_argument('--l_aux_wt', type=float, default=0.0)
3333
parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1)
3434
parser.add_argument('--num_steps', type=int, default=100)
35-
parser.add_argument('--use_model_parallel', default=False, action='store_true')
35+
parser.add_argument('--parallel_type', type=str, default='auto')
3636
parser.add_argument('--save_load_checkpoint', default=False, action='store_true')
3737
args = parser.parse_args()
3838

@@ -72,8 +72,8 @@ def __init__(self):
7272
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
7373
seeds = (1, dist_rank + 1, 1),
7474
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree,
75-
use_model_parallel = args.use_model_parallel,
76-
).to(device)
75+
parallel_type = args.parallel_type,
76+
)
7777

7878
# Summary of different parameter types: gate, local_experts
7979
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
@@ -85,7 +85,7 @@ def forward(self, input):
8585
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
8686
return result
8787

88-
model = ExampleModel()
88+
model = ExampleModel().to(device)
8989
dist_print(model)
9090

9191
if args.save_load_checkpoint:
@@ -101,8 +101,8 @@ def forward(self, input):
101101
x = torch.tensor(torch.randn([batch_size, num_tokens, model_dim], dtype=torch.float32, device='cpu').detach().numpy(), dtype=torch.get_default_dtype(), requires_grad=True, device=device)
102102
y = torch.LongTensor(batch_size).random_(1).to(device)
103103

104-
tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, a2a_ffn_overlap_degree, args.use_model_parallel)
105-
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, use_model_parallel = %s' % tuples)
104+
tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, a2a_ffn_overlap_degree, args.parallel_type)
105+
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, parallel_type = `%s`' % tuples)
106106

107107
average_time, num_steps = 0, args.num_steps
108108

tutel/examples/helloworld_amp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self):
7171
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
7272
seeds = (1, dist_rank + 1, 1),
7373
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree,
74-
).to(device)
74+
)
7575

7676
# Summary of different parameter types: gate, local_experts
7777
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
@@ -83,7 +83,7 @@ def forward(self, input):
8383
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
8484
return result
8585

86-
model = ExampleModel()
86+
model = ExampleModel().to(device)
8787
dist_print(model)
8888

8989
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

tutel/examples/helloworld_ddp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(self):
7272
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
7373
seeds = (1, dist_rank + 1, 1),
7474
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree,
75-
).to(device)
75+
)
7676

7777
# Summary of different parameter types: gate, local_experts
7878
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
@@ -88,7 +88,7 @@ def add_param_to_skip_allreduce(self, param_name):
8888
self._ddp_params_and_buffers_to_ignore.append(param_name)
8989

9090

91-
model = ExampleModel()
91+
model = ExampleModel().to(device)
9292

9393
for name, param in model.named_parameters():
9494
if hasattr(param, 'skip_allreduce'):

tutel/examples/helloworld_deepspeed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self):
9797
num_experts = num_local_experts * dist_world_size,
9898
k = top_value,
9999
use_tutel = args.use_tutel
100-
).to(device)
100+
)
101101

102102
for name, param in self._moe_layer.named_parameters():
103103
if '.experts.' in name:
@@ -113,7 +113,7 @@ def forward(self, input):
113113
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
114114
return result
115115

116-
model = ExampleModel()
116+
model = ExampleModel().to(device)
117117
dist_print(model)
118118

119119
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

tutel/examples/helloworld_megatron.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self):
6464
model_dim = model_dim,
6565
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
6666
seeds = (1, dist_rank + 1, 1),
67-
).to(device)
67+
)
6868

6969
# Summary of different parameter types: gate, local_experts
7070
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
@@ -76,7 +76,7 @@ def forward(self, input):
7676
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
7777
return result
7878

79-
model = ExampleModel()
79+
model = ExampleModel().to(device)
8080
dist_print(model)
8181

8282
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

tutel/examples/helloworld_sharded_experts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self):
6666
model_dim = model_dim,
6767
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
6868
seeds = (1, dist_rank + 1, 1),
69-
).to(device)
69+
)
7070

7171
# Summary of different parameter types: gate, local_experts
7272
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
@@ -78,7 +78,7 @@ def forward(self, input):
7878
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
7979
return result
8080

81-
model = ExampleModel()
81+
model = ExampleModel().to(device)
8282
dist_print(model)
8383

8484
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

0 commit comments

Comments
 (0)