Skip to content

Commit 6b31758

Browse files
spcypptfacebook-github-bot
authored andcommitted
Unifying TBE API using List (Frontend) (pytorch#3711)
Summary: X-link: pytorch/torchrec#2751 X-link: facebookresearch/FBGEMM#793 **Backend**: D68054868 --- As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. **For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.** We pack the arguments as a list for each type. For **common** arguments, we pack - weights and arguments of type `Momentum` into TensorList - other tensors and optional tensors to list of optional tensors `aux_tensor` - `int` arguments into `aux_int` - `float` arguments into `aux_float` - `bool` arguments into `aux_bool`. Similarly for **optimizer-specific** arguments, we pack - arguments of type `Momentum` that are *__not__ optional* into TensorList - *optional* tensors to list of optional tensors `optim_tensor` - `int` arguments into `optim_int` - `float` arguments into `optim_float` - `bool` arguments into `optim_bool`. We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. **This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. Please refer to the design doc on which arguments are packed and signature. Design doc: https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb Full signature for each optimizer lookup function will be provided shortly. Reviewed By: sryap Differential Revision: D68055168
1 parent 744f231 commit 6b31758

9 files changed

+320
-305
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def generate() -> None:
447447
ssd_optimizers.append(optim)
448448

449449
BackwardSplitGenerator.generate_backward_split(
450-
ssd_tensors=ssd_tensors, **optimizer
450+
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
451451
)
452452
BackwardSplitGenerator.generate_rocm_backward_split()
453453

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

+1
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ Tensor {{ embedding_cuda_op }}(
601601

602602
{%- if "learning_rate" in args.split_kernel_arg_names %}
603603
// convert `learning rate` to float since `learning rate` is float in kernels
604+
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
604605
const float learning_rate = learning_rate_tensor.item<float>();
605606
{%- endif %}
606607

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

+297-293
Large diffs are not rendered by default.

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
597597
"""
598598

599599
embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
600-
optimizer_args: invokers.lookup_args.OptimizerArgs
600+
optimizer_args: invokers.lookup_args.OptimizerArgsPT2
601601
lxu_cache_locations_list: List[Tensor]
602602
lxu_cache_locations_empty: Tensor
603603
timesteps_prefetched: List[int]
@@ -1070,6 +1070,9 @@ def __init__( # noqa C901
10701070
# which should not be effective when CounterBasedRegularizationDefinition
10711071
# and CowClipDefinition are not used
10721072
counter_halflife = -1
1073+
learning_rate_tensor = torch.tensor(
1074+
learning_rate, device=torch.device("cpu"), dtype=torch.float
1075+
)
10731076

10741077
# TO DO: Enable this on the new interface
10751078
# learning_rate_tensor = torch.tensor(
@@ -1085,12 +1088,12 @@ def __init__( # noqa C901
10851088
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
10861089
)
10871090

1088-
self.optimizer_args = invokers.lookup_args.OptimizerArgs(
1091+
self.optimizer_args = invokers.lookup_args.OptimizerArgsPT2(
10891092
stochastic_rounding=stochastic_rounding,
10901093
gradient_clipping=gradient_clipping,
10911094
max_gradient=max_gradient,
10921095
max_norm=max_norm,
1093-
learning_rate=learning_rate,
1096+
learning_rate_tensor=learning_rate_tensor,
10941097
eps=eps,
10951098
beta1=beta1,
10961099
beta2=beta2,
@@ -2032,7 +2035,6 @@ def forward( # noqa: C901
20322035
momentum1,
20332036
momentum2,
20342037
iter_int,
2035-
self.use_rowwise_bias_correction,
20362038
row_counter=(
20372039
row_counter if self.use_rowwise_bias_correction else None
20382040
),
@@ -2918,7 +2920,7 @@ def _set_learning_rate(self, lr: float) -> float:
29182920
Helper function to script `set_learning_rate`.
29192921
Note that returning None does not work.
29202922
"""
2921-
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
2923+
self.optimizer_args.learning_rate_tensor.fill_(lr)
29222924
return 0.0
29232925

29242926
@torch.jit.ignore

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -544,12 +544,15 @@ def __init__(
544544
)
545545
cowclip_regularization = CowClipDefinition()
546546

547-
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
547+
learning_rate_tensor = torch.tensor(
548+
learning_rate, device=torch.device("cpu"), dtype=torch.float
549+
)
550+
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgsPT2(
548551
stochastic_rounding=stochastic_rounding,
549552
gradient_clipping=gradient_clipping,
550553
max_gradient=max_gradient,
551554
max_norm=max_norm,
552-
learning_rate=learning_rate,
555+
learning_rate_tensor=learning_rate_tensor,
553556
eps=eps,
554557
beta1=beta1,
555558
beta2=beta2,

fbgemm_gpu/test/tbe/training/backward_adagrad_common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@
8282
"use_cache": st.booleans(),
8383
"cache_algorithm": st.sampled_from(CacheAlgorithm),
8484
"use_cpu": use_cpu_strategy(),
85-
"output_dtype": st.sampled_from(
86-
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
85+
"output_dtype": (
86+
st.sampled_from([SparseType.FP32, SparseType.FP16, SparseType.BF16])
87+
if gpu_available
88+
else st.sampled_from([SparseType.FP32, SparseType.FP16])
8789
),
8890
}
8991

fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def apply_gwd(
190190
apply_gwd_per_table(
191191
prev_iter_values,
192192
weights_values,
193-
emb.optimizer_args.learning_rate,
193+
emb.optimizer_args.learning_rate_tensor.item(),
194194
emb.optimizer_args.weight_decay,
195195
step,
196196
emb.current_device,

fbgemm_gpu/test/tbe/training/forward_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@
7676
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
7777
unittest.skip("Operator not implemented for Meta tensors"),
7878
],
79+
"test_faketensor__test_forward_cpu_fp32": [
80+
unittest.skip("Operator not implemented for Meta tensors"),
81+
],
7982
# TODO: Make it compatible with opcheck tests
8083
"test_faketensor__test_forward_gpu_uvm_cache_fp16": [
8184
unittest.skip(

fbgemm_gpu/test/tbe/utils/split_embeddings_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def test_update_hyper_parameters(self) -> None:
594594
} | {"lr": 1.0, "lower_bound": 2.0}
595595
cc.update_hyper_parameters(updated_parameters)
596596
self.assertAlmostEqual(
597-
cc.optimizer_args.learning_rate, updated_parameters["lr"]
597+
cc.optimizer_args.learning_rate_tensor.item(), updated_parameters["lr"]
598598
)
599599
self.assertAlmostEqual(cc.optimizer_args.eps, updated_parameters["eps"])
600600
self.assertAlmostEqual(cc.optimizer_args.beta1, updated_parameters["beta1"])

0 commit comments

Comments
 (0)