From dc381d491c66e147a0b784eab4231548985b5a52 Mon Sep 17 00:00:00 2001 From: Emmanuel Ogbizi Date: Fri, 3 Apr 2020 09:53:00 -0400 Subject: [PATCH 1/2] test: extension is not cleared when not overridden --- tests/test_extension_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_extension_image.py b/tests/test_extension_image.py index af67e103..f538103c 100644 --- a/tests/test_extension_image.py +++ b/tests/test_extension_image.py @@ -51,5 +51,6 @@ def test_multiple_snapshot_extensions(snapshot): """ assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension) assert actual_svg == snapshot # uses initial extension class + assert snapshot._extension is not None assert actual_png == snapshot(extension_class=PNGImageSnapshotExtension) assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension) From 1548265ad462a873b850905a6c0ca3d66611d018 Mon Sep 17 00:00:00 2001 From: Emmanuel Ogbizi Date: Fri, 3 Apr 2020 09:53:41 -0400 Subject: [PATCH 2/2] perf: only clear extension when overridden --- src/syrupy/assertion.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index d05cb38c..a5d1ef25 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -1,6 +1,7 @@ from gettext import gettext from typing import ( TYPE_CHECKING, + Callable, Dict, List, Optional, @@ -44,10 +45,9 @@ class SnapshotAssertion: _test_location: "TestLocation" = attr.ib(kw_only=True) _update_snapshots: bool = attr.ib(kw_only=True) _extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None) - _executions: int = attr.ib(init=False, default=0, kw_only=True) - _execution_results: Dict[int, "AssertionResult"] = attr.ib( - init=False, factory=dict, kw_only=True - ) + _executions: int = attr.ib(init=False, default=0) + _execution_results: Dict[int, "AssertionResult"] = attr.ib(init=False, factory=dict) + _post_assert_actions: List[Callable[..., None]] = attr.ib(init=False, factory=list) def __attrs_post_init__(self) -> None: self._session.register_request(self) @@ -108,6 +108,11 @@ def __call__( """ if extension_class: self._extension = self.__init_extension(extension_class) + + def clear_extension() -> None: + self._extension = None + + self._post_assert_actions.append(clear_extension) return self def __repr__(self) -> str: @@ -155,7 +160,8 @@ def _post_assert(self) -> None: """ Restores assertion instance options """ - self._extension = None + while self._post_assert_actions: + self._post_assert_actions.pop()() def _recall_data(self, index: int) -> Optional["SerializableData"]: try: