Skip to content

Commit f1fe624

Browse files
spcypptfacebook-github-bot
authored andcommitted
Backout
Summary: X-link: facebookresearch/FBGEMM#887 Backout D68055168 as it seems to break pyper and causes S498612. Differential Revision: D70996903
1 parent 147e35c commit f1fe624

9 files changed

+319
-350
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def generate() -> None:
447447
ssd_optimizers.append(optim)
448448

449449
BackwardSplitGenerator.generate_backward_split(
450-
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
450+
ssd_tensors=ssd_tensors, **optimizer
451451
)
452452
BackwardSplitGenerator.generate_rocm_backward_split()
453453

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,6 @@ Tensor {{ embedding_cuda_op }}(
601601

602602
{%- if "learning_rate" in args.split_kernel_arg_names %}
603603
// convert `learning rate` to float since `learning rate` is float in kernels
604-
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
605604
const float learning_rate = learning_rate_tensor.item<float>();
606605
{%- endif %}
607606

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

+293-287
Large diffs are not rendered by default.

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+17-37
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
597597
"""
598598

599599
embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
600-
optimizer_args: invokers.lookup_args.OptimizerArgsPT2
600+
optimizer_args: invokers.lookup_args.OptimizerArgs
601601
lxu_cache_locations_list: List[Tensor]
602602
lxu_cache_locations_empty: Tensor
603603
timesteps_prefetched: List[int]
@@ -926,13 +926,6 @@ def __init__( # noqa C901
926926
"feature_dims",
927927
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
928928
)
929-
(_info_B_num_bits, _info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
930-
self.D_offsets, # unused tensor
931-
1, # max_B
932-
T, # T
933-
)
934-
self.info_B_num_bits: int = _info_B_num_bits
935-
self.info_B_mask: int = _info_B_mask
936929

937930
# A flag for indicating whether all embedding tables are placed in the
938931
# same locations
@@ -1077,9 +1070,6 @@ def __init__( # noqa C901
10771070
# which should not be effective when CounterBasedRegularizationDefinition
10781071
# and CowClipDefinition are not used
10791072
counter_halflife = -1
1080-
learning_rate_tensor = torch.tensor(
1081-
learning_rate, device=torch.device("cpu"), dtype=torch.float
1082-
)
10831073

10841074
# TO DO: Enable this on the new interface
10851075
# learning_rate_tensor = torch.tensor(
@@ -1095,12 +1085,12 @@ def __init__( # noqa C901
10951085
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
10961086
)
10971087

1098-
self.optimizer_args = invokers.lookup_args.OptimizerArgsPT2(
1088+
self.optimizer_args = invokers.lookup_args.OptimizerArgs(
10991089
stochastic_rounding=stochastic_rounding,
11001090
gradient_clipping=gradient_clipping,
11011091
max_gradient=max_gradient,
11021092
max_norm=max_norm,
1103-
learning_rate_tensor=learning_rate_tensor,
1093+
learning_rate=learning_rate,
11041094
eps=eps,
11051095
beta1=beta1,
11061096
beta2=beta2,
@@ -1885,7 +1875,7 @@ def forward( # noqa: C901
18851875
if len(self.lxu_cache_locations_list) == 0
18861876
else self.lxu_cache_locations_list.pop(0)
18871877
)
1888-
common_args = invokers.lookup_args.CommonArgsPT2(
1878+
common_args = invokers.lookup_args.CommonArgs(
18891879
placeholder_autograd_tensor=self.placeholder_autograd_tensor,
18901880
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
18911881
# `Union[Module, Tensor]`.
@@ -1932,8 +1922,6 @@ def forward( # noqa: C901
19321922
is_experimental=self.is_experimental,
19331923
use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
19341924
use_homogeneous_placements=self.use_homogeneous_placements,
1935-
info_B_num_bits=self.info_B_num_bits,
1936-
info_B_mask=self.info_B_mask,
19371925
)
19381926

19391927
if self.optimizer == OptimType.NONE:
@@ -2046,6 +2034,7 @@ def forward( # noqa: C901
20462034
momentum1,
20472035
momentum2,
20482036
iter_int,
2037+
self.use_rowwise_bias_correction,
20492038
row_counter=(
20502039
row_counter if self.use_rowwise_bias_correction else None
20512040
),
@@ -2931,7 +2920,7 @@ def _set_learning_rate(self, lr: float) -> float:
29312920
Helper function to script `set_learning_rate`.
29322921
Note that returning None does not work.
29332922
"""
2934-
self.optimizer_args.learning_rate_tensor.fill_(lr)
2923+
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
29352924
return 0.0
29362925

29372926
@torch.jit.ignore
@@ -3446,22 +3435,6 @@ def prepare_inputs(
34463435
offsets, batch_size_per_feature_per_rank
34473436
)
34483437

3449-
vbe = vbe_metadata.B_offsets is not None
3450-
# TODO: assert vbe_metadata.B_offsets.numel() - 1 == T
3451-
# T = self.D_offsets.numel() - 1
3452-
# vbe_metadata.B_offsets causes jit to fail for cogwheel forward compatibility test
3453-
# max_B = int(vbe_metadata.max_B) if vbe else int(offsets.numel() - 1 / T)
3454-
3455-
# TODO: max_B <= self.info_B_mask
3456-
# cannot use assert as it breaks pt2 compile for dynamic shape
3457-
# Need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
3458-
# cannot use lambda as it fails jit script.
3459-
# torch._check is not supported in jitscript
3460-
# torch._check(
3461-
# max_B <= self.info_B_mask,
3462-
# "Not enough infos bits to accommodate T and B.",
3463-
# )
3464-
34653438
# TODO: remove this and add an assert after updating
34663439
# bounds_check_indices to support different indices type and offset
34673440
# type
@@ -3489,6 +3462,7 @@ def prepare_inputs(
34893462
per_sample_weights = per_sample_weights.float()
34903463

34913464
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
3465+
vbe = vbe_metadata.B_offsets is not None
34923466
# Compute B info and VBE metadata for bounds_check_indices only if
34933467
# VBE and bounds check indices v2 are used
34943468
if vbe and self.use_bounds_check_v2:
@@ -3502,7 +3476,11 @@ def prepare_inputs(
35023476
assert isinstance(
35033477
output_offsets_feature_rank, Tensor
35043478
), "output_offsets_feature_rank must be tensor"
3505-
3479+
info_B_num_bits, info_B_mask = torch.ops.fbgemm.get_infos_metadata(
3480+
B_offsets, # unused tensor
3481+
vbe_metadata.max_B,
3482+
B_offsets.numel() - 1, # T
3483+
)
35063484
row_output_offsets, b_t_map = torch.ops.fbgemm.generate_vbe_metadata(
35073485
B_offsets,
35083486
B_offsets_rank_per_feature,
@@ -3511,11 +3489,13 @@ def prepare_inputs(
35113489
self.max_D,
35123490
self.is_nobag,
35133491
vbe_metadata.max_B_feature_rank,
3514-
self.info_B_num_bits,
3492+
info_B_num_bits,
35153493
offsets.numel() - 1, # total_B
35163494
)
35173495
else:
35183496
b_t_map = None
3497+
info_B_num_bits = -1
3498+
info_B_mask = -1
35193499

35203500
torch.ops.fbgemm.bounds_check_indices(
35213501
self.rows_per_table,
@@ -3527,8 +3507,8 @@ def prepare_inputs(
35273507
B_offsets=vbe_metadata.B_offsets,
35283508
max_B=vbe_metadata.max_B,
35293509
b_t_map=b_t_map,
3530-
info_B_num_bits=self.info_B_num_bits,
3531-
info_B_mask=self.info_B_mask,
3510+
info_B_num_bits=info_B_num_bits,
3511+
info_B_mask=info_B_mask,
35323512
bounds_check_version=self.bounds_check_version,
35333513
)
35343514

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,6 @@ def __init__(
228228
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
229229
)
230230

231-
(self.info_B_num_bits, self.info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
232-
self.D_offsets, # unused tensor
233-
1, # max_B
234-
T, # T
235-
)
236-
237231
assert cache_sets > 0
238232
element_size = weights_precision.bit_rate() // 8
239233
assert (
@@ -550,15 +544,12 @@ def __init__(
550544
)
551545
cowclip_regularization = CowClipDefinition()
552546

553-
learning_rate_tensor = torch.tensor(
554-
learning_rate, device=torch.device("cpu"), dtype=torch.float
555-
)
556-
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgsPT2(
547+
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
557548
stochastic_rounding=stochastic_rounding,
558549
gradient_clipping=gradient_clipping,
559550
max_gradient=max_gradient,
560551
max_norm=max_norm,
561-
learning_rate_tensor=learning_rate_tensor,
552+
learning_rate=learning_rate,
562553
eps=eps,
563554
beta1=beta1,
564555
beta2=beta2,
@@ -1639,7 +1630,7 @@ def forward(
16391630
vbe_metadata.max_B,
16401631
)
16411632

1642-
common_args = invokers.lookup_args_ssd.CommonArgsPT2(
1633+
common_args = invokers.lookup_args_ssd.CommonArgs(
16431634
placeholder_autograd_tensor=self.placeholder_autograd_tensor,
16441635
output_dtype=self.output_dtype,
16451636
dev_weights=self.weights_dev,
@@ -1674,8 +1665,6 @@ def forward(
16741665
},
16751666
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
16761667
vbe_metadata=vbe_metadata,
1677-
info_B_num_bits=self.info_B_num_bits,
1678-
info_B_mask=self.info_B_mask,
16791668
)
16801669

16811670
self.timesteps_prefetched.pop(0)
@@ -1826,7 +1815,7 @@ def _set_learning_rate(self, lr: float) -> float:
18261815
Helper function to script `set_learning_rate`.
18271816
Note that returning None does not work.
18281817
"""
1829-
self.optimizer_args.learning_rate_tensor.fill_(lr)
1818+
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
18301819
return 0.0
18311820

18321821
def flush(self) -> None:

fbgemm_gpu/test/tbe/training/backward_adagrad_common.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,8 @@
8282
"use_cache": st.booleans(),
8383
"cache_algorithm": st.sampled_from(CacheAlgorithm),
8484
"use_cpu": use_cpu_strategy(),
85-
"output_dtype": (
86-
st.sampled_from([SparseType.FP32, SparseType.FP16, SparseType.BF16])
87-
if gpu_available
88-
else st.sampled_from([SparseType.FP32, SparseType.FP16])
85+
"output_dtype": st.sampled_from(
86+
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
8987
),
9088
}
9189

fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def apply_gwd(
190190
apply_gwd_per_table(
191191
prev_iter_values,
192192
weights_values,
193-
emb.optimizer_args.learning_rate_tensor.item(),
193+
emb.optimizer_args.learning_rate,
194194
emb.optimizer_args.weight_decay,
195195
step,
196196
emb.current_device,

fbgemm_gpu/test/tbe/training/forward_test.py

-3
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@
7676
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
7777
unittest.skip("Operator not implemented for Meta tensors"),
7878
],
79-
"test_faketensor__test_forward_cpu_fp32": [
80-
unittest.skip("Operator not implemented for Meta tensors"),
81-
],
8279
# TODO: Make it compatible with opcheck tests
8380
"test_faketensor__test_forward_gpu_uvm_cache_fp16": [
8481
unittest.skip(

fbgemm_gpu/test/tbe/utils/split_embeddings_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def test_update_hyper_parameters(self) -> None:
594594
} | {"lr": 1.0, "lower_bound": 2.0}
595595
cc.update_hyper_parameters(updated_parameters)
596596
self.assertAlmostEqual(
597-
cc.optimizer_args.learning_rate_tensor.item(), updated_parameters["lr"]
597+
cc.optimizer_args.learning_rate, updated_parameters["lr"]
598598
)
599599
self.assertAlmostEqual(cc.optimizer_args.eps, updated_parameters["eps"])
600600
self.assertAlmostEqual(cc.optimizer_args.beta1, updated_parameters["beta1"])

0 commit comments

Comments
 (0)