Skip to content

Commit 5ab0b9b

Browse files
spcypptfacebook-github-bot
authored andcommitted
Implement generate_vbe_metadata cpu (pytorch#3715)
Summary: X-link: facebookresearch/FBGEMM#796 This diff implements `generate_vbe_metadata` for cpu, such that the function returns the same output for CPU, CUDA and MTIA. To support VBE on CPU with existing fixed-batch-size CPU kernel, we need to recompute offsets, which is previously done in python. This diff implements offsets recomputation in C++ such that all manipulations are done in C++. Note that reshaping offsets and grad_input to work with existing fixed-batch-size CPU kernels are done in Autograd instead of wrapper to avoid multiple computations. VBE CPU tests are in the next diff. Reviewed By: sryap Differential Revision: D69162870
1 parent d523692 commit 5ab0b9b

9 files changed

+270
-75
lines changed

fbgemm_gpu/cmake/TbeTraining.cmake

+2
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ gpu_cpp_library(
158158
${gen_gpu_files_forward_split}
159159
NVCC_FLAGS
160160
${TORCH_CUDA_OPTIONS}
161+
DEPS
162+
fbgemm_gpu_tbe_common
161163
DESTINATION
162164
fbgemm_gpu)
163165

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp

+26-19
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@
4242
#include "torch/csrc/autograd/record_function_ops.h"
4343
#include "torch/csrc/autograd/record_function_ops.h"
4444

45-
{%- if has_vbe_support %}
46-
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
47-
{%- endif %}
48-
4945
#include "pt2_arg_utils.h"
5046

5147
using Tensor = at::Tensor;
@@ -124,6 +120,9 @@ enum SSDTensor {
124120
const c10::SymInt /*vbe_output_size*/,
125121
const int64_t /*info_B_num_bits*/,
126122
const int64_t /*info_B_mask_int64*/,
123+
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and output
124+
const Tensor& /*vbe_output_offsets_feature_rank*/, // for reshaping vbe cpu output
125+
const int64_t /*max_B_int*/, // for reshaping vbe cpu offsets
127126
{%- endif %}
128127
{%- if is_gwd %}
129128
const Tensor& /*prev_iter_dev*/,
@@ -168,6 +167,9 @@ enum SSDTensor {
168167
vbe_output_size,
169168
info_B_num_bits,
170169
info_B_mask_int64,
170+
vbe_B_offsets_rank_per_feature_, // for reshaping vbe cpu offsets and output
171+
vbe_output_offsets_feature_rank_, // for reshaping vbe cpu output
172+
max_B_int, // for reshaping vbe cpu offsets
171173
{%- endif %} {# /* if vbe */ #}
172174
{%- if is_gwd %}
173175
prev_iter_dev_,
@@ -244,6 +246,8 @@ enum SSDTensor {
244246
const Tensor& /*B_offsets*/,
245247
const Tensor& /*vbe_row_output_offsets*/,
246248
const Tensor& /*vbe_b_t_map*/,
249+
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and grad output
250+
const int64_t /*max_B*/, // for reshaping vbe cpu offsets
247251
{%- endif %}
248252
const bool /*use_uniq_cache_locations_bwd*/,
249253
const bool /*use_homogeneous_placements*/,
@@ -309,6 +313,8 @@ enum SSDTensor {
309313
B_offsets,
310314
vbe_row_output_offsets,
311315
vbe_b_t_map,
316+
vbe_B_offsets_rank_per_feature, // for reshaping vbe cpu offsets and grad output
317+
max_B, // for reshaping vbe cpu offsets
312318
{%- endif %} {# /* if vbe */ #}
313319
{%- if not dense %}
314320
use_uniq_cache_locations_bwd,
@@ -689,6 +695,7 @@ class {{ autograd_func }} :
689695
const auto info_B_mask = static_cast<uint32_t>(aux_int[IDX_INFO_B_MASK]);
690696

691697
{%- if vbe %}
698+
const int64_t max_B_int = max_B_.guard_int(__FILE__, __LINE__); // for reshaping vbe cpu offsets and grad_output
692699
static auto generate_vbe_metadata_op =
693700
torch::Dispatcher::singleton()
694701
.findSchemaOrThrow("fbgemm::generate_vbe_metadata", "")
@@ -766,6 +773,7 @@ class {{ autograd_func }} :
766773
B_offsets_,
767774
vbe_row_output_offsets,
768775
vbe_b_t_map,
776+
vbe_B_offsets_rank_per_feature_, // for reshaping vbe cpu grad_output
769777
{%- endif %}
770778
{%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %}
771779
prev_iter_dev_,
@@ -808,6 +816,9 @@ class {{ autograd_func }} :
808816
{%- if not nobag %}
809817
ctx->saved_data["output_dtype"] = output_dtype;
810818
{%- endif %}
819+
{%- if vbe %}
820+
ctx->saved_data["max_B"] = max_B_int; // for reshaping vbe cpu offsets and grad_output
821+
{%- endif %}
811822

812823
{%- if not dense %}
813824
// unpack optim args
@@ -894,6 +905,7 @@ static torch::autograd::variable_list backward(
894905
auto B_offsets = *savedItr++;
895906
auto vbe_row_output_offsets = *savedItr++;
896907
auto vbe_b_t_map = *savedItr++;
908+
auto vbe_B_offsets_rank_per_feature = *savedItr++; // for reshaping vbe cpu grad_output
897909
{%- endif %}
898910
{%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %}
899911
auto prev_iter_dev = *savedItr++;
@@ -939,6 +951,10 @@ static torch::autograd::variable_list backward(
939951
auto output_dtype = ctx->saved_data["output_dtype"].toInt();
940952
{%- endif %}
941953
{%- if not dense %}
954+
{%- if vbe %}
955+
auto max_B = ctx->saved_data["max_B"].toInt(); // for reshaping vbe cpu offsets and grad_output
956+
{%- endif %}
957+
942958
{%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %}
943959
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
944960
{%- endfor %}
@@ -976,19 +992,6 @@ static torch::autograd::variable_list backward(
976992
// {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda)
977993
weights_dev = weights_dev.flatten();
978994
{%- endif %}
979-
{%- if vbe %}
980-
// TODO: remove this once vbe_metadata for cpu is implemented
981-
// MTIA kernel uses weights_host but follows CUDA implementation,
982-
// so grad_output is already in a correct shape and must not be reshaped
983-
// Reshaping on weights_host here causes MTIA kernel to fail.
984-
// As a hotfix to unblock MTIA, we add condition check dimension so that reshpaing would skip on MTIA
985-
// CUDA and MTIA vbe_b_t_map is size of {total_B} - should be 1 dim
986-
// CPU vbe_b_t_map is B_offsets_rank_per_feature, so shape should be {num_features, batch_offsets}
987-
// This will be removed totally once vbe_metadata for cpu is implemented
988-
if (weights_host.numel() > 1 && vbe_b_t_map.dim() > 1){
989-
grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets);
990-
}
991-
{%- endif %}
992995

993996
{%- set grad_indice_weights_op =
994997
"{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc)
@@ -1023,7 +1026,9 @@ static torch::autograd::variable_list backward(
10231026
const Tensor& /*vbe_row_output_offsets*/,
10241027
const Tensor& /*vbe_b_t_map*/,
10251028
const int64_t /*info_B_num_bits*/,
1026-
const int64_t /*info_B_mask_int64*/
1029+
const int64_t /*info_B_mask_int64*/,
1030+
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu grad_output
1031+
const int64_t /*max_B*/ // for reshaping vbe cpu offsets and grad_output
10271032
{%- else %}
10281033
const Tensor& /*feature_requires_grad*/
10291034
{%- endif %}
@@ -1053,7 +1058,9 @@ static torch::autograd::variable_list backward(
10531058
vbe_row_output_offsets,
10541059
vbe_b_t_map,
10551060
info_B_num_bits,
1056-
info_B_mask_int64
1061+
info_B_mask_int64,
1062+
vbe_B_offsets_rank_per_feature,
1063+
max_B
10571064
{%- else %}
10581065
feature_requires_grad
10591066
{%- endif %}

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp

+48-12
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
#include "fbgemm_gpu/utils/ops_utils.h"
2727
#include "fbgemm_gpu/utils/dispatch_macros.h"
2828
#include "fbgemm_gpu/embedding_common.h"
29+
{%- if has_vbe_support %}
30+
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
31+
{%- endif %}
2932

3033
using Tensor = at::Tensor;
3134
using namespace fbgemm_gpu;
@@ -53,23 +56,39 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
5356
const Tensor& vbe_row_output_offsets,
5457
const Tensor& vbe_b_t_map,
5558
const int64_t info_B_num_bits,
56-
const int64_t info_B_mask_int64
59+
const int64_t info_B_mask_int64,
60+
const Tensor& vbe_B_offsets_rank_per_feature,
61+
const int64_t max_B
5762
{%- else %}
5863
const Tensor& feature_requires_grad
5964
{%- endif %}
6065
) {
66+
{%- if vbe %}
67+
const auto offsets_ = reshape_vbe_offsets(
68+
offsets,
69+
vbe_B_offsets_rank_per_feature,
70+
max_B,
71+
D_offsets.numel() - 1
72+
);
73+
const auto grad_output_ = reshape_vbe_output(
74+
grad_output,
75+
max_B,
76+
vbe_B_offsets_rank_per_feature,
77+
D_offsets
78+
);
79+
{%- endif %}
6180
static auto op =
6281
torch::Dispatcher::singleton()
6382
.findSchemaOrThrow(
6483
"fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "")
6584
.typed<Tensor(Tensor,Tensor,Tensor,Tensor,Tensor,Tensor,Tensor)>();
6685
return op.call(
67-
grad_output,
86+
{{ "grad_output_" if vbe else "grad_output" }},
6887
host_weights,
6988
weights_offsets,
7089
D_offsets,
7190
indices,
72-
offsets,
91+
{{ "offsets_" if vbe else "offsets" }},
7392
feature_requires_grad);
7493
}
7594
{%- endif %}
@@ -96,14 +115,20 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
96115
const Tensor& /*lxu_cache_locations*/,
97116
const Tensor& /*uvm_cache_stats*/,
98117
{%- if vbe %}
99-
const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/
100-
const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/
118+
const Tensor& vbe_row_output_offsets,
119+
const Tensor& vbe_b_t_map,
101120
const c10::SymInt vbe_output_size,
102121
const int64_t info_B_num_bits,
103122
const int64_t info_B_mask_int64,
123+
const Tensor& vbe_B_offsets_rank_per_feature,
124+
const Tensor& vbe_output_offsets_feature_rank,
125+
const int64_t max_B,
104126
{%- endif %}
105127
const bool /*is_experimental = false*/,
106128
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
129+
{%- if vbe %}
130+
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
131+
{%- endif %}
107132
static auto op =
108133
torch::Dispatcher::singleton()
109134
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
@@ -112,16 +137,14 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
112137
)>();
113138
{%- if vbe %}
114139
// TODO: remove this after vbe is implemented for CPU kernel
115-
Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map;
116-
Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets;
117140
const auto output = op.call(
118141
host_weights,
119142
weights_offsets,
120143
D_offsets,
121144
total_D,
122145
hash_size_cumsum,
123146
indices,
124-
offsets,
147+
offsets_,
125148
pooling_mode,
126149
indice_weights,
127150
output_dtype);
@@ -192,6 +215,8 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
192215
const Tensor& B_offsets,
193216
const Tensor& vbe_row_output_offsets,
194217
const Tensor& vbe_b_t_map,
218+
const Tensor& vbe_B_offsets_rank_per_feature,
219+
const int64_t max_B,
195220
{%- endif %}
196221
const bool /*use_uniq_cache_locations*/,
197222
const bool /*use_homogeneous_placements*/,
@@ -200,6 +225,10 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
200225
, const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)
201226
{%- endif %})
202227
{
228+
{%- if vbe %}
229+
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
230+
const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
231+
{%- endif %}
203232
{%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(
204233
optimizer
205234
)
@@ -230,7 +259,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
230259
)>();
231260

232261
op.call(
233-
grad_output,
262+
{{ "grad_output_" if vbe else "grad_output" }},
234263
host_weights,
235264
weights_placements,
236265
weights_offsets,
@@ -239,7 +268,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
239268
hash_size_cumsum,
240269
total_hash_size_bits,
241270
indices,
242-
offsets,
271+
{{ "offsets_" if vbe else "offsets" }},
243272
pooling_mode,
244273
indice_weights,
245274
stochastic_rounding,
@@ -248,7 +277,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
248277
, output_dtype
249278
{%- endif %}
250279
);
251-
return grad_output;
280+
return Tensor();
252281
}
253282
{% endif %}
254283
{%- endfor %} {#-/*for weighted*/#}
@@ -293,6 +322,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
293322
" SymInt vbe_output_size, "
294323
" int info_B_num_bits, "
295324
" int info_B_mask_int64, "
325+
" Tensor vbe_B_offsets_rank_per_feature, "
326+
" Tensor vbe_output_offsets_feature_rank, "
327+
" int max_B, "
296328
{%- endif %}
297329
" bool is_experimental, "
298330
" int output_dtype "
@@ -345,6 +377,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
345377
" Tensor B_offsets, "
346378
" Tensor vbe_row_output_offsets, "
347379
" Tensor vbe_b_t_map, "
380+
" Tensor vbe_B_offsets_rank_per_feature, "
381+
" int max_B, "
348382
{%- endif %}
349383
" bool use_uniq_cache_locations, "
350384
" bool use_homogeneous_placements,"
@@ -381,7 +415,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
381415
" Tensor vbe_row_output_offsets, "
382416
" Tensor vbe_b_t_map, "
383417
" int info_B_num_bits, "
384-
" int info_B_mask_int64"
418+
" int info_B_mask_int64, "
419+
" Tensor vbe_B_offsets_rank_per_feature, "
420+
" int max_B "
385421
{%- else %}
386422
" Tensor feature_requires_grad"
387423
{%- endif %}

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt
9393
const c10::SymInt vbe_output_size,
9494
const int64_t info_B_num_bits,
9595
const int64_t info_B_mask_int64,
96+
const Tensor& vbe_B_offsets_rank_per_feature,
97+
const Tensor& vbe_output_offsets_feature_rank,
98+
const int64_t max_B,
9699
{%- endif %}
97100
{%- if is_gwd %}
98101
const Tensor& prev_iter_dev,
@@ -241,6 +244,8 @@ Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{
241244
const Tensor& B_offsets,
242245
const Tensor& vbe_row_output_offsets,
243246
const Tensor& vbe_b_t_map,
247+
const Tensor& vbe_B_offsets_rank_per_feature,
248+
const int64_t max_B,
244249
{%- endif %}
245250
const bool use_uniq_cache_locations,
246251
const bool use_homogeneous_placements,
@@ -403,7 +408,9 @@ Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_{{ d
403408
const Tensor& vbe_row_output_offsets,
404409
const Tensor& vbe_b_t_map,
405410
const int64_t info_B_num_bits,
406-
const int64_t info_B_mask_int64
411+
const int64_t info_B_mask_int64,
412+
const Tensor& vbe_B_offsets_rank_per_feature,
413+
const int64_t max_B
407414
{%- else %}
408415
const Tensor& feature_requires_grad
409416
{%- endif %}
@@ -529,6 +536,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
529536
" SymInt vbe_output_size, "
530537
" int info_B_num_bits, "
531538
" int info_B_mask_int64, "
539+
" Tensor vbe_B_offsets_rank_per_feature, "
540+
" Tensor vbe_output_offsets_feature_rank, "
541+
" int max_B, "
532542
{%- endif %}
533543
{%- if is_gwd %}
534544
" Tensor{{ schema_annotation['prev_iter_dev'] }} prev_iter_dev, "
@@ -599,6 +609,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
599609
" Tensor B_offsets, "
600610
" Tensor vbe_row_output_offsets, "
601611
" Tensor vbe_b_t_map, "
612+
" Tensor vbe_B_offsets_rank_per_feature, "
613+
" int max_B, "
602614
{%- endif %}
603615
" bool use_uniq_cache_locations, "
604616
" bool use_homogeneous_placements,"
@@ -656,7 +668,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
656668
" Tensor vbe_row_output_offsets, "
657669
" Tensor vbe_b_t_map, "
658670
" int info_B_num_bits, "
659-
" int info_B_mask_int64"
671+
" int info_B_mask_int64, "
672+
" Tensor vbe_B_offsets_rank_per_feature, "
673+
" int max_B "
660674
{%- else %}
661675
" Tensor feature_requires_grad"
662676
{%- endif %}

0 commit comments

Comments
 (0)