Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MPS Backend [without torch.amp.autocast ] + CI #3041

Merged
merged 11 commits into from
Nov 22, 2023
125 changes: 125 additions & 0 deletions .github/workflows/mps-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
name: Run unit tests on M1
on:
push:
branches:
- master
- "*.*.*"
paths:
- "ignite/**"
- "tests/ignite/**"
- "tests/run_code_style.sh"
- "examples/**.py"
- "requirements-dev.txt"
- ".github/workflows/mps-tests.yml"
pull_request:
paths:
- "ignite/**"
- "tests/ignite/**"
- "tests/run_code_style.sh"
- "examples/**.py"
- "requirements-dev.txt"
- ".github/workflows/mps-tests.yml"
workflow_dispatch:

concurrency:
# <workflow_name>-<branch_name>-<true || commit_sha (if branch is protected)>
group: mps-tests-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }}
cancel-in-progress: true

# Cherry-picked from
# - https://github.com/pytorch/vision/main/.github/workflows/tests.yml
# - https://github.com/pytorch/test-infra/blob/main/.github/workflows/macos_job.yml

jobs:
mps-tests:
strategy:
matrix:
python-version: [3.8]
pytorch-channel: ["pytorch"]
skip-distrib-tests: [1]
fail-fast: false
runs-on: ["macos-m1-12"]
timeout-minutes: 60

steps:
- name: Clean workspace
run: |
echo "::group::Cleanup debug output"
sudo rm -rfv "${GITHUB_WORKSPACE}"
mkdir -p "${GITHUB_WORKSPACE}"
echo "::endgroup::"

- name: Checkout repository (pytorch/test-infra)
uses: actions/checkout@v3
with:
# Support the use case where we need to checkout someone's fork
repository: pytorch/test-infra
path: test-infra

- name: Checkout repository (${{ github.repository }})
uses: actions/checkout@v3
with:
# Support the use case where we need to checkout someone's fork
repository: ${{ github.repository }}
ref: ${{ github.ref }}
path: ${{ github.repository }}
fetch-depth: 1

- name: Setup miniconda
uses: ./test-infra/.github/actions/setup-miniconda
with:
python-version: ${{ matrix.python-version }}

- name: Install PyTorch
if: ${{ matrix.pytorch-channel == 'pytorch' }}
shell: bash -l {0}
run: pip install torch torchvision

- name: Install PyTorch (nightly)
if: ${{ matrix.pytorch-channel == 'pytorch-nightly' }}
shell: bash -l {0}
run: pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu

- name: Install dependencies
shell: bash -l {0}
working-directory: ${{ github.repository }}
run: |
# TODO: We add set -xe to explicitly fail the CI if one of the commands is failing.
# Somehow the step is passing even if a subcommand failed
set -xe
pip install -r requirements-dev.txt
echo "1 returned code: $?"
pip install -e .
echo "2 returned code: $?"
pip list
echo "3 returned code: $?"

# Download MNIST: https://github.com/pytorch/ignite/issues/1737
# to "/tmp" for unit tests
- name: Download MNIST
uses: pytorch-ignite/download-mnist-github-action@master
with:
target_dir: /tmp

# Copy MNIST to "." for the examples
- name: Copy MNIST
run: |
cp -R /tmp/MNIST .

- name: Run Tests
shell: bash -l {0}
working-directory: ${{ github.repository }}
run: |
SKIP_DISTRIB_TESTS=${{ matrix.skip-distrib-tests }} bash tests/run_cpu_tests.sh

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ${{ github.repository }}/coverage.xml
flags: mps
fail_ci_if_error: false

- name: Run MNIST Examples
shell: bash -l {0}
working-directory: ${{ github.repository }}
run: python examples/mnist/mnist.py --epochs=1
1 change: 1 addition & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
run: |
pip install -r requirements-dev.txt
python setup.py install
pip list

- name: Check code formatting
run: |
Expand Down
5 changes: 5 additions & 0 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any, Callable, cast, List, Optional, Union

import torch
from packaging.version import Version

_torch_version_le_112 = Version(torch.__version__) > Version("1.12.0")


class ComputationModel(metaclass=ABCMeta):
Expand Down Expand Up @@ -326,6 +329,8 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if _torch_version_le_112 and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")

def backend(self) -> Optional[str]:
Expand Down
19 changes: 15 additions & 4 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def supervised_training_step(
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
.. versionchanged:: 0.4.14
Added support for ``mps`` device
"""

if gradient_accumulation_steps <= 0:
Expand Down Expand Up @@ -391,9 +393,12 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to


def _check_arg(
on_tpu: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
"""Checking tpu, amp and GradScaler instance combinations."""
"""Checking tpu, mps, amp and GradScaler instance combinations."""
if on_mps and amp_mode:
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")

if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

Expand Down Expand Up @@ -546,11 +551,14 @@ def output_transform_fn(x, y, y_pred, loss):
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
.. versionchanged:: 0.4.14
Added support for ``mps`` device
"""

device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _scaler = _check_arg(on_tpu, amp_mode, scaler)
on_mps = "mps" in device_type if device_type is not None else False
mode, _scaler = _check_arg(on_tpu, on_mps, amp_mode, scaler)

if mode == "amp":
_update = supervised_training_step_amp(
Expand Down Expand Up @@ -791,10 +799,13 @@ def create_supervised_evaluator(
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
.. versionchanged:: 0.4.14
Added support for ``mps`` device
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

metrics = metrics or {}
if mode == "amp":
Expand Down
5 changes: 4 additions & 1 deletion tests/ignite/distributed/comp_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
import torch

from ignite.distributed.comp_models.base import _SerialModel, ComputationModel
from ignite.distributed.comp_models.base import _SerialModel, _torch_version_le_112, ComputationModel


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_serial_model():
_SerialModel.create_from_backend()
model = _SerialModel.create_from_context()
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from ignite.distributed.comp_models.base import _torch_version_le_112


class DummyDS(Dataset):
Expand Down Expand Up @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device):
assert optimizer.backward_passes_per_step == backward_passes_per_step


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from packaging.version import Version

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support


Expand Down Expand Up @@ -257,6 +258,9 @@ def test_idist_parallel_n_procs_native(init_method, backend, get_fixed_dirname,


@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_idist_parallel_no_dist():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with idist.Parallel(backend=None) as parallel:
Expand Down
5 changes: 5 additions & 0 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import torch

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from tests.ignite.distributed.utils import (
_sanity_check,
_test_distrib__get_max_length,
Expand All @@ -13,6 +15,9 @@
)


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_no_distrib(capsys):
assert idist.backend() is None
if torch.cuda.is_available():
Expand Down
40 changes: 36 additions & 4 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.optim import SGD

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.engine import (
_check_arg,
create_supervised_evaluator,
Expand Down Expand Up @@ -196,7 +197,8 @@ def _test_create_mocked_supervised_trainer(
data = [(x, y)]

on_tpu = "xla" in trainer_device if trainer_device is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, scaler)
on_mps = "mps" in trainer_device if trainer_device is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, scaler)

if model_device == trainer_device or ((model_device == "cpu") ^ (trainer_device == "cpu")):
trainer.run(data)
Expand Down Expand Up @@ -306,7 +308,9 @@ def _test_create_supervised_evaluator(
else:
if Version(torch.__version__) >= Version("1.7.0"):
# This is broken in 1.6.0 but will be probably fixed with 1.7.0
with pytest.raises(RuntimeError, match=r"Expected all tensors to be on the same device"):
err_msg_1 = "Expected all tensors to be on the same device"
err_msg_2 = "Placeholder storage has not been allocated on MPS device"
with pytest.raises(RuntimeError, match=f"({err_msg_1}|{err_msg_2})"):
evaluator.run(data)


Expand Down Expand Up @@ -358,7 +362,8 @@ def _test_create_evaluation_step_amp(

device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

evaluate_step = supervised_evaluation_step_amp(model, evaluator_device, output_transform=output_transform_mock)

Expand Down Expand Up @@ -393,7 +398,8 @@ def _test_create_evaluation_step(

device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

evaluate_step = supervised_evaluation_step(model, evaluator_device, output_transform=output_transform_mock)

Expand Down Expand Up @@ -475,6 +481,19 @@ def test_create_supervised_trainer_on_cuda():
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS")
def test_create_supervised_trainer_on_mps():
model_device = trainer_device = "mps"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_amp():
Expand Down Expand Up @@ -643,6 +662,19 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
_test_mocked_supervised_evaluator(evaluator_device="cuda")


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps():
model_device = evaluator_device = "mps"
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
_test_create_supervised_evaluator(evaluator_device="mps")
_test_mocked_supervised_evaluator(evaluator_device="mps")


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda_amp():
Expand Down