Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in macro average for a number of classification metrics #1821

Merged
merged 13 commits into from
Jun 15, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support for half precision in `PearsonCorrCoef` ([#1819](https://github.com/Lightning-AI/torchmetrics/pull/1819))


- Fixed number of bugs related to `average="macro"` in classification metrics ([#1821](https://github.com/Lightning-AI/torchmetrics/pull/1821))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def __init__(
def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average)
return _fbeta_reduce(
tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
Expand Down Expand Up @@ -819,7 +819,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ class MultilabelSpecificity(MultilabelStatScores):
def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)
return _specificity_reduce(
tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
7 changes: 2 additions & 5 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand Down Expand Up @@ -83,10 +83,7 @@ def _accuracy_reduce(
return _safe_divide(tp, tp + fn)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
if average is None or average == "none":
return score
weights = tp + fn if average == "weighted" else torch.ones_like(score)
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)


def binary_accuracy(
Expand Down
10 changes: 4 additions & 6 deletions src/torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand All @@ -43,6 +43,7 @@ def _fbeta_reduce(
beta: float,
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
) -> Tensor:
beta2 = beta**2
if average == "binary":
Expand All @@ -54,10 +55,7 @@ def _fbeta_reduce(
return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp)

fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp)
if average is None or average == "none":
return fbeta_score
weights = tp + fn if average == "weighted" else torch.ones_like(fbeta_score)
return _safe_divide(weights * fbeta_score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn)


def _binary_fbeta_score_arg_validation(
Expand Down Expand Up @@ -375,7 +373,7 @@ def multilabel_fbeta_score(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average)
return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, multilabel=True)


def binary_f1_score(
Expand Down
7 changes: 2 additions & 5 deletions src/torchmetrics/functional/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand Down Expand Up @@ -80,10 +80,7 @@ def _hamming_distance_reduce(
return 1 - _safe_divide(tp, tp + fn)

score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else 1 - _safe_divide(tp, tp + fn)
if average is None or average == "none":
return score
weights = tp + fn if average == "weighted" else torch.ones_like(score)
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)


def binary_hamming_distance(
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _jaccard_index_reduce(
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])

ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]
if confmat.ndim == 3: # multilabel
multilabel = confmat.ndim == 3
if multilabel:
num = confmat[:, 1, 1]
denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0]
else: # multiclass
Expand All @@ -87,6 +88,8 @@ def _jaccard_index_reduce(
weights = torch.ones_like(jaccard)
if ignore_index_cond:
weights[ignore_index] = 0.0
if not multilabel:
weights[confmat.sum(1) + confmat.sum(0) == 0] = 0.0
return ((weights * jaccard) / weights.sum()).sum()


Expand Down
16 changes: 9 additions & 7 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand All @@ -43,6 +43,7 @@ def _precision_recall_reduce(
fn: Tensor,
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -54,10 +55,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat)

score = _safe_divide(tp, tp + different_stat)
if average is None or average == "none":
return score
weights = tp + fn if average == "weighted" else torch.ones_like(score)
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)


def binary_precision(
Expand Down Expand Up @@ -336,7 +334,9 @@ def multilabel_precision(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True
)


def binary_recall(
Expand Down Expand Up @@ -615,7 +615,9 @@ def multilabel_recall(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True
)


def precision(
Expand Down
10 changes: 4 additions & 6 deletions src/torchmetrics/functional/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_multilabel_stat_scores_tensor_validation,
_multilabel_stat_scores_update,
)
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
from torchmetrics.utilities.enums import ClassificationTask


Expand All @@ -42,6 +42,7 @@ def _specificity_reduce(
fn: Tensor,
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
) -> Tensor:
if average == "binary":
return _safe_divide(tn, tn + fp)
Expand All @@ -51,10 +52,7 @@ def _specificity_reduce(
return _safe_divide(tn, tn + fp)

specificity_score = _safe_divide(tn, tn + fp)
if average is None or average == "none":
return specificity_score
weights = tp + fn if average == "weighted" else torch.ones_like(specificity_score)
return _safe_divide(weights * specificity_score, weights.sum(-1, keepdim=True)).sum(-1)
return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fp, fn)


def binary_specificity(
Expand Down Expand Up @@ -333,7 +331,7 @@ def multilabel_specificity(
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True)


def specificity(
Expand Down
16 changes: 15 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -55,6 +55,20 @@ def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:
return num / denom


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor
) -> Tensor:
if average is None or average == "none":
return score
if average == "weighted":
weights = tp + fn
else:
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0] = 0.0
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


def _auc_format_inputs(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
"""Check that auc input is correct."""
x = x.squeeze() if x.ndim > 1 else x
Expand Down
Loading