Skip to content

Commit 6b434d9

Browse files
authored
fix parallel methods (#40)
1 parent ef8bcc8 commit 6b434d9

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

tutel/impls/communicate.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,48 @@ def backward(ctx: Any, grad_output: Tensor):
6868

6969
class PreAllreduceSum(torch.autograd.Function):
7070
@staticmethod
71-
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor):
71+
def forward(ctx, group, input):
7272
ctx.group = group
73-
return input
74-
73+
ctx.num_nodes = get_world_size(ctx.group)
74+
if ctx.num_nodes <= 1:
75+
return input
76+
ctx.input_shape = input.shape
77+
output = torch.empty([ctx.num_nodes, input.numel()], device=input.device, dtype=input.dtype)
78+
tensor_list = [x.contiguous() for x in torch.chunk(output, chunks=ctx.num_nodes, dim=0)]
79+
dist.all_gather(tensor_list=tensor_list, tensor=input.contiguous())
80+
output = output.view(list(input.shape[:0]) + [input.shape[0] * ctx.num_nodes] + list(input.shape[1:]))
81+
return output
7582
@staticmethod
76-
def backward(ctx: Any, grad_output: Tensor):
83+
def backward(ctx, doutput):
7784
if get_world_size(ctx.group) <= 1:
78-
return (None, grad_output)
79-
dinput = torch.clone(grad_output).contiguous()
80-
dist.all_reduce(dinput, op=torch.distributed.ReduceOp.SUM)
85+
return (None, doutput)
86+
dinput = torch.empty(ctx.input_shape, device=doutput.device, dtype=doutput.dtype)
87+
chunks = [x.contiguous() for x in torch.chunk(doutput.view(ctx.num_nodes, -1), chunks=ctx.num_nodes, dim=0)]
88+
dist.reduce_scatter(output=dinput, input_list=chunks)
8189
return (None, dinput)
8290

8391
class PostAllreduceSum(torch.autograd.Function):
8492
@staticmethod
85-
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor):
86-
if get_world_size(group) <= 1:
93+
def forward(ctx, group, input):
94+
ctx.group = group
95+
ctx.num_nodes = get_world_size(ctx.group)
96+
if ctx.num_nodes <= 1:
8797
return input
88-
output = torch.clone(input).contiguous()
89-
dist.all_reduce(output, op=torch.distributed.ReduceOp.SUM)
98+
ctx.input_shape = input.shape
99+
ctx.leading_dim = 0
100+
chunks = [x.contiguous() for x in torch.chunk(input, chunks=ctx.num_nodes, dim=ctx.leading_dim)]
101+
assert len(chunks) == ctx.num_nodes
102+
output = torch.empty_like(chunks[0])
103+
dist.reduce_scatter(output=output, input_list=list(chunks))
90104
return output
91-
92105
@staticmethod
93-
def backward(ctx: Any, grad_output: Tensor):
94-
return (None, grad_output)
106+
def backward(ctx, doutput):
107+
if ctx.num_nodes <= 1:
108+
return (None, doutput)
109+
dinput = torch.empty(ctx.input_shape, device=doutput.device, dtype=doutput.dtype)
110+
tensor_list = [x.contiguous() for x in torch.chunk(dinput, chunks=ctx.num_nodes, dim=ctx.leading_dim)]
111+
dist.all_gather(tensor_list=tensor_list, tensor=doutput)
112+
return (None, dinput)
95113

96114

97115
# A2A_TYPE: 0 for skip AllToAll, 1 for standard Pytorch AllToAll, 9 for standard Pytorch AllToAll with Timing

tutel/impls/moe_layer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def named_parameters(self):
157157
def apply_on_expert_fn(self, input, expert_fn, group):
158158
if self.l_zero is None:
159159
self.l_zero = torch.tensor(0, dtype=input.dtype, device=input.device)
160-
result_output = expert_fn(PreAllreduceSum.apply(group, input))
160+
gathered_input = PreAllreduceSum.apply(group, input)
161+
result_output = expert_fn(gathered_input)
161162
result_output = PostAllreduceSum.apply(group, result_output)
162163
return result_output, self.l_zero
163164

0 commit comments

Comments
 (0)