Skip to content

Commit 7439a04

Browse files
committed
test: Move WandbLogger test to test_wandb
the trainer tests don't run with wandb installed, so we can't put it there
1 parent 35a62e7 commit 7439a04

File tree

2 files changed

+39
-39
lines changed

2 files changed

+39
-39
lines changed

tests/tests_pytorch/loggers/test_wandb.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lightning.pytorch.callbacks import ModelCheckpoint
2525
from lightning.pytorch.cli import LightningCLI
2626
from lightning.pytorch.demos.boring_classes import BoringModel
27-
from lightning.pytorch.loggers import WandbLogger
27+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
2828
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2929
from tests_pytorch.test_cli import _xfail_python_ge_3_11_9
3030

@@ -133,6 +133,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock):
133133
assert logger._experiment is not None
134134

135135

136+
def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path):
137+
wandb_experiment_called = False
138+
139+
def tensorboard_experiment_side_effect() -> mock.MagicMock:
140+
nonlocal wandb_experiment_called
141+
assert wandb_experiment_called
142+
return mock.MagicMock()
143+
144+
def wandb_experiment_side_effect() -> mock.MagicMock:
145+
nonlocal wandb_experiment_called
146+
wandb_experiment_called = True
147+
return mock.MagicMock()
148+
149+
with (
150+
mock.patch.object(
151+
TensorBoardLogger,
152+
"experiment",
153+
new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect),
154+
),
155+
mock.patch.object(
156+
WandbLogger,
157+
"experiment",
158+
new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect),
159+
),
160+
):
161+
model = BoringModel()
162+
trainer = Trainer(
163+
default_root_dir=tmp_path,
164+
log_every_n_steps=1,
165+
limit_train_batches=0,
166+
limit_val_batches=0,
167+
max_steps=1,
168+
logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)],
169+
)
170+
trainer.fit(model)
171+
172+
136173
def test_wandb_pickle(wandb_mock, tmp_path):
137174
"""Verify that pickling trainer with wandb logger works.
138175

tests/tests_pytorch/trainer/test_trainer.py

+1-38
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
RandomIterableDataset,
5050
RandomIterableDatasetWithLen,
5151
)
52-
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
52+
from lightning.pytorch.loggers import TensorBoardLogger
5353
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
5454
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
5555
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
@@ -1271,43 +1271,6 @@ def training_step(self, *args, **kwargs):
12711271
log_metrics_mock.assert_has_calls(expected_calls)
12721272

12731273

1274-
def test_wandb_logger_experiment_called_first(tmp_path):
1275-
wandb_experiment_called = False
1276-
1277-
def tensorboard_experiment_side_effect() -> mock.MagicMock:
1278-
nonlocal wandb_experiment_called
1279-
assert wandb_experiment_called
1280-
return mock.MagicMock()
1281-
1282-
def wandb_experiment_side_effect() -> mock.MagicMock:
1283-
nonlocal wandb_experiment_called
1284-
wandb_experiment_called = True
1285-
return mock.MagicMock()
1286-
1287-
with (
1288-
mock.patch.object(
1289-
TensorBoardLogger,
1290-
"experiment",
1291-
new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect),
1292-
),
1293-
mock.patch.object(
1294-
WandbLogger,
1295-
"experiment",
1296-
new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect),
1297-
),
1298-
):
1299-
model = BoringModel()
1300-
trainer = Trainer(
1301-
default_root_dir=tmp_path,
1302-
log_every_n_steps=1,
1303-
limit_train_batches=0,
1304-
limit_val_batches=0,
1305-
max_steps=1,
1306-
logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)],
1307-
)
1308-
trainer.fit(model)
1309-
1310-
13111274
class TestLightningDataModule(LightningDataModule):
13121275
def __init__(self, dataloaders):
13131276
super().__init__()

0 commit comments

Comments
 (0)