Skip to content

Commit 7cc252f

Browse files
committed
[benchmarks] typing
1 parent 1486573 commit 7cc252f

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

avalanche/benchmarks/utils/classification_dataset.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
labels automatically. Concatenation and subsampling operations are optimized
1717
to be used frequently, as is common in replay strategies.
1818
"""
19-
2019
from functools import partial
2120
from typing import (
2221
List,
@@ -29,7 +28,7 @@
2928
Dict,
3029
Tuple,
3130
Mapping,
32-
overload,
31+
overload, Self,
3332
)
3433

3534
import torch
@@ -64,11 +63,11 @@
6463
)
6564

6665
T_co = TypeVar("T_co", covariant=True)
67-
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
66+
TAvalancheDataset = TypeVar("TAvalancheDataset", bound=AvalancheDataset)
6867
TTargetType = int
6968

7069
TClassificationDataset = TypeVar(
71-
"TClassificationDataset", bound="ClassificationDataset"
70+
"TClassificationDataset", bound=IDatasetWithTargets
7271
)
7372

7473

@@ -114,7 +113,7 @@ def task_pattern_indices(self) -> Dict[int, Sequence[int]]:
114113
return self.targets_task_labels.val_to_idx # type: ignore
115114

116115
@property
117-
def task_set(self: TClassificationDataset) -> TaskSet[TClassificationDataset]:
116+
def task_set(self) -> TaskSet[Self]:
118117
"""Returns the dataset's ``TaskSet``, which is a mapping <task-id,
119118
task-dataset>."""
120119
return TaskSet(self)

avalanche/benchmarks/utils/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -653,13 +653,15 @@ class TaskSet(Mapping[int, TAvalancheDataset], Generic[TAvalancheDataset]):
653653
654654
"""
655655

656+
data: TAvalancheDataset
657+
656658
def __init__(self, data: TAvalancheDataset):
657659
"""Constructor.
658660
659661
:param data: original data
660662
"""
661663
super().__init__()
662-
self.data: TAvalancheDataset = data
664+
self.data = data
663665

664666
def __iter__(self) -> Iterator[int]:
665667
t_labels = self._get_task_labels_field()

0 commit comments

Comments
 (0)