Skip to content

Commit ce34aac

Browse files
spcypptfacebook-github-bot
authored andcommitted
Unifying TBE API using List (Frontend) (pytorch#3711)
Summary: X-link: pytorch/torchrec#2751 X-link: facebookresearch/FBGEMM#793 **Backend**: D68054868 --- As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. **For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.** We pack the arguments as a list for each type. For **common** arguments, we pack - weights and arguments of type `Momentum` into TensorList - other tensors and optional tensors to list of optional tensors `aux_tensor` - `int` arguments into `aux_int` - `float` arguments into `aux_float` - `bool` arguments into `aux_bool`. Similarly for **optimizer-specific** arguments, we pack - arguments of type `Momentum` that are *__not__ optional* into TensorList - *optional* tensors to list of optional tensors `optim_tensor` - `int` arguments into `optim_int` - `float` arguments into `optim_float` - `bool` arguments into `optim_bool`. We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. **This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. Please refer to the design doc on which arguments are packed and signature. Design doc: https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb Full signature for each optimizer lookup function will be provided shortly. Reviewed By: sryap, nautsimon Differential Revision: D68055168
1 parent 428e671 commit ce34aac

9 files changed

+343
-318
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def generate() -> None:
447447
ssd_optimizers.append(optim)
448448

449449
BackwardSplitGenerator.generate_backward_split(
450-
ssd_tensors=ssd_tensors, **optimizer
450+
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
451451
)
452452
BackwardSplitGenerator.generate_rocm_backward_split()
453453

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

+1
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ Tensor {{ embedding_cuda_op }}(
601601

602602
{%- if "learning_rate" in args.split_kernel_arg_names %}
603603
// convert `learning rate` to float since `learning rate` is float in kernels
604+
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
604605
const float learning_rate = learning_rate_tensor.item<float>();
605606
{%- endif %}
606607

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

+283-293
Large diffs are not rendered by default.

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+35-17
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
597597
"""
598598

599599
embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
600-
optimizer_args: invokers.lookup_args.OptimizerArgs
600+
optimizer_args: invokers.lookup_args.OptimizerArgsPT2
601601
lxu_cache_locations_list: List[Tensor]
602602
lxu_cache_locations_empty: Tensor
603603
timesteps_prefetched: List[int]
@@ -926,6 +926,11 @@ def __init__( # noqa C901
926926
"feature_dims",
927927
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
928928
)
929+
(self.info_B_num_bits, self.info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
930+
self.D_offsets, # unused tensor
931+
1, # max_B
932+
T, # T
933+
)
929934

930935
# A flag for indicating whether all embedding tables are placed in the
931936
# same locations
@@ -1070,6 +1075,9 @@ def __init__( # noqa C901
10701075
# which should not be effective when CounterBasedRegularizationDefinition
10711076
# and CowClipDefinition are not used
10721077
counter_halflife = -1
1078+
learning_rate_tensor = torch.tensor(
1079+
learning_rate, device=torch.device("cpu"), dtype=torch.float
1080+
)
10731081

10741082
# TO DO: Enable this on the new interface
10751083
# learning_rate_tensor = torch.tensor(
@@ -1085,12 +1093,12 @@ def __init__( # noqa C901
10851093
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
10861094
)
10871095

1088-
self.optimizer_args = invokers.lookup_args.OptimizerArgs(
1096+
self.optimizer_args = invokers.lookup_args.OptimizerArgsPT2(
10891097
stochastic_rounding=stochastic_rounding,
10901098
gradient_clipping=gradient_clipping,
10911099
max_gradient=max_gradient,
10921100
max_norm=max_norm,
1093-
learning_rate=learning_rate,
1101+
learning_rate_tensor=learning_rate_tensor,
10941102
eps=eps,
10951103
beta1=beta1,
10961104
beta2=beta2,
@@ -1873,7 +1881,7 @@ def forward( # noqa: C901
18731881
if len(self.lxu_cache_locations_list) == 0
18741882
else self.lxu_cache_locations_list.pop(0)
18751883
)
1876-
common_args = invokers.lookup_args.CommonArgs(
1884+
common_args = invokers.lookup_args.CommonArgsPT2(
18771885
placeholder_autograd_tensor=self.placeholder_autograd_tensor,
18781886
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
18791887
# `Union[Module, Tensor]`.
@@ -1920,6 +1928,8 @@ def forward( # noqa: C901
19201928
is_experimental=self.is_experimental,
19211929
use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
19221930
use_homogeneous_placements=self.use_homogeneous_placements,
1931+
info_B_num_bits=self.info_B_num_bits,
1932+
info_B_mask=self.info_B_mask,
19231933
)
19241934

19251935
if self.optimizer == OptimType.NONE:
@@ -2032,7 +2042,6 @@ def forward( # noqa: C901
20322042
momentum1,
20332043
momentum2,
20342044
iter_int,
2035-
self.use_rowwise_bias_correction,
20362045
row_counter=(
20372046
row_counter if self.use_rowwise_bias_correction else None
20382047
),
@@ -2918,7 +2927,7 @@ def _set_learning_rate(self, lr: float) -> float:
29182927
Helper function to script `set_learning_rate`.
29192928
Note that returning None does not work.
29202929
"""
2921-
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
2930+
self.optimizer_args.learning_rate_tensor.fill_(lr)
29222931
return 0.0
29232932

29242933
@torch.jit.ignore
@@ -3433,6 +3442,22 @@ def prepare_inputs(
34333442
offsets, batch_size_per_feature_per_rank
34343443
)
34353444

3445+
vbe = vbe_metadata.B_offsets is not None
3446+
# TODO: assert vbe_metadata.B_offsets.numel() - 1 == T
3447+
# T = self.D_offsets.numel() - 1
3448+
# vbe_metadata.B_offsets causes jit to fail for cogwheel forward compatibility test
3449+
# max_B = int(vbe_metadata.max_B) if vbe else int(offsets.numel() - 1 / T)
3450+
3451+
# TODO: max_B <= self.info_B_mask
3452+
# cannot use assert as it breaks pt2 compile for dynamic shape
3453+
# Need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
3454+
# cannot use lambda as it fails jit script.
3455+
# torch._check is not supported in jitscript
3456+
# torch._check(
3457+
# max_B <= self.info_B_mask,
3458+
# "Not enough infos bits to accommodate T and B.",
3459+
# )
3460+
34363461
# TODO: remove this and add an assert after updating
34373462
# bounds_check_indices to support different indices type and offset
34383463
# type
@@ -3460,7 +3485,6 @@ def prepare_inputs(
34603485
per_sample_weights = per_sample_weights.float()
34613486

34623487
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
3463-
vbe = vbe_metadata.B_offsets is not None
34643488
# Compute B info and VBE metadata for bounds_check_indices only if
34653489
# VBE and bounds check indices v2 are used
34663490
if vbe and self.use_bounds_check_v2:
@@ -3474,11 +3498,7 @@ def prepare_inputs(
34743498
assert isinstance(
34753499
output_offsets_feature_rank, Tensor
34763500
), "output_offsets_feature_rank must be tensor"
3477-
info_B_num_bits, info_B_mask = torch.ops.fbgemm.get_infos_metadata(
3478-
B_offsets, # unused tensor
3479-
vbe_metadata.max_B,
3480-
B_offsets.numel() - 1, # T
3481-
)
3501+
34823502
row_output_offsets, b_t_map = torch.ops.fbgemm.generate_vbe_metadata(
34833503
B_offsets,
34843504
B_offsets_rank_per_feature,
@@ -3487,13 +3507,11 @@ def prepare_inputs(
34873507
self.max_D,
34883508
self.is_nobag,
34893509
vbe_metadata.max_B_feature_rank,
3490-
info_B_num_bits,
3510+
self.info_B_num_bits,
34913511
offsets.numel() - 1, # total_B
34923512
)
34933513
else:
34943514
b_t_map = None
3495-
info_B_num_bits = -1
3496-
info_B_mask = -1
34973515

34983516
torch.ops.fbgemm.bounds_check_indices(
34993517
self.rows_per_table,
@@ -3505,8 +3523,8 @@ def prepare_inputs(
35053523
B_offsets=vbe_metadata.B_offsets,
35063524
max_B=vbe_metadata.max_B,
35073525
b_t_map=b_t_map,
3508-
info_B_num_bits=info_B_num_bits,
3509-
info_B_mask=info_B_mask,
3526+
info_B_num_bits=self.info_B_num_bits,
3527+
info_B_mask=self.info_B_mask,
35103528
bounds_check_version=self.bounds_check_version,
35113529
)
35123530

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def __init__(
228228
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
229229
)
230230

231+
(self.info_B_num_bits, self.info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
232+
self.D_offsets, # unused tensor
233+
1, # max_B
234+
T, # T
235+
)
236+
231237
assert cache_sets > 0
232238
element_size = weights_precision.bit_rate() // 8
233239
assert (
@@ -544,12 +550,15 @@ def __init__(
544550
)
545551
cowclip_regularization = CowClipDefinition()
546552

547-
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
553+
learning_rate_tensor = torch.tensor(
554+
learning_rate, device=torch.device("cpu"), dtype=torch.float
555+
)
556+
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgsPT2(
548557
stochastic_rounding=stochastic_rounding,
549558
gradient_clipping=gradient_clipping,
550559
max_gradient=max_gradient,
551560
max_norm=max_norm,
552-
learning_rate=learning_rate,
561+
learning_rate_tensor=learning_rate_tensor,
553562
eps=eps,
554563
beta1=beta1,
555564
beta2=beta2,
@@ -1630,7 +1639,7 @@ def forward(
16301639
vbe_metadata.max_B,
16311640
)
16321641

1633-
common_args = invokers.lookup_args_ssd.CommonArgs(
1642+
common_args = invokers.lookup_args_ssd.CommonArgsPT2(
16341643
placeholder_autograd_tensor=self.placeholder_autograd_tensor,
16351644
output_dtype=self.output_dtype,
16361645
dev_weights=self.weights_dev,
@@ -1665,6 +1674,8 @@ def forward(
16651674
},
16661675
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
16671676
vbe_metadata=vbe_metadata,
1677+
info_B_num_bits=self.info_B_num_bits,
1678+
info_B_mask=self.info_B_mask,
16681679
)
16691680

16701681
self.timesteps_prefetched.pop(0)

fbgemm_gpu/test/tbe/training/backward_adagrad_common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@
8282
"use_cache": st.booleans(),
8383
"cache_algorithm": st.sampled_from(CacheAlgorithm),
8484
"use_cpu": use_cpu_strategy(),
85-
"output_dtype": st.sampled_from(
86-
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
85+
"output_dtype": (
86+
st.sampled_from([SparseType.FP32, SparseType.FP16, SparseType.BF16])
87+
if gpu_available
88+
else st.sampled_from([SparseType.FP32, SparseType.FP16])
8789
),
8890
}
8991

fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def apply_gwd(
190190
apply_gwd_per_table(
191191
prev_iter_values,
192192
weights_values,
193-
emb.optimizer_args.learning_rate,
193+
emb.optimizer_args.learning_rate_tensor.item(),
194194
emb.optimizer_args.weight_decay,
195195
step,
196196
emb.current_device,

fbgemm_gpu/test/tbe/training/forward_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@
7676
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
7777
unittest.skip("Operator not implemented for Meta tensors"),
7878
],
79+
"test_faketensor__test_forward_cpu_fp32": [
80+
unittest.skip("Operator not implemented for Meta tensors"),
81+
],
7982
# TODO: Make it compatible with opcheck tests
8083
"test_faketensor__test_forward_gpu_uvm_cache_fp16": [
8184
unittest.skip(

fbgemm_gpu/test/tbe/utils/split_embeddings_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def test_update_hyper_parameters(self) -> None:
594594
} | {"lr": 1.0, "lower_bound": 2.0}
595595
cc.update_hyper_parameters(updated_parameters)
596596
self.assertAlmostEqual(
597-
cc.optimizer_args.learning_rate, updated_parameters["lr"]
597+
cc.optimizer_args.learning_rate_tensor.item(), updated_parameters["lr"]
598598
)
599599
self.assertAlmostEqual(cc.optimizer_args.eps, updated_parameters["eps"])
600600
self.assertAlmostEqual(cc.optimizer_args.beta1, updated_parameters["beta1"])

0 commit comments

Comments
 (0)