Skip to content

Commit 3e5dc36

Browse files
spcypptfacebook-github-bot
authored andcommitted
Enable int32_t support for reshape_vbe_offsets (pytorch#3782)
Summary: - Enable int32_t support for reshape_vbe_offsets - Fix setting learning_rate as learning_rate_tensor in ssd. X-link: facebookresearch/FBGEMM#866 Reviewed By: nautsimon Differential Revision: D70760386
1 parent de35b3c commit 3e5dc36

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp

+20-8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
#include "fbgemm_gpu/utils/ops_utils.h"
2727
#include "fbgemm_gpu/utils/dispatch_macros.h"
2828
#include "fbgemm_gpu/embedding_common.h"
29+
// #include <ATen/ATen.h>
30+
#include <ATen/Dispatch.h>
31+
#include <ATen/TensorUtils.h>
2932
{%- if has_vbe_support %}
3033
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
3134
{%- endif %}
@@ -64,12 +67,15 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
6467
{%- endif %}
6568
) {
6669
{%- if vbe %}
67-
const auto offsets_ = reshape_vbe_offsets(
68-
offsets,
69-
vbe_B_offsets_rank_per_feature,
70-
max_B,
71-
D_offsets.numel() - 1
72-
);
70+
Tensor offsets_;
71+
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_grad_indices", [&]() {
72+
offsets_ = reshape_vbe_offsets<index_t>(
73+
offsets,
74+
vbe_B_offsets_rank_per_feature,
75+
max_B,
76+
D_offsets.numel() - 1
77+
);
78+
});
7379
const auto grad_output_ = reshape_vbe_output(
7480
grad_output,
7581
max_B,
@@ -126,8 +132,11 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
126132
{%- endif %}
127133
const bool /*is_experimental = false*/,
128134
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
135+
Tensor offsets_;
129136
{%- if vbe %}
130-
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
137+
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_forward", [&]() {
138+
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
139+
});
131140
{%- endif %}
132141
static auto op =
133142
torch::Dispatcher::singleton()
@@ -226,7 +235,10 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
226235
{%- endif %})
227236
{
228237
{%- if vbe %}
229-
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
238+
Tensor offsets_;
239+
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_backward", [&]() {
240+
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
241+
});
230242
const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
231243
{%- endif %}
232244
{%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(

fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void checked_memcpy(
113113
/// size(1) is number of ranks
114114
/// @param max_B Maximum batch size
115115
/// @param T Number of embedding tables (features)
116+
template <typename index_t>
116117
Tensor reshape_vbe_offsets(
117118
const Tensor& offsets,
118119
const Tensor& B_offsets_rank_per_feature,
@@ -125,12 +126,8 @@ Tensor reshape_vbe_offsets(
125126
B_offsets_rank_per_feature.accessor<int32_t, 2>();
126127
auto reshaped_offsets = at::empty({T * max_B + 1}, offsets.options());
127128
// TODO: support other types
128-
TORCH_CHECK(
129-
offsets.dtype() == at::kLong,
130-
"Expected offsets to be int64 but got ",
131-
offsets.dtype());
132-
auto reshaped_offsets_acc = reshaped_offsets.accessor<int64_t, 1>();
133-
auto offsets_acc = offsets.accessor<int64_t, 1>();
129+
auto reshaped_offsets_acc = reshaped_offsets.accessor<index_t, 1>();
130+
auto offsets_acc = offsets.accessor<index_t, 1>();
134131
auto begin = 0;
135132
for (int32_t t = 0; t < T; t++) {
136133
const auto batch_size =
@@ -167,4 +164,16 @@ Tensor reshape_vbe_offsets(
167164
return reshaped_offsets;
168165
}
169166

167+
template Tensor reshape_vbe_offsets<int32_t>(
168+
const Tensor& offsets,
169+
const Tensor& B_offsets_rank_per_feature,
170+
const int64_t max_B,
171+
const int32_t T);
172+
173+
template Tensor reshape_vbe_offsets<int64_t>(
174+
const Tensor& offsets,
175+
const Tensor& B_offsets_rank_per_feature,
176+
const int64_t max_B,
177+
const int32_t T);
178+
170179
} // namespace fbgemm_gpu

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1826,7 +1826,7 @@ def _set_learning_rate(self, lr: float) -> float:
18261826
Helper function to script `set_learning_rate`.
18271827
Note that returning None does not work.
18281828
"""
1829-
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
1829+
self.optimizer_args.learning_rate_tensor.fill_(lr)
18301830
return 0.0
18311831

18321832
def flush(self) -> None:

fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,6 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/TypeDefault.h>
11-
// #include <ATen/core/op_registration/op_registration.h>
12-
// #include <torch/script.h>
13-
// #include "fbgemm_gpu/embedding_common.h"
14-
// #include "fbgemm_gpu/utils/dispatch_macros.h"
15-
// #include "fbgemm_gpu/utils/ops_utils.h"
16-
// #include "fbgemm_gpu/utils/tensor_utils.h"
1711

1812
using Tensor = at::Tensor;
1913

@@ -29,6 +23,7 @@ Tensor reshape_vbe_output(
2923
const Tensor& B_offsets_rank_per_feature,
3024
const Tensor& D_offsets);
3125

26+
template <typename index_t>
3227
Tensor reshape_vbe_offsets(
3328
const Tensor& offsets,
3429
const Tensor& B_offsets_rank_per_feature,

0 commit comments

Comments
 (0)