|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +# Copyright (c) Microsoft Corporation. |
| 4 | +# Licensed under the MIT license. |
| 5 | + |
| 6 | +import os |
| 7 | +import torch |
| 8 | +import torch.optim as optim |
| 9 | +import torch.nn.functional as F |
| 10 | +from torch import nn |
| 11 | +import argparse |
| 12 | + |
| 13 | +from tutel import system |
| 14 | +from tutel import moe as tutel_moe |
| 15 | +from tutel import net |
| 16 | + |
| 17 | +parser = argparse.ArgumentParser() |
| 18 | + |
| 19 | +parser.add_argument('--local_rank', type=int, default=-1) |
| 20 | +parser.add_argument('--batch_size', type=int, default=16) |
| 21 | +parser.add_argument('--num_tokens', type=int, default=512) |
| 22 | +parser.add_argument('--model_dim', type=int, default=2048) |
| 23 | +parser.add_argument('--hidden_size', type=int, default=2048) |
| 24 | +parser.add_argument('--num_local_experts', type=int, default=2) |
| 25 | +parser.add_argument('--dtype', type=str, default='float32') |
| 26 | +parser.add_argument('--fp32_gate', default=False, action='store_true') |
| 27 | +parser.add_argument('--top', type=int, default=2) |
| 28 | +parser.add_argument('--l_aux_wt', type=float, default=0.0) |
| 29 | +parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1) |
| 30 | +parser.add_argument('--allreduce_degree', type=int, default=1) |
| 31 | +parser.add_argument('--num_steps', type=int, default=100) |
| 32 | +parser.add_argument('--cap_factor', type=float, default=1.0) |
| 33 | +parser.add_argument('--checkpoint_path', type=str, default='') |
| 34 | +parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| 35 | +parser.add_argument('--use_2dh', default=False, action='store_true') |
| 36 | +parser.add_argument('--eval', default=False, action='store_true') |
| 37 | +args = parser.parse_args() |
| 38 | + |
| 39 | +parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo') |
| 40 | +dist_rank, dist_world_size, dist_print = parallel_env.global_rank, parallel_env.global_size, parallel_env.dist_print |
| 41 | +args.local_rank = parallel_env.local_device.index |
| 42 | + |
| 43 | +batch_size = args.batch_size |
| 44 | +num_tokens = args.num_tokens |
| 45 | +model_dim = args.model_dim |
| 46 | +hidden_size = args.hidden_size |
| 47 | +num_local_experts = args.num_local_experts |
| 48 | +top_value = args.top |
| 49 | +a2a_ffn_overlap_degree = args.a2a_ffn_overlap_degree |
| 50 | +device = parallel_env.local_device |
| 51 | + |
| 52 | +if args.dtype == 'float32': |
| 53 | + torch.set_default_dtype(torch.float32) |
| 54 | +elif args.dtype == 'float64': |
| 55 | + torch.set_default_dtype(torch.float64) |
| 56 | +elif args.dtype == 'float16': |
| 57 | + torch.set_default_dtype(torch.float16) |
| 58 | +elif args.dtype == 'bfloat16': |
| 59 | + torch.set_default_dtype(torch.bfloat16) |
| 60 | +else: |
| 61 | + raise Exception('Unrecognized data type specified: %s' % args.dtype) |
| 62 | + |
| 63 | + |
| 64 | +class ExampleModel(torch.nn.Module): |
| 65 | + def __init__(self): |
| 66 | + super().__init__() |
| 67 | + |
| 68 | + self._moe_layer = tutel_moe.moe_layer( |
| 69 | + gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate}, |
| 70 | + experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)}, |
| 71 | + model_dim = model_dim, |
| 72 | + scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True), |
| 73 | + seeds = (1, dist_rank + 1, 1), |
| 74 | + a2a_ffn_overlap_degree = a2a_ffn_overlap_degree, |
| 75 | + use_2dh=args.use_2dh, |
| 76 | + ) |
| 77 | + |
| 78 | + # Summary of different parameter types: gate, local_experts |
| 79 | + local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')]) |
| 80 | + shared_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='gate')]) |
| 81 | + dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count)) |
| 82 | + self.r_index = -1 |
| 83 | + |
| 84 | + def forward(self, input): |
| 85 | + r, o = self._moe_layer.valid_rs[(self.r_index // 8) % len(self._moe_layer.valid_rs)], self.r_index % 8 + 1 |
| 86 | + self.r_index += 1 |
| 87 | + |
| 88 | + result = self._moe_layer(input, capacity_factor=args.cap_factor, adaptive_r=r, a2a_ffn_overlap_degree=o) |
| 89 | + result = F.log_softmax(torch.sum(result, dim=2), dim=1) |
| 90 | + return result |
| 91 | + |
| 92 | +model = ExampleModel().to(device) |
| 93 | +dist_print(model) |
| 94 | + |
| 95 | +if args.checkpoint_path: |
| 96 | + checkpoint_path = system.apply_rank_size_from_pattern(args.checkpoint_path, rank=parallel_env.global_rank, size=parallel_env.global_size) |
| 97 | + if os.path.exists(checkpoint_path): |
| 98 | + model.load_state_dict(torch.load(checkpoint_path)) |
| 99 | + else: |
| 100 | + print('Checkpoint not loaded: file `%s` is not found. Will train the model from start.' % checkpoint_path) |
| 101 | + |
| 102 | +optimizer = torch.optim.SGD(model.parameters(), lr=1e-5) |
| 103 | + |
| 104 | +torch.manual_seed(0) |
| 105 | +x = torch.tensor(torch.randn([batch_size, num_tokens, model_dim], dtype=torch.float32, device='cpu').detach().numpy(), dtype=torch.get_default_dtype(), requires_grad=False, device=device) |
| 106 | +y = torch.LongTensor(batch_size).random_(1).to(device) |
| 107 | + |
| 108 | +tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, a2a_ffn_overlap_degree, device) |
| 109 | +dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, device = `%s`' % tuples) |
| 110 | + |
| 111 | +average_time, num_steps = 0, args.num_steps |
| 112 | + |
| 113 | +if args.allreduce_degree == -1: |
| 114 | + params_for_all_reduce = [] |
| 115 | +else: |
| 116 | + params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'skip_allreduce') and getattr(p, 'requires_grad', False)] |
| 117 | + |
| 118 | +for i in range(num_steps): |
| 119 | + t_start = system.record_time() |
| 120 | + |
| 121 | + if not args.eval: |
| 122 | + optimizer.zero_grad() |
| 123 | + output = model(x) |
| 124 | + loss = F.nll_loss(output, y) |
| 125 | + if args.l_aux_wt: |
| 126 | + loss += args.l_aux_wt * model._moe_layer.l_aux |
| 127 | + loss.backward() |
| 128 | + if dist_world_size > 1: |
| 129 | + for p in params_for_all_reduce: |
| 130 | + p.grad /= dist_world_size |
| 131 | + p.grad = net.simple_all_reduce(p.grad) |
| 132 | + optimizer.step() |
| 133 | + else: |
| 134 | + with torch.no_grad(): |
| 135 | + output = model(x) |
| 136 | + loss = F.nll_loss(output, y) |
| 137 | + |
| 138 | + t_stop = system.record_time() |
| 139 | + |
| 140 | + num_global_experts = tutel_moe.moe_layer.global_expert_count(num_local_experts, group=system.get_local_session().model_group) |
| 141 | + mm_ceof, cap_ceof = 1 if args.eval else 3, min(args.top, num_global_experts) |
| 142 | + tflops = (batch_size * num_tokens * model_dim * hidden_size) * 4 * mm_ceof * cap_ceof * 1e-12 / (t_stop - t_start) |
| 143 | + dist_print('STEP-%s: loss = %.5f, step_time = %.6f sec, perf = %.2f tflops. (f = %.1f, r = %d, o = %d)' % (i, float(loss.data), t_stop - t_start, tflops, args.cap_factor, model._moe_layer.adaptive_degree, model._moe_layer.a2a_ffn_overlap_degree)) |
| 144 | + |
| 145 | + if i + 10 >= num_steps: |
| 146 | + average_time += t_stop - t_start |
| 147 | + |
| 148 | +average_time /= 10 |
| 149 | +dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time) |
| 150 | + |
| 151 | +if args.checkpoint_path: |
| 152 | + torch.save(model.state_dict(), checkpoint_path) |
0 commit comments