Skip to content

Commit 7fd8f96

Browse files
author
ibriquem
committed
Make dataclasses/attrs comparison recursive, fixes pytest-dev#4675
1 parent 9214e63 commit 7fd8f96

File tree

4 files changed

+144
-21
lines changed

4 files changed

+144
-21
lines changed

changelog/4675.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make dataclasses/attrs comparison recursive.

src/_pytest/assertion/util.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -148,26 +148,7 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
148148
explanation = None
149149
try:
150150
if op == "==":
151-
if istext(left) and istext(right):
152-
explanation = _diff_text(left, right, verbose)
153-
else:
154-
if issequence(left) and issequence(right):
155-
explanation = _compare_eq_sequence(left, right, verbose)
156-
elif isset(left) and isset(right):
157-
explanation = _compare_eq_set(left, right, verbose)
158-
elif isdict(left) and isdict(right):
159-
explanation = _compare_eq_dict(left, right, verbose)
160-
elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
161-
type_fn = (isdatacls, isattrs)
162-
explanation = _compare_eq_cls(left, right, verbose, type_fn)
163-
elif verbose > 0:
164-
explanation = _compare_eq_verbose(left, right)
165-
if isiterable(left) and isiterable(right):
166-
expl = _compare_eq_iterable(left, right, verbose)
167-
if explanation is not None:
168-
explanation.extend(expl)
169-
else:
170-
explanation = expl
151+
explanation = _compare_eq_any(left, right, verbose)
171152
elif op == "not in":
172153
if istext(left) and istext(right):
173154
explanation = _notin_text(left, right, verbose)
@@ -187,6 +168,28 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
187168
return [summary] + explanation
188169

189170

171+
def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
172+
explanation = [] # type: List[str]
173+
if istext(left) and istext(right):
174+
explanation = _diff_text(left, right, verbose)
175+
else:
176+
if issequence(left) and issequence(right):
177+
explanation = _compare_eq_sequence(left, right, verbose)
178+
elif isset(left) and isset(right):
179+
explanation = _compare_eq_set(left, right, verbose)
180+
elif isdict(left) and isdict(right):
181+
explanation = _compare_eq_dict(left, right, verbose)
182+
elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
183+
type_fn = (isdatacls, isattrs)
184+
explanation = _compare_eq_cls(left, right, verbose, type_fn)
185+
elif verbose > 0:
186+
explanation = _compare_eq_verbose(left, right)
187+
if isiterable(left) and isiterable(right):
188+
expl = _compare_eq_iterable(left, right, verbose)
189+
explanation.extend(expl)
190+
return explanation
191+
192+
190193
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
191194
"""Return the explanation for the diff between text.
192195
@@ -439,7 +442,10 @@ def _compare_eq_cls(
439442
explanation += ["Differing attributes:"]
440443
for field in diff:
441444
explanation += [
442-
("%s: %r != %r") % (field, getattr(left, field), getattr(right, field))
445+
("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)),
446+
"",
447+
"Recursive from previous comparison",
448+
*_compare_eq_any(getattr(left, field), getattr(right, field), verbose),
443449
]
444450
return explanation
445451

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from dataclasses import dataclass
2+
from dataclasses import field
3+
4+
5+
@dataclass
6+
class SimpleDataObject:
7+
field_a: int = field()
8+
field_b: int = field()
9+
10+
11+
@dataclass
12+
class ComplexDataObject2:
13+
field_a: SimpleDataObject = field()
14+
field_b: SimpleDataObject = field()
15+
16+
17+
@dataclass
18+
class ComplexDataObject:
19+
field_a: SimpleDataObject = field()
20+
field_b: ComplexDataObject2 = field()
21+
22+
23+
def test_recursive_dataclasses():
24+
25+
left = ComplexDataObject(
26+
SimpleDataObject(1, "b"),
27+
ComplexDataObject2(SimpleDataObject(1, "b"), SimpleDataObject(2, "c"),),
28+
)
29+
right = ComplexDataObject(
30+
SimpleDataObject(1, "b"),
31+
ComplexDataObject2(SimpleDataObject(1, "b"), SimpleDataObject(3, "c"),),
32+
)
33+
34+
assert left == right

testing/test_assertion.py

+82
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,50 @@ def test_dataclasses(self, testdir):
756756
"*Omitting 1 identical items, use -vv to show*",
757757
"*Differing attributes:*",
758758
"*field_b: 'b' != 'c'*",
759+
"*- c*",
760+
"*+ b*",
761+
]
762+
)
763+
764+
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
765+
def test_recursive_dataclasses(self, testdir):
766+
p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py")
767+
result = testdir.runpytest(p)
768+
result.assert_outcomes(failed=1, passed=0)
769+
result.stdout.fnmatch_lines(
770+
[
771+
"*Omitting 1 identical items, use -vv to show*",
772+
"*Differing attributes:*",
773+
"*field_b: ComplexDataObject2(*SimpleDataObject(field_a=2, field_b='c')) "
774+
"!= ComplexDataObject2(*SimpleDataObject(field_a=3, field_b='c'))*",
775+
"*Recursive from previous comparison*",
776+
"*Omitting 1 identical items, use -vv to show*",
777+
"*Differing attributes:*",
778+
"*Full output truncated*",
779+
]
780+
)
781+
782+
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
783+
def test_recursive_dataclasses_verbose(self, testdir):
784+
p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py")
785+
result = testdir.runpytest(p, "-vv")
786+
result.assert_outcomes(failed=1, passed=0)
787+
result.stdout.fnmatch_lines(
788+
[
789+
"*Matching attributes:*",
790+
"*['field_a']*",
791+
"*Differing attributes:*",
792+
"*field_b: ComplexDataObject2(*SimpleDataObject(field_a=2, field_b='c')) != "
793+
"ComplexDataObject2(*SimpleDataObject(field_a=3, field_b='c'))*",
794+
"*Matching attributes:*",
795+
"*['field_a']*",
796+
"*Differing attributes:*",
797+
"*field_b: SimpleDataObject(field_a=2, field_b='c') "
798+
"!= SimpleDataObject(field_a=3, field_b='c')*",
799+
"*Matching attributes:*",
800+
"*['field_b']*",
801+
"*Differing attributes:*",
802+
"*field_a: 2 != 3",
759803
]
760804
)
761805

@@ -806,6 +850,44 @@ class SimpleDataObject:
806850
for line in lines[1:]:
807851
assert "field_a" not in line
808852

853+
def test_attrs_recursive(self):
854+
@attr.s
855+
class OtherDataObject:
856+
field_c = attr.ib()
857+
field_d = attr.ib()
858+
859+
@attr.s
860+
class SimpleDataObject:
861+
field_a = attr.ib()
862+
field_b = attr.ib()
863+
864+
left = SimpleDataObject(OtherDataObject(1, "a"), "b")
865+
right = SimpleDataObject(OtherDataObject(1, "b"), "b")
866+
867+
lines = callequal(left, right)
868+
assert "Matching attributes" not in lines
869+
for line in lines[1:]:
870+
assert "field_b:" not in line
871+
assert "field_c:" not in line
872+
873+
def test_attrs_recursive_verbose(self):
874+
@attr.s
875+
class OtherDataObject:
876+
field_c = attr.ib()
877+
field_d = attr.ib()
878+
879+
@attr.s
880+
class SimpleDataObject:
881+
field_a = attr.ib()
882+
field_b = attr.ib()
883+
884+
left = SimpleDataObject(OtherDataObject(1, "a"), "b")
885+
right = SimpleDataObject(OtherDataObject(1, "b"), "b")
886+
887+
lines = callequal(left, right)
888+
assert "field_d: 'a' != 'b'" in lines
889+
print("\n".join(lines))
890+
809891
def test_attrs_verbose(self):
810892
@attr.s
811893
class SimpleDataObject:

0 commit comments

Comments
 (0)