Skip to content

Commit 5a6a018

Browse files
authored
Merge pull request #1682 from iretes/IL2M
Added IL2M strategy
2 parents 625e46d + 62c8fd3 commit 5a6a018

File tree

5 files changed

+300
-0
lines changed

5 files changed

+300
-0
lines changed

avalanche/training/plugins/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
from .update_ncm import *
2727
from .update_fecam import *
2828
from .feature_distillation import *
29+
from .il2m import IL2MPlugin

avalanche/training/plugins/il2m.py

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
from typing import Optional
2+
3+
from packaging.version import parse
4+
import torch
5+
import numpy as np
6+
7+
from avalanche.training.templates import SupervisedTemplate
8+
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
9+
from avalanche.training.storage_policy import ExemplarsBuffer, ExperienceBalancedBuffer
10+
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
11+
12+
13+
class IL2MPlugin(SupervisedPlugin):
14+
"""
15+
Class Incremental Learning With Dual Memory (IL2M) plugin.
16+
17+
Technique introduced in:
18+
Belouadah, E. and Popescu, A. "IL2M: Class Incremental Learning With Dual
19+
Memory." Proceedings of the IEEE/CVF Conference on Computer Vision and
20+
Pattern Recognition. 2019.
21+
22+
Implementation based on FACIL, as in:
23+
https://github.com/mmasana/FACIL/blob/master/src/approach/il2m.py
24+
"""
25+
26+
def __init__(
27+
self,
28+
mem_size: int = 2000,
29+
batch_size: Optional[int] = None,
30+
batch_size_mem: Optional[int] = None,
31+
storage_policy: Optional[ExemplarsBuffer] = None,
32+
):
33+
"""
34+
:param mem_size: replay buffer size.
35+
:param batch_size: the size of the data batch. If set to `None`, it
36+
will be set equal to the strategy's batch size.
37+
:param batch_size_mem: the size of the memory batch. If its value is set
38+
to `None` (the default value), it will be automatically set equal to
39+
the data batch size.
40+
:param storage_policy: The policy that controls how to add new exemplars
41+
in memory.
42+
"""
43+
44+
super().__init__()
45+
self.mem_size = mem_size
46+
self.batch_size = batch_size
47+
self.batch_size_mem = batch_size_mem
48+
49+
if storage_policy is not None: # Use other storage policy
50+
self.storage_policy = storage_policy
51+
assert storage_policy.max_size == self.mem_size
52+
else: # Default
53+
self.storage_policy = ExperienceBalancedBuffer(
54+
max_size=self.mem_size, adaptive_size=True
55+
)
56+
57+
# to store statistics for the classes as learned in the current incremental state
58+
self.current_classes_means = []
59+
# to store statistics for past classes as learned in the incremental state in which they were first seen
60+
self.init_classes_means = []
61+
# to store statistics for model confidence in different states (i.e. avg top-1 pred scores)
62+
self.models_confidence = []
63+
# to store the mapping between classes and the incremental state in which they were first seen
64+
self.classes2exp = []
65+
# total number of classes that will be seen
66+
self.n_classes = 0
67+
68+
def before_training_exp(
69+
self,
70+
strategy: SupervisedTemplate,
71+
num_workers: int = 0,
72+
shuffle: bool = True,
73+
drop_last: bool = False,
74+
**kwargs
75+
):
76+
77+
if len(self.init_classes_means) == 0:
78+
self.n_classes = len(strategy.experience.classes_seen_so_far) + len(
79+
strategy.experience.future_classes
80+
)
81+
self.init_classes_means = [0 for _ in range(self.n_classes)]
82+
self.classes2exp = [-1 for _ in range(self.n_classes)]
83+
84+
if len(self.storage_policy.buffer) == 0:
85+
# first experience. We don't use the buffer, no need to change
86+
# the dataloader.
87+
return
88+
89+
batch_size = self.batch_size
90+
if batch_size is None:
91+
batch_size = strategy.train_mb_size
92+
93+
batch_size_mem = self.batch_size_mem
94+
if batch_size_mem is None:
95+
batch_size_mem = strategy.train_mb_size
96+
97+
assert strategy.adapted_dataset is not None
98+
99+
other_dataloader_args = dict()
100+
101+
if "ffcv_args" in kwargs:
102+
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]
103+
104+
if "persistent_workers" in kwargs:
105+
if parse(torch.__version__) >= parse("1.7.0"):
106+
other_dataloader_args["persistent_workers"] = kwargs[
107+
"persistent_workers"
108+
]
109+
110+
strategy.dataloader = ReplayDataLoader(
111+
strategy.adapted_dataset,
112+
self.storage_policy.buffer,
113+
oversample_small_tasks=True,
114+
batch_size=batch_size,
115+
batch_size_mem=batch_size_mem,
116+
num_workers=num_workers,
117+
shuffle=shuffle,
118+
drop_last=drop_last,
119+
**other_dataloader_args
120+
)
121+
122+
def after_training_exp(self, strategy: SupervisedTemplate, **kwargs):
123+
experience = strategy.experience
124+
self.current_classes_means = [0 for _ in range(self.n_classes)]
125+
classes_counts = [0 for _ in range(self.n_classes)]
126+
self.models_confidence.append(0)
127+
models_counts = 0
128+
129+
# compute the mean prediction scores that will be used to rectify scores in subsequent incremental states
130+
with torch.no_grad():
131+
strategy.model.eval()
132+
for inputs, targets, _ in strategy.dataloader:
133+
inputs, targets = inputs.to(strategy.device), targets.to(
134+
strategy.device
135+
)
136+
outputs = strategy.model(inputs.to(strategy.device))
137+
scores = outputs.data.cpu().numpy()
138+
for i in range(len(targets)):
139+
target = targets[i].item()
140+
classes_counts[target] += 1
141+
if target in experience.previous_classes:
142+
# compute the mean prediction scores for past classes of the current state
143+
self.current_classes_means[target] += scores[i, target]
144+
else:
145+
# compute the mean prediction scores for the new classes of the current state
146+
self.init_classes_means[target] += scores[i, target]
147+
# compute the mean top scores for the new classes of the current state
148+
self.models_confidence[-1] += np.max(scores[i,])
149+
models_counts += 1
150+
151+
# normalize by corresponding number of samples
152+
for cls in experience.previous_classes:
153+
self.current_classes_means[cls] /= classes_counts[cls]
154+
for cls in experience.classes_in_this_experience:
155+
self.init_classes_means[cls] /= classes_counts[cls]
156+
self.models_confidence[-1] /= models_counts
157+
# store the mapping between classes and the incremental state in which they are first seen
158+
for cls in experience.classes_in_this_experience:
159+
self.classes2exp[cls] = experience.current_experience
160+
161+
# update the buffer of exemplars
162+
self.storage_policy.post_adapt(strategy, strategy.experience)
163+
164+
def after_eval_forward(self, strategy: SupervisedTemplate, **kwargs):
165+
old_classes = strategy.experience.previous_classes
166+
new_classes = strategy.experience.classes_in_this_experience
167+
if not old_classes:
168+
return
169+
170+
outputs = strategy.mb_output
171+
targets = strategy.mbatch[1]
172+
173+
# rectify predicted scores (Eq. 1 in the paper)
174+
for i in range(len(targets)):
175+
# if the top-1 class predicted by the network is a new one, rectify the score
176+
if outputs[i].argmax().item() in new_classes:
177+
for cls in old_classes:
178+
o_exp = self.classes2exp[cls]
179+
if (
180+
self.current_classes_means[cls] == 0
181+
): # when evaluation is done before training
182+
continue
183+
outputs[i, cls] *= (
184+
self.init_classes_means[cls] / self.current_classes_means[cls]
185+
) * (self.models_confidence[-1] / self.models_confidence[o_exp])
186+
# otherwise, rectification is not done because an old class is directly predicted
187+
188+
189+
__all__ = [
190+
"IL2MPlugin",
191+
]

avalanche/training/supervised/strategy_wrappers.py

+89
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
default_evaluator,
2323
default_loggers,
2424
)
25+
from avalanche.training.storage_policy import ExemplarsBuffer
2526
from avalanche.training.plugins import (
2627
SupervisedPlugin,
2728
CWRStarPlugin,
@@ -42,6 +43,7 @@
4243
BiCPlugin,
4344
MIRPlugin,
4445
FromScratchTrainingPlugin,
46+
IL2MPlugin,
4547
)
4648
from avalanche.training.templates.base import BaseTemplate
4749
from avalanche.training.templates import SupervisedTemplate
@@ -1676,6 +1678,92 @@ def __init__(
16761678
)
16771679

16781680

1681+
class IL2M(SupervisedTemplate):
1682+
"""Class Incremental Learning With Dual Memory (IL2M) strategy.
1683+
1684+
See IL2M plugin for details.
1685+
This strategy does not use task identities.
1686+
"""
1687+
1688+
def __init__(
1689+
self,
1690+
*,
1691+
model: Module,
1692+
optimizer: Optimizer,
1693+
criterion: CriterionType,
1694+
mem_size: int = 2000,
1695+
mem_mb_size: Optional[int] = None,
1696+
train_mb_size: int = 1,
1697+
train_epochs: int = 1,
1698+
eval_mb_size: Optional[int] = None,
1699+
storage_policy: Optional[ExemplarsBuffer] = None,
1700+
device: Union[str, torch.device] = "cpu",
1701+
plugins: Optional[List[SupervisedPlugin]] = None,
1702+
evaluator: Union[
1703+
EvaluationPlugin, Callable[[], EvaluationPlugin]
1704+
] = default_evaluator,
1705+
eval_every=-1,
1706+
peval_mode="epoch",
1707+
**base_kwargs
1708+
):
1709+
"""Init.
1710+
1711+
:param model: The model.
1712+
:param optimizer: The optimizer to use.
1713+
:param criterion: The loss criterion to use.
1714+
:param mem_size: Replay buffer size. Defaults to 2000.
1715+
:param mem_mb_size: The size of the memory batch. Defaults to None.
1716+
:param train_mb_size: The train minibatch size. Defaults to 1.
1717+
:param train_epochs: The number of training epochs. Defaults to 1.
1718+
:param eval_mb_size: The eval minibatch size. Defaults to 1.
1719+
:param storage_policy: The policy that controls how to add new exemplars
1720+
in memory. Defaults to None.
1721+
:param device: The device to use. Defaults to None (cpu).
1722+
:param plugins: Plugins to be added. Defaults to None.
1723+
:param evaluator: (optional) Instance of EvaluationPlugin for logging
1724+
and metric computations.
1725+
:param eval_every: The frequency of the calls to `eval` inside the
1726+
training loop. -1 disables the evaluation. 0 means `eval` is called
1727+
only at the end of the learning experience. Values >0 mean that
1728+
`eval` is called every `eval_every` epochs and at the end of the
1729+
learning experience. Defaults to -1.
1730+
:param peval_mode: one of {'experience', 'iteration'}. Decides whether
1731+
the periodic evaluation during training should execute every
1732+
`eval_every` experience or iterations. Default to 'experience'.
1733+
:param **base_kwargs: any additional
1734+
:class:`~avalanche.training.BaseTemplate` constructor arguments.
1735+
"""
1736+
1737+
# Instantiate plugin
1738+
il2m = IL2MPlugin(
1739+
mem_size=mem_size,
1740+
batch_size=train_mb_size,
1741+
batch_size_mem=mem_mb_size,
1742+
storage_policy=storage_policy,
1743+
)
1744+
1745+
# Add plugin to the strategy
1746+
if plugins is None:
1747+
plugins = [il2m]
1748+
else:
1749+
plugins.append(il2m)
1750+
1751+
super().__init__(
1752+
model=model,
1753+
optimizer=optimizer,
1754+
criterion=criterion,
1755+
train_mb_size=train_mb_size,
1756+
train_epochs=train_epochs,
1757+
eval_mb_size=eval_mb_size,
1758+
device=device,
1759+
plugins=plugins,
1760+
evaluator=evaluator,
1761+
eval_every=eval_every,
1762+
peval_mode=peval_mode,
1763+
**base_kwargs
1764+
)
1765+
1766+
16791767
__all__ = [
16801768
"Naive",
16811769
"PNNStrategy",
@@ -1698,4 +1786,5 @@ def __init__(
16981786
"MIR",
16991787
"PackNet",
17001788
"FromScratchTraining",
1789+
"IL2M",
17011790
]

docs/training.rst

+2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Ready-to-use continual learning strategies.
9191
FeatureReplay
9292
supervised.lamaml.LaMAML
9393
supervised.lamaml_v2.LaMAML
94+
IL2M
9495

9596
Replay Buffers and Selection Strategies
9697
----------------------------------------
@@ -196,5 +197,6 @@ Strategy implemented as plugins in `avalanche.training.plugins`.
196197
MemoryNCMUpdate
197198
NCMOracle
198199
CurrentDataNCMUpdate
200+
IL2MPlugin
199201

200202

tests/training/test_strategies.py

+17
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
ExpertGateStrategy,
5454
MER,
5555
FeatureReplay,
56+
IL2M,
5657
)
5758
from avalanche.training.supervised.cumulative import Cumulative
5859
from avalanche.training.supervised.icarl import ICaRL
@@ -1162,6 +1163,22 @@ def test_feature_replay(self):
11621163
)
11631164
run_strategy(benchmark, strategy)
11641165

1166+
def test_il2m(self):
1167+
# SIT scenario
1168+
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
1169+
with self.assertWarns(PositionalArgumentsDeprecatedWarning):
1170+
strategy = IL2M(
1171+
model,
1172+
optimizer,
1173+
criterion,
1174+
mem_size=50,
1175+
train_mb_size=10,
1176+
device=self.device,
1177+
eval_mb_size=50,
1178+
train_epochs=2,
1179+
)
1180+
run_strategy(benchmark, strategy)
1181+
11651182
def load_benchmark(
11661183
self,
11671184
use_task_labels=False,

0 commit comments

Comments
 (0)