@@ -68,30 +68,48 @@ def backward(ctx: Any, grad_output: Tensor):
68
68
69
69
class PreAllreduceSum (torch .autograd .Function ):
70
70
@staticmethod
71
- def forward (ctx : Any , group : dist . ProcessGroup , input : Tensor ):
71
+ def forward (ctx , group , input ):
72
72
ctx .group = group
73
- return input
74
-
73
+ ctx .num_nodes = get_world_size (ctx .group )
74
+ if ctx .num_nodes <= 1 :
75
+ return input
76
+ ctx .input_shape = input .shape
77
+ output = torch .empty ([ctx .num_nodes , input .numel ()], device = input .device , dtype = input .dtype )
78
+ tensor_list = [x .contiguous () for x in torch .chunk (output , chunks = ctx .num_nodes , dim = 0 )]
79
+ dist .all_gather (tensor_list = tensor_list , tensor = input .contiguous ())
80
+ output = output .view (list (input .shape [:0 ]) + [input .shape [0 ] * ctx .num_nodes ] + list (input .shape [1 :]))
81
+ return output
75
82
@staticmethod
76
- def backward (ctx : Any , grad_output : Tensor ):
83
+ def backward (ctx , doutput ):
77
84
if get_world_size (ctx .group ) <= 1 :
78
- return (None , grad_output )
79
- dinput = torch .clone (grad_output ).contiguous ()
80
- dist .all_reduce (dinput , op = torch .distributed .ReduceOp .SUM )
85
+ return (None , doutput )
86
+ dinput = torch .empty (ctx .input_shape , device = doutput .device , dtype = doutput .dtype )
87
+ chunks = [x .contiguous () for x in torch .chunk (doutput .view (ctx .num_nodes , - 1 ), chunks = ctx .num_nodes , dim = 0 )]
88
+ dist .reduce_scatter (output = dinput , input_list = chunks )
81
89
return (None , dinput )
82
90
83
91
class PostAllreduceSum (torch .autograd .Function ):
84
92
@staticmethod
85
- def forward (ctx : Any , group : dist .ProcessGroup , input : Tensor ):
86
- if get_world_size (group ) <= 1 :
93
+ def forward (ctx , group , input ):
94
+ ctx .group = group
95
+ ctx .num_nodes = get_world_size (ctx .group )
96
+ if ctx .num_nodes <= 1 :
87
97
return input
88
- output = torch .clone (input ).contiguous ()
89
- dist .all_reduce (output , op = torch .distributed .ReduceOp .SUM )
98
+ ctx .input_shape = input .shape
99
+ ctx .leading_dim = 0
100
+ chunks = [x .contiguous () for x in torch .chunk (input , chunks = ctx .num_nodes , dim = ctx .leading_dim )]
101
+ assert len (chunks ) == ctx .num_nodes
102
+ output = torch .empty_like (chunks [0 ])
103
+ dist .reduce_scatter (output = output , input_list = list (chunks ))
90
104
return output
91
-
92
105
@staticmethod
93
- def backward (ctx : Any , grad_output : Tensor ):
94
- return (None , grad_output )
106
+ def backward (ctx , doutput ):
107
+ if ctx .num_nodes <= 1 :
108
+ return (None , doutput )
109
+ dinput = torch .empty (ctx .input_shape , device = doutput .device , dtype = doutput .dtype )
110
+ tensor_list = [x .contiguous () for x in torch .chunk (dinput , chunks = ctx .num_nodes , dim = ctx .leading_dim )]
111
+ dist .all_gather (tensor_list = tensor_list , tensor = doutput )
112
+ return (None , dinput )
95
113
96
114
97
115
# A2A_TYPE: 0 for skip AllToAll, 1 for standard Pytorch AllToAll, 9 for standard Pytorch AllToAll with Timing
0 commit comments