@@ -473,7 +473,11 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
473
473
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
474
474
prev_iter[idx] = iter * 1.0;
475
475
const auto counter_log_rho = logf(2.0) / counter_halflife;
476
- row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
476
+ if (regularization_mode == 3 && weight_decay_mode == 3) {
477
+ tail_id_threshold_val = iter_delta;
478
+ } else {
479
+ row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
480
+ }
477
481
} else if (counter_halflife == 0) { // count only 1 (appear or not)
478
482
row_counter[idx] = 1.0;
479
483
} else { // count raw appearance without decaying
@@ -552,7 +556,13 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
552
556
exp_reg_correction = 1.0;
553
557
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
554
558
if (adjustment_enabled) {
555
- if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
559
+ if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3)
560
+ freq = min(tail_id_threshold_val, iter*1.0 - adjustment_iter);
561
+ exp_reg_correction = 1.0 - weight_decay * learning_rate / sqrtf(iter*1.0);
562
+ freq = expf(- weight_decay * learning_rate * 2.0 * (sqrtf(iter*1.0) - sqrtf(iter*1.0 - freq + 1.0)));
563
+ adjusted_multiplier *= freq; // lazy update
564
+ exp_reg_correction *= freq;
565
+ } else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
556
566
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
557
567
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
558
568
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
0 commit comments