Skip to content

Commit 864df42

Browse files
minhua-chenfacebook-github-bot
authored andcommitted
AdagradW (pytorch#3605)
Summary: CounterWeightDecayMode.SQRT Differential Revision: D67625467
1 parent 428e671 commit 864df42

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
473473
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
474474
prev_iter[idx] = iter * 1.0;
475475
const auto counter_log_rho = logf(2.0) / counter_halflife;
476-
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
476+
if (regularization_mode == 3 && weight_decay_mode == 3) {
477+
tail_id_threshold_val = iter_delta;
478+
} else {
479+
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
480+
}
477481
} else if (counter_halflife == 0) { // count only 1 (appear or not)
478482
row_counter[idx] = 1.0;
479483
} else { // count raw appearance without decaying
@@ -552,7 +556,13 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
552556
exp_reg_correction = 1.0;
553557
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
554558
if (adjustment_enabled) {
555-
if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
559+
if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3)
560+
freq = min(tail_id_threshold_val, iter*1.0 - adjustment_iter);
561+
exp_reg_correction = 1.0 - weight_decay * learning_rate / sqrtf(iter*1.0);
562+
freq = expf(- weight_decay * learning_rate * 2.0 * (sqrtf(iter*1.0) - sqrtf(iter*1.0 - freq + 1.0)));
563+
adjusted_multiplier *= freq; // lazy update
564+
exp_reg_correction *= freq;
565+
} else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
556566
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
557567
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
558568
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum):
9999
NONE = 0
100100
L2 = 1
101101
DECOUPLE = 2
102+
SQRT = 3
102103

103104

104105
class StepMode(enum.IntEnum):

0 commit comments

Comments
 (0)