32
32
parser .add_argument ('--l_aux_wt' , type = float , default = 0.0 )
33
33
parser .add_argument ('--a2a_ffn_overlap_degree' , type = int , default = 1 )
34
34
parser .add_argument ('--num_steps' , type = int , default = 100 )
35
- parser .add_argument ('--use_model_parallel ' , default = False , action = 'store_true ' )
35
+ parser .add_argument ('--parallel_type ' , type = str , default = 'auto ' )
36
36
parser .add_argument ('--save_load_checkpoint' , default = False , action = 'store_true' )
37
37
args = parser .parse_args ()
38
38
@@ -72,8 +72,8 @@ def __init__(self):
72
72
scan_expert_func = lambda name , param : setattr (param , 'skip_allreduce' , True ),
73
73
seeds = (1 , dist_rank + 1 , 1 ),
74
74
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree ,
75
- use_model_parallel = args .use_model_parallel ,
76
- ). to ( device )
75
+ parallel_type = args .parallel_type ,
76
+ )
77
77
78
78
# Summary of different parameter types: gate, local_experts
79
79
local_count = sum ([torch .numel (param ) for name , param in self ._moe_layer .get_parameter_iterator (param_type = 'local_experts' )])
@@ -85,7 +85,7 @@ def forward(self, input):
85
85
result = F .log_softmax (torch .sum (result , dim = 2 ), dim = 1 )
86
86
return result
87
87
88
- model = ExampleModel ()
88
+ model = ExampleModel (). to ( device )
89
89
dist_print (model )
90
90
91
91
if args .save_load_checkpoint :
@@ -101,8 +101,8 @@ def forward(self, input):
101
101
x = torch .tensor (torch .randn ([batch_size , num_tokens , model_dim ], dtype = torch .float32 , device = 'cpu' ).detach ().numpy (), dtype = torch .get_default_dtype (), requires_grad = True , device = device )
102
102
y = torch .LongTensor (batch_size ).random_ (1 ).to (device )
103
103
104
- tuples = (dist_world_size , args .dtype , model_dim , hidden_size , batch_size * num_tokens , num_local_experts , top_value , a2a_ffn_overlap_degree , args .use_model_parallel )
105
- dist_print ('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, use_model_parallel = %s ' % tuples )
104
+ tuples = (dist_world_size , args .dtype , model_dim , hidden_size , batch_size * num_tokens , num_local_experts , top_value , a2a_ffn_overlap_degree , args .parallel_type )
105
+ dist_print ('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, parallel_type = `%s` ' % tuples )
106
106
107
107
average_time , num_steps = 0 , args .num_steps
108
108
0 commit comments