@@ -597,7 +597,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
597
597
"""
598
598
599
599
embedding_specs : List [Tuple [int , int , EmbeddingLocation , ComputeDevice ]]
600
- optimizer_args : invokers .lookup_args .OptimizerArgsPT2
600
+ optimizer_args : invokers .lookup_args .OptimizerArgs
601
601
lxu_cache_locations_list : List [Tensor ]
602
602
lxu_cache_locations_empty : Tensor
603
603
timesteps_prefetched : List [int ]
@@ -926,13 +926,6 @@ def __init__( # noqa C901
926
926
"feature_dims" ,
927
927
torch .tensor (feature_dims , device = "cpu" , dtype = torch .int64 ),
928
928
)
929
- (_info_B_num_bits , _info_B_mask ) = torch .ops .fbgemm .get_infos_metadata (
930
- self .D_offsets , # unused tensor
931
- 1 , # max_B
932
- T , # T
933
- )
934
- self .info_B_num_bits : int = _info_B_num_bits
935
- self .info_B_mask : int = _info_B_mask
936
929
937
930
# A flag for indicating whether all embedding tables are placed in the
938
931
# same locations
@@ -1077,9 +1070,6 @@ def __init__( # noqa C901
1077
1070
# which should not be effective when CounterBasedRegularizationDefinition
1078
1071
# and CowClipDefinition are not used
1079
1072
counter_halflife = - 1
1080
- learning_rate_tensor = torch .tensor (
1081
- learning_rate , device = torch .device ("cpu" ), dtype = torch .float
1082
- )
1083
1073
1084
1074
# TO DO: Enable this on the new interface
1085
1075
# learning_rate_tensor = torch.tensor(
@@ -1095,12 +1085,12 @@ def __init__( # noqa C901
1095
1085
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM" ,
1096
1086
)
1097
1087
1098
- self .optimizer_args = invokers .lookup_args .OptimizerArgsPT2 (
1088
+ self .optimizer_args = invokers .lookup_args .OptimizerArgs (
1099
1089
stochastic_rounding = stochastic_rounding ,
1100
1090
gradient_clipping = gradient_clipping ,
1101
1091
max_gradient = max_gradient ,
1102
1092
max_norm = max_norm ,
1103
- learning_rate_tensor = learning_rate_tensor ,
1093
+ learning_rate = learning_rate ,
1104
1094
eps = eps ,
1105
1095
beta1 = beta1 ,
1106
1096
beta2 = beta2 ,
@@ -1885,7 +1875,7 @@ def forward( # noqa: C901
1885
1875
if len (self .lxu_cache_locations_list ) == 0
1886
1876
else self .lxu_cache_locations_list .pop (0 )
1887
1877
)
1888
- common_args = invokers .lookup_args .CommonArgsPT2 (
1878
+ common_args = invokers .lookup_args .CommonArgs (
1889
1879
placeholder_autograd_tensor = self .placeholder_autograd_tensor ,
1890
1880
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
1891
1881
# `Union[Module, Tensor]`.
@@ -1932,8 +1922,6 @@ def forward( # noqa: C901
1932
1922
is_experimental = self .is_experimental ,
1933
1923
use_uniq_cache_locations_bwd = self .use_uniq_cache_locations_bwd ,
1934
1924
use_homogeneous_placements = self .use_homogeneous_placements ,
1935
- info_B_num_bits = self .info_B_num_bits ,
1936
- info_B_mask = self .info_B_mask ,
1937
1925
)
1938
1926
1939
1927
if self .optimizer == OptimType .NONE :
@@ -2046,6 +2034,7 @@ def forward( # noqa: C901
2046
2034
momentum1 ,
2047
2035
momentum2 ,
2048
2036
iter_int ,
2037
+ self .use_rowwise_bias_correction ,
2049
2038
row_counter = (
2050
2039
row_counter if self .use_rowwise_bias_correction else None
2051
2040
),
@@ -2931,7 +2920,7 @@ def _set_learning_rate(self, lr: float) -> float:
2931
2920
Helper function to script `set_learning_rate`.
2932
2921
Note that returning None does not work.
2933
2922
"""
2934
- self .optimizer_args . learning_rate_tensor . fill_ ( lr )
2923
+ self .optimizer_args = self . optimizer_args . _replace ( learning_rate = lr )
2935
2924
return 0.0
2936
2925
2937
2926
@torch .jit .ignore
@@ -3446,22 +3435,6 @@ def prepare_inputs(
3446
3435
offsets , batch_size_per_feature_per_rank
3447
3436
)
3448
3437
3449
- vbe = vbe_metadata .B_offsets is not None
3450
- # TODO: assert vbe_metadata.B_offsets.numel() - 1 == T
3451
- # T = self.D_offsets.numel() - 1
3452
- # vbe_metadata.B_offsets causes jit to fail for cogwheel forward compatibility test
3453
- # max_B = int(vbe_metadata.max_B) if vbe else int(offsets.numel() - 1 / T)
3454
-
3455
- # TODO: max_B <= self.info_B_mask
3456
- # cannot use assert as it breaks pt2 compile for dynamic shape
3457
- # Need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
3458
- # cannot use lambda as it fails jit script.
3459
- # torch._check is not supported in jitscript
3460
- # torch._check(
3461
- # max_B <= self.info_B_mask,
3462
- # "Not enough infos bits to accommodate T and B.",
3463
- # )
3464
-
3465
3438
# TODO: remove this and add an assert after updating
3466
3439
# bounds_check_indices to support different indices type and offset
3467
3440
# type
@@ -3489,6 +3462,7 @@ def prepare_inputs(
3489
3462
per_sample_weights = per_sample_weights .float ()
3490
3463
3491
3464
if self .bounds_check_mode_int != BoundsCheckMode .NONE .value :
3465
+ vbe = vbe_metadata .B_offsets is not None
3492
3466
# Compute B info and VBE metadata for bounds_check_indices only if
3493
3467
# VBE and bounds check indices v2 are used
3494
3468
if vbe and self .use_bounds_check_v2 :
@@ -3502,7 +3476,11 @@ def prepare_inputs(
3502
3476
assert isinstance (
3503
3477
output_offsets_feature_rank , Tensor
3504
3478
), "output_offsets_feature_rank must be tensor"
3505
-
3479
+ info_B_num_bits , info_B_mask = torch .ops .fbgemm .get_infos_metadata (
3480
+ B_offsets , # unused tensor
3481
+ vbe_metadata .max_B ,
3482
+ B_offsets .numel () - 1 , # T
3483
+ )
3506
3484
row_output_offsets , b_t_map = torch .ops .fbgemm .generate_vbe_metadata (
3507
3485
B_offsets ,
3508
3486
B_offsets_rank_per_feature ,
@@ -3511,11 +3489,13 @@ def prepare_inputs(
3511
3489
self .max_D ,
3512
3490
self .is_nobag ,
3513
3491
vbe_metadata .max_B_feature_rank ,
3514
- self . info_B_num_bits ,
3492
+ info_B_num_bits ,
3515
3493
offsets .numel () - 1 , # total_B
3516
3494
)
3517
3495
else :
3518
3496
b_t_map = None
3497
+ info_B_num_bits = - 1
3498
+ info_B_mask = - 1
3519
3499
3520
3500
torch .ops .fbgemm .bounds_check_indices (
3521
3501
self .rows_per_table ,
@@ -3527,8 +3507,8 @@ def prepare_inputs(
3527
3507
B_offsets = vbe_metadata .B_offsets ,
3528
3508
max_B = vbe_metadata .max_B ,
3529
3509
b_t_map = b_t_map ,
3530
- info_B_num_bits = self . info_B_num_bits ,
3531
- info_B_mask = self . info_B_mask ,
3510
+ info_B_num_bits = info_B_num_bits ,
3511
+ info_B_mask = info_B_mask ,
3532
3512
bounds_check_version = self .bounds_check_version ,
3533
3513
)
3534
3514
0 commit comments