Skip to content

Commit

Permalink
Fix compositional logging with lightning (#1761)
Browse files Browse the repository at this point in the history
* tests

* tests

* fix implementation

* changelog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Jun 15, 2023
1 parent d396cda commit 181e112
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed lookup for punkt sources being downloaded in `RougeScore` ([#1789](https://github.com/Lightning-AI/torchmetrics/pull/1789))


- Fixed integration with lightning for `CompositionalMetric` ([#1761](https://github.com/Lightning-AI/torchmetrics/pull/1761))


- Fixed several bugs in `SpectralDistortionIndex` metric ([#1808](https://github.com/Lightning-AI/torchmetrics/pull/1808))


Expand Down
12 changes: 8 additions & 4 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,17 +1086,21 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
)

if val_a is None:
return None
self._forward_cache = None
return self._forward_cache

if val_b is None:
if isinstance(self.metric_b, Metric):
return None
self._forward_cache = None
return self._forward_cache

# Unary op
return self.op(val_a)
self._forward_cache = self.op(val_a)
return self._forward_cache

# Binary op
return self.op(val_a, val_b)
self._forward_cache = self.op(val_a, val_b)
return self._forward_cache

def reset(self) -> None:
"""Redirect the call to the input which the conposition was formed from."""
Expand Down
129 changes: 112 additions & 17 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

if module_available("lightning"):
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import CSVLogger
else:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger

from torchmetrics import MetricCollection
from torchmetrics.aggregation import SumMetric
Expand Down Expand Up @@ -180,44 +182,137 @@ def test_metric_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.metric_step = SumMetric()
self.metric_epoch = SumMetric()
self.register_buffer("sum", torch.tensor(0.0))
self.outs = []

def on_train_epoch_start(self):
self.sum = torch.tensor(0.0, device=self.sum.device)
# initiliaze one metric for every combination of `on_step` and `on_epoch` and `forward` and `update`
self.metric_update = SumMetric()
self.metric_update_step = SumMetric()
self.metric_update_epoch = SumMetric()

self.metric_forward = SumMetric()
self.metric_forward_step = SumMetric()
self.metric_forward_epoch = SumMetric()

self.compo_update = SumMetric() + SumMetric()
self.compo_update_step = SumMetric() + SumMetric()
self.compo_update_epoch = SumMetric() + SumMetric()

self.compo_forward = SumMetric() + SumMetric()
self.compo_forward_step = SumMetric() + SumMetric()
self.compo_forward_epoch = SumMetric() + SumMetric()

self.sum = []

def training_step(self, batch, batch_idx):
x = batch
self.metric_step(x.sum())
self.sum += x.sum()
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
self.outs.append(x)
return self.step(x)
s = x.sum()

def on_train_epoch_end(self):
self.log("sum_epoch", self.metric_epoch(torch.stack(self.outs)))
self.outs = []
for metric in [self.metric_update, self.metric_update_step, self.metric_update_epoch]:
metric.update(s)
for metric in [self.metric_forward, self.metric_forward_step, self.metric_forward_epoch]:
_ = metric(s)
for metric in [self.compo_update, self.compo_update_step, self.compo_update_epoch]:
metric.update(s)
for metric in [self.compo_forward, self.compo_forward_step, self.compo_forward_epoch]:
_ = metric(s)

self.sum.append(s)

self.log("metric_update", self.metric_update)
self.log("metric_update_step", self.metric_update_step, on_epoch=False, on_step=True)
self.log("metric_update_epoch", self.metric_update_epoch, on_epoch=True, on_step=False)

self.log("metric_forward", self.metric_forward)
self.log("metric_forward_step", self.metric_forward_step, on_epoch=False, on_step=True)
self.log("metric_forward_epoch", self.metric_forward_epoch, on_epoch=True, on_step=False)

self.log("compo_update", self.compo_update)
self.log("compo_update_step", self.compo_update_step, on_epoch=False, on_step=True)
self.log("compo_update_epoch", self.compo_update_epoch, on_epoch=True, on_step=False)

self.log("compo_forward", self.compo_forward)
self.log("compo_forward_step", self.compo_forward_step, on_epoch=False, on_step=True)
self.log("compo_forward_epoch", self.compo_forward_epoch, on_epoch=True, on_step=False)

return self.step(x)

model = TestModel()

logger = CSVLogger("tmpdir/logs")
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=1,
logger=logger,
)
with no_warning_call(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(tensor(logged["sum_step"]), model.sum, atol=2e-4)
assert torch.allclose(tensor(logged["sum_epoch"]), model.sum, atol=2e-4)
logged_metrics = logger._experiment.metrics

epoch_0_step_0 = logged_metrics[0]
assert "metric_forward" in epoch_0_step_0
assert epoch_0_step_0["metric_forward"] == model.sum[0]
assert "metric_forward_step" in epoch_0_step_0
assert epoch_0_step_0["metric_forward_step"] == model.sum[0]
assert "compo_forward" in epoch_0_step_0
assert epoch_0_step_0["compo_forward"] == 2 * model.sum[0]
assert "compo_forward_step" in epoch_0_step_0
assert epoch_0_step_0["compo_forward_step"] == 2 * model.sum[0]

epoch_0_step_1 = logged_metrics[1]
assert "metric_forward" in epoch_0_step_1
assert epoch_0_step_1["metric_forward"] == model.sum[1]
assert "metric_forward_step" in epoch_0_step_1
assert epoch_0_step_1["metric_forward_step"] == model.sum[1]
assert "compo_forward" in epoch_0_step_1
assert epoch_0_step_1["compo_forward"] == 2 * model.sum[1]
assert "compo_forward_step" in epoch_0_step_1
assert epoch_0_step_1["compo_forward_step"] == 2 * model.sum[1]

epoch_0 = logged_metrics[2]
assert "metric_update_epoch" in epoch_0
assert epoch_0["metric_update_epoch"] == sum([model.sum[0], model.sum[1]])
assert "metric_forward_epoch" in epoch_0
assert epoch_0["metric_forward_epoch"] == sum([model.sum[0], model.sum[1]])
assert "compo_update_epoch" in epoch_0
assert epoch_0["compo_update_epoch"] == 2 * sum([model.sum[0], model.sum[1]])
assert "compo_forward_epoch" in epoch_0
assert epoch_0["compo_forward_epoch"] == 2 * sum([model.sum[0], model.sum[1]])

epoch_1_step_0 = logged_metrics[3]
assert "metric_forward" in epoch_1_step_0
assert epoch_1_step_0["metric_forward"] == model.sum[2]
assert "metric_forward_step" in epoch_1_step_0
assert epoch_1_step_0["metric_forward_step"] == model.sum[2]
assert "compo_forward" in epoch_1_step_0
assert epoch_1_step_0["compo_forward"] == 2 * model.sum[2]
assert "compo_forward_step" in epoch_1_step_0
assert epoch_1_step_0["compo_forward_step"] == 2 * model.sum[2]

epoch_1_step_1 = logged_metrics[4]
assert "metric_forward" in epoch_1_step_1
assert epoch_1_step_1["metric_forward"] == model.sum[3]
assert "metric_forward_step" in epoch_1_step_1
assert epoch_1_step_1["metric_forward_step"] == model.sum[3]
assert "compo_forward" in epoch_1_step_1
assert epoch_1_step_1["compo_forward"] == 2 * model.sum[3]
assert "compo_forward_step" in epoch_1_step_1
assert epoch_1_step_1["compo_forward_step"] == 2 * model.sum[3]

epoch_1 = logged_metrics[5]
assert "metric_update_epoch" in epoch_1
assert epoch_1["metric_update_epoch"] == sum([model.sum[2], model.sum[3]])
assert "metric_forward_epoch" in epoch_1
assert epoch_1["metric_forward_epoch"] == sum([model.sum[2], model.sum[3]])
assert "compo_update_epoch" in epoch_1
assert epoch_1["compo_update_epoch"] == 2 * sum([model.sum[2], model.sum[3]])
assert "compo_forward_epoch" in epoch_1
assert epoch_1["compo_forward_epoch"] == 2 * sum([model.sum[2], model.sum[3]])


def test_metric_collection_lightning_log(tmpdir):
Expand Down

0 comments on commit 181e112

Please sign in to comment.