26
26
#include " fbgemm_gpu/utils/ops_utils.h"
27
27
#include " fbgemm_gpu/utils/dispatch_macros.h"
28
28
#include " fbgemm_gpu/embedding_common.h"
29
+ {%- if has_vbe_support %}
30
+ #include " fbgemm_gpu/utils/pt2_autograd_utils.h"
31
+ {%- endif %}
29
32
30
33
using Tensor = at::Tensor;
31
34
using namespace fbgemm_gpu ;
@@ -53,23 +56,39 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
53
56
const Tensor& vbe_row_output_offsets,
54
57
const Tensor& vbe_b_t_map,
55
58
const int64_t info_B_num_bits,
56
- const int64_t info_B_mask_int64
59
+ const int64_t info_B_mask_int64,
60
+ const Tensor& vbe_B_offsets_rank_per_feature,
61
+ const int64_t max_B
57
62
{%- else %}
58
63
const Tensor& feature_requires_grad
59
64
{%- endif %}
60
65
) {
66
+ {%- if vbe %}
67
+ const auto offsets_ = reshape_vbe_offsets (
68
+ offsets,
69
+ vbe_B_offsets_rank_per_feature,
70
+ max_B,
71
+ D_offsets.numel () - 1
72
+ );
73
+ const auto grad_output_ = reshape_vbe_output (
74
+ grad_output,
75
+ max_B,
76
+ vbe_B_offsets_rank_per_feature,
77
+ D_offsets
78
+ );
79
+ {%- endif %}
61
80
static auto op =
62
81
torch::Dispatcher::singleton ()
63
82
.findSchemaOrThrow (
64
83
" fbgemm::split_embedding_codegen_grad_indice_weights_cpu" , " " )
65
84
.typed <Tensor (Tensor,Tensor,Tensor,Tensor,Tensor,Tensor,Tensor)>();
66
85
return op.call (
67
- grad_output,
86
+ {{ " grad_output_ " if vbe else " grad_output" }} ,
68
87
host_weights,
69
88
weights_offsets,
70
89
D_offsets,
71
90
indices,
72
- offsets,
91
+ {{ " offsets_ " if vbe else " offsets" }} ,
73
92
feature_requires_grad);
74
93
}
75
94
{%- endif %}
@@ -96,14 +115,20 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
96
115
const Tensor& /* lxu_cache_locations*/ ,
97
116
const Tensor& /* uvm_cache_stats*/ ,
98
117
{%- if vbe %}
99
- const Tensor& vbe_row_output_offsets, /* vbe_output_offsets_feature_rank */
100
- const Tensor& vbe_b_t_map, /* vbe_B_offsets_rank_per_feature */
118
+ const Tensor& vbe_row_output_offsets,
119
+ const Tensor& vbe_b_t_map,
101
120
const c10::SymInt vbe_output_size,
102
121
const int64_t info_B_num_bits,
103
122
const int64_t info_B_mask_int64,
123
+ const Tensor& vbe_B_offsets_rank_per_feature,
124
+ const Tensor& vbe_output_offsets_feature_rank,
125
+ const int64_t max_B,
104
126
{%- endif %}
105
127
const bool /* is_experimental = false*/ ,
106
128
const int64_t output_dtype = static_cast <int64_t >(SparseType::FP32)) {
129
+ {%- if vbe %}
130
+ const auto offsets_ = reshape_vbe_offsets (offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel () - 1 );
131
+ {%- endif %}
107
132
static auto op =
108
133
torch::Dispatcher::singleton ()
109
134
.findSchemaOrThrow (" fbgemm::split_embedding_codegen_forward_cpu" , " " )
@@ -112,16 +137,14 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
112
137
)>();
113
138
{%- if vbe %}
114
139
// TODO: remove this after vbe is implemented for CPU kernel
115
- Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map;
116
- Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets;
117
140
const auto output = op.call (
118
141
host_weights,
119
142
weights_offsets,
120
143
D_offsets,
121
144
total_D,
122
145
hash_size_cumsum,
123
146
indices,
124
- offsets ,
147
+ offsets_ ,
125
148
pooling_mode,
126
149
indice_weights,
127
150
output_dtype);
@@ -192,6 +215,8 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
192
215
const Tensor& B_offsets,
193
216
const Tensor& vbe_row_output_offsets,
194
217
const Tensor& vbe_b_t_map,
218
+ const Tensor& vbe_B_offsets_rank_per_feature,
219
+ const int64_t max_B,
195
220
{%- endif %}
196
221
const bool /* use_uniq_cache_locations*/ ,
197
222
const bool /* use_homogeneous_placements*/ ,
@@ -200,6 +225,10 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
200
225
, const int64_t output_dtype = static_cast <int64_t >(SparseType::FP32)
201
226
{%- endif %})
202
227
{
228
+ {%- if vbe %}
229
+ const auto offsets_ = reshape_vbe_offsets (offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel () - 1 );
230
+ const auto grad_output_ = reshape_vbe_output (grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
231
+ {%- endif %}
203
232
{%- set backward_op = " split_embedding_backward_codegen_{}_cpu" .format (
204
233
optimizer
205
234
)
@@ -230,7 +259,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
230
259
)>();
231
260
232
261
op.call (
233
- grad_output,
262
+ {{ " grad_output_ " if vbe else " grad_output" }} ,
234
263
host_weights,
235
264
weights_placements,
236
265
weights_offsets,
@@ -239,7 +268,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
239
268
hash_size_cumsum,
240
269
total_hash_size_bits,
241
270
indices,
242
- offsets,
271
+ {{ " offsets_ " if vbe else " offsets" }} ,
243
272
pooling_mode,
244
273
indice_weights,
245
274
stochastic_rounding,
@@ -248,7 +277,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
248
277
, output_dtype
249
278
{%- endif %}
250
279
);
251
- return grad_output ;
280
+ return Tensor () ;
252
281
}
253
282
{% endif %}
254
283
{%- endfor %} {#-/* for weighted*/ #}
@@ -293,6 +322,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
293
322
" SymInt vbe_output_size, "
294
323
" int info_B_num_bits, "
295
324
" int info_B_mask_int64, "
325
+ " Tensor vbe_B_offsets_rank_per_feature, "
326
+ " Tensor vbe_output_offsets_feature_rank, "
327
+ " int max_B, "
296
328
{%- endif %}
297
329
" bool is_experimental, "
298
330
" int output_dtype "
@@ -345,6 +377,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
345
377
" Tensor B_offsets, "
346
378
" Tensor vbe_row_output_offsets, "
347
379
" Tensor vbe_b_t_map, "
380
+ " Tensor vbe_B_offsets_rank_per_feature, "
381
+ " int max_B, "
348
382
{%- endif %}
349
383
" bool use_uniq_cache_locations, "
350
384
" bool use_homogeneous_placements,"
@@ -381,7 +415,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
381
415
" Tensor vbe_row_output_offsets, "
382
416
" Tensor vbe_b_t_map, "
383
417
" int info_B_num_bits, "
384
- " int info_B_mask_int64"
418
+ " int info_B_mask_int64, "
419
+ " Tensor vbe_B_offsets_rank_per_feature, "
420
+ " int max_B "
385
421
{%- else %}
386
422
" Tensor feature_requires_grad"
387
423
{%- endif %}
0 commit comments