Skip to content

Commit

Permalink
Use exact comparison for bool in approx()
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Nov 30, 2021
1 parent fa240b0 commit f119657
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog/9353.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`approx()` now uses strict equality when `type(expected) == bool`.
33 changes: 23 additions & 10 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,16 @@ def _repr_compare(self, other_side: Mapping[object, float]) -> List[str]:
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
max_rel_diff = max(
max_rel_diff,
abs((approx_value.expected - other_value) / approx_value.expected),
)
try:
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
)
except ZeroDivisionError:
pass
different_ids.append(approx_key)

message_data = [
Expand Down Expand Up @@ -395,8 +401,12 @@ def __repr__(self) -> str:
# Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected) # type: ignore[arg-type]
if (
isinstance(self.expected, bool)
or (not isinstance(self.expected, (Complex, Decimal)))
or math.isinf(
abs(self.expected) or isinstance(self.expected, bool) # type: ignore[arg-type]
)
):
return str(self.expected)

Expand Down Expand Up @@ -424,17 +434,20 @@ def __eq__(self, actual) -> bool:
# numpy<1.13. See #3748.
return all(self.__eq__(a) for a in asarray.flat)

# Short-circuit exact equality.
if actual == self.expected:
# Short-circuit exact equality, except for bool
if isinstance(self.expected, bool) and not isinstance(actual, bool):
return False
elif actual == self.expected:
return True

# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
# __sub__, and __float__ are defined. Also, consider bool to be
# nonnumeric, even though it has the required arithmetic.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
) or isinstance(self.expected, bool):
return False

# Allow the user to control whether NaNs are considered equal to each
Expand Down
23 changes: 22 additions & 1 deletion testing/python/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,27 @@ def do_assert(lhs, rhs, expected_message, verbosity_level=0):
return do_assert


SOME_FLOAT = r"[+-]?([0-9]*[.])?[0-9]+\s*"
SOME_FLOAT = r"[+-]?((?:([0-9]*[.])?[0-9]+(e-?[0-9]+)?)|inf|nan)\s*"
SOME_INT = r"[0-9]+\s*"


class TestApprox:
def test_error_messages(self, assert_approx_raises_regex):
np = pytest.importorskip("numpy")

# treat bool exactly
assert_approx_raises_regex(
{"a": 1.0, "b": True},
{"a": 1.0, "b": False},
[
" comparison failed. Mismatched elements: 1 / 2:",
f" Max absolute difference: {SOME_FLOAT}",
f" Max relative difference: {SOME_FLOAT}",
r" Index\s+\| Obtained\s+\| Expected",
r".*(True|False)\s+",
],
)

assert_approx_raises_regex(
2.0,
1.0,
Expand Down Expand Up @@ -546,6 +559,13 @@ def test_complex(self):
assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a

def test_bool(self):
assert True == approx(True)
assert False == approx(False)
assert True != approx(False)
assert True != approx(False, abs=2)
assert 1 != approx(True)

def test_list(self):
actual = [1 + 1e-7, 2 + 1e-8]
expected = [1, 2]
Expand Down Expand Up @@ -611,6 +631,7 @@ def test_dict_wrong_len(self):
def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": True} != pytest.approx({"a": 1.0, "b": False}, abs=2)

def test_dict_vs_other(self):
assert 1 != approx({"a": 0})
Expand Down

0 comments on commit f119657

Please sign in to comment.