You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unifying TBE API using List (Frontend) (pytorch#3711)
Summary:
X-link: pytorch/torchrec#2751
X-link: facebookresearch/FBGEMM#793
**Backend**: D68054868
---
As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration.
**For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.**
We pack the arguments as a list for each type.
For **common** arguments, we pack
- weights and arguments of type `Momentum` into TensorList
- other tensors and optional tensors to list of optional tensors `aux_tensor`
- `int` arguments into `aux_int`
- `float` arguments into `aux_float`
- `bool` arguments into `aux_bool`.
Similarly for **optimizer-specific** arguments, we pack
- arguments of type `Momentum` that are *__not__ optional* into TensorList
- *optional* tensors to list of optional tensors `optim_tensor`
- `int` arguments into `optim_int`
- `float` arguments into `optim_float`
- `bool` arguments into `optim_bool`.
We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually.
**This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design.
Please refer to the design doc on which arguments are packed and signature.
Design doc:
https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb
Full signature for each optimizer lookup function will be provided shortly.
Reviewed By: sryap, nautsimon
Differential Revision: D68055168
{%- if"learning_rate" in args.split_kernel_arg_names %}
603
603
// convert `learning rate` to float since `learning rate` is float in kernels
604
+
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
0 commit comments