Skip to content

Commit d19f8c6

Browse files
author
ibriquem
committed
Make dataclasses/attrs comparison recursive, fixes pytest-dev#4675
1 parent f77d606 commit d19f8c6

File tree

4 files changed

+66
-21
lines changed

4 files changed

+66
-21
lines changed

changelog/4675.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make dataclasses/attrs comparison recursive.

src/_pytest/assertion/util.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -148,26 +148,7 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
148148
explanation = None
149149
try:
150150
if op == "==":
151-
if istext(left) and istext(right):
152-
explanation = _diff_text(left, right, verbose)
153-
else:
154-
if issequence(left) and issequence(right):
155-
explanation = _compare_eq_sequence(left, right, verbose)
156-
elif isset(left) and isset(right):
157-
explanation = _compare_eq_set(left, right, verbose)
158-
elif isdict(left) and isdict(right):
159-
explanation = _compare_eq_dict(left, right, verbose)
160-
elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
161-
type_fn = (isdatacls, isattrs)
162-
explanation = _compare_eq_cls(left, right, verbose, type_fn)
163-
elif verbose > 0:
164-
explanation = _compare_eq_verbose(left, right)
165-
if isiterable(left) and isiterable(right):
166-
expl = _compare_eq_iterable(left, right, verbose)
167-
if explanation is not None:
168-
explanation.extend(expl)
169-
else:
170-
explanation = expl
151+
explanation = _compare_eq_any(left, right, verbose)
171152
elif op == "not in":
172153
if istext(left) and istext(right):
173154
explanation = _notin_text(left, right, verbose)
@@ -187,6 +168,28 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
187168
return [summary] + explanation
188169

189170

171+
def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
172+
explanation = [] # type: List[str]
173+
if istext(left) and istext(right):
174+
explanation = _diff_text(left, right, verbose)
175+
else:
176+
if issequence(left) and issequence(right):
177+
explanation = _compare_eq_sequence(left, right, verbose)
178+
elif isset(left) and isset(right):
179+
explanation = _compare_eq_set(left, right, verbose)
180+
elif isdict(left) and isdict(right):
181+
explanation = _compare_eq_dict(left, right, verbose)
182+
elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
183+
type_fn = (isdatacls, isattrs)
184+
explanation = _compare_eq_cls(left, right, verbose, type_fn)
185+
elif verbose > 0:
186+
explanation = _compare_eq_verbose(left, right)
187+
if isiterable(left) and isiterable(right):
188+
expl = _compare_eq_iterable(left, right, verbose)
189+
explanation.extend(expl)
190+
return explanation
191+
192+
190193
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
191194
"""Return the explanation for the diff between text.
192195
@@ -439,7 +442,8 @@ def _compare_eq_cls(
439442
explanation += ["Differing attributes:"]
440443
for field in diff:
441444
explanation += [
442-
("%s: %r != %r") % (field, getattr(left, field), getattr(right, field))
445+
("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)),
446+
*_compare_eq_any(getattr(left, field), getattr(right, field), verbose),
443447
]
444448
return explanation
445449

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dataclasses import dataclass
2+
from dataclasses import field
3+
4+
5+
@dataclass
6+
class SimpleDataObject:
7+
field_a: int = field()
8+
field_b: int = field()
9+
10+
11+
@dataclass
12+
class ComplexDataObject:
13+
field_a: SimpleDataObject = field()
14+
field_b: SimpleDataObject = field()
15+
16+
17+
def test_recursive_dataclasses():
18+
19+
left = ComplexDataObject(SimpleDataObject(1, "b"), SimpleDataObject(2, "c"),)
20+
right = ComplexDataObject(SimpleDataObject(1, "b"), SimpleDataObject(3, "c"),)
21+
22+
assert left == right

testing/test_assertion.py

+18
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,24 @@ def test_dataclasses(self, testdir):
752752
"*Omitting 1 identical items, use -vv to show*",
753753
"*Differing attributes:*",
754754
"*field_b: 'b' != 'c'*",
755+
"*- c*",
756+
"*+ b*",
757+
]
758+
)
759+
760+
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
761+
def test_recursive_dataclasses(self, testdir):
762+
p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py")
763+
result = testdir.runpytest(p)
764+
result.assert_outcomes(failed=1, passed=0)
765+
result.stdout.fnmatch_lines(
766+
[
767+
"*Omitting 1 identical items, use -vv to show*",
768+
"*Differing attributes:*",
769+
"*field_b: SimpleDataObject(field_a=2, field_b='c') != SimpleDataObject(field_a=3, field_b='c')*",
770+
"*Omitting 1 identical items, use -vv to show*",
771+
"*Differing attributes:*",
772+
"*field_a: 2 != 3*",
755773
]
756774
)
757775

0 commit comments

Comments
 (0)