diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index d05cb38c..a5d1ef25 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -1,6 +1,7 @@ from gettext import gettext from typing import ( TYPE_CHECKING, + Callable, Dict, List, Optional, @@ -44,10 +45,9 @@ class SnapshotAssertion: _test_location: "TestLocation" = attr.ib(kw_only=True) _update_snapshots: bool = attr.ib(kw_only=True) _extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None) - _executions: int = attr.ib(init=False, default=0, kw_only=True) - _execution_results: Dict[int, "AssertionResult"] = attr.ib( - init=False, factory=dict, kw_only=True - ) + _executions: int = attr.ib(init=False, default=0) + _execution_results: Dict[int, "AssertionResult"] = attr.ib(init=False, factory=dict) + _post_assert_actions: List[Callable[..., None]] = attr.ib(init=False, factory=list) def __attrs_post_init__(self) -> None: self._session.register_request(self) @@ -108,6 +108,11 @@ def __call__( """ if extension_class: self._extension = self.__init_extension(extension_class) + + def clear_extension() -> None: + self._extension = None + + self._post_assert_actions.append(clear_extension) return self def __repr__(self) -> str: @@ -155,7 +160,8 @@ def _post_assert(self) -> None: """ Restores assertion instance options """ - self._extension = None + while self._post_assert_actions: + self._post_assert_actions.pop()() def _recall_data(self, index: int) -> Optional["SerializableData"]: try: diff --git a/tests/test_extension_image.py b/tests/test_extension_image.py index af67e103..f538103c 100644 --- a/tests/test_extension_image.py +++ b/tests/test_extension_image.py @@ -51,5 +51,6 @@ def test_multiple_snapshot_extensions(snapshot): """ assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension) assert actual_svg == snapshot # uses initial extension class + assert snapshot._extension is not None assert actual_png == snapshot(extension_class=PNGImageSnapshotExtension) assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)