-
-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Miscellaneous improvements to approx() #3741
Merged
Merged
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
7d8688d
Reflect dimension in approx repr for numpy arrays.
kalekundert ad305e7
Improve docstrings for Approx classes.
kalekundert cd2085e
approx(): Detect type errors earlier.
kalekundert 032db15
Let black reformat the code...
kalekundert d024919
Fix the unused import.
kalekundert 327fe4c
Update the changelog.
kalekundert bf7c188
Improve error message for invalid types passed to pytest.approx()
nicoddemus 8e2ed76
Create appropriate CHANGELOG entries
nicoddemus 098dca3
Use {!r} for a few other messages as well
nicoddemus 611d254
Improve error checking messages: add position and use pprint
nicoddemus 6e32a1f
Use parametrize in repr test for nd arrays
nicoddemus 5003bae
Fix 'at' string for non-numeric messages in approx()
nicoddemus ad5ddaf
Simplify is_numpy_array as suggested in review
nicoddemus 2a2f888
Move recursive_map from local to free function
nicoddemus 43664d7
Use ids for parametrized values in test_expected_value_type_error
nicoddemus a5c0fb7
Rename recursive_map -> _recursive_list_map as requested in review
nicoddemus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Raise immediately if ``approx()`` is given an expected value of a type it doesn't understand (e.g. strings, nested dicts, etc.). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Correctly represent the dimensions of an numpy array when calling ``repr()`` on ``approx()``. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
import math | ||
import pprint | ||
import sys | ||
from numbers import Number | ||
from decimal import Decimal | ||
|
||
import py | ||
from six.moves import zip, filterfalse | ||
|
@@ -30,6 +33,15 @@ def _cmp_raises_type_error(self, other): | |
) | ||
|
||
|
||
def _non_numeric_type_error(value, at): | ||
at_str = " at {}".format(at) if at else "" | ||
return TypeError( | ||
"cannot make approximate comparisons to non-numeric values: {!r} {}".format( | ||
value, at_str | ||
) | ||
) | ||
|
||
|
||
# builtin pytest.approx helper | ||
|
||
|
||
|
@@ -39,15 +51,17 @@ class ApproxBase(object): | |
or sequences of numbers. | ||
""" | ||
|
||
# Tell numpy to use our `__eq__` operator instead of its | ||
# Tell numpy to use our `__eq__` operator instead of its. | ||
__array_ufunc__ = None | ||
__array_priority__ = 100 | ||
|
||
def __init__(self, expected, rel=None, abs=None, nan_ok=False): | ||
__tracebackhide__ = True | ||
self.expected = expected | ||
self.abs = abs | ||
self.rel = rel | ||
self.nan_ok = nan_ok | ||
self._check_type() | ||
|
||
def __repr__(self): | ||
raise NotImplementedError | ||
|
@@ -75,21 +89,32 @@ def _yield_comparisons(self, actual): | |
""" | ||
raise NotImplementedError | ||
|
||
def _check_type(self): | ||
""" | ||
Raise a TypeError if the expected value is not a valid type. | ||
""" | ||
# This is only a concern if the expected value is a sequence. In every | ||
# other case, the approx() function ensures that the expected value has | ||
# a numeric type. For this reason, the default is to do nothing. The | ||
# classes that deal with sequences should reimplement this method to | ||
# raise if there are any non-numeric elements in the sequence. | ||
pass | ||
|
||
|
||
def recursive_map(f, x): | ||
if isinstance(x, list): | ||
return list(recursive_map(f, xi) for xi in x) | ||
else: | ||
return f(x) | ||
|
||
|
||
class ApproxNumpy(ApproxBase): | ||
""" | ||
Perform approximate comparisons for numpy arrays. | ||
Perform approximate comparisons where the expected value is numpy array. | ||
""" | ||
|
||
def __repr__(self): | ||
# It might be nice to rewrite this function to account for the | ||
# shape of the array... | ||
import numpy as np | ||
|
||
list_scalars = [] | ||
for x in np.ndindex(self.expected.shape): | ||
list_scalars.append(self._approx_scalar(np.asscalar(self.expected[x]))) | ||
|
||
list_scalars = recursive_map(self._approx_scalar, self.expected.tolist()) | ||
return "approx({!r})".format(list_scalars) | ||
|
||
if sys.version_info[0] == 2: | ||
|
@@ -128,8 +153,8 @@ def _yield_comparisons(self, actual): | |
|
||
class ApproxMapping(ApproxBase): | ||
""" | ||
Perform approximate comparisons for mappings where the values are numbers | ||
(the keys can be anything). | ||
Perform approximate comparisons where the expected value is a mapping with | ||
numeric values (the keys can be anything). | ||
""" | ||
|
||
def __repr__(self): | ||
|
@@ -147,10 +172,20 @@ def _yield_comparisons(self, actual): | |
for k in self.expected.keys(): | ||
yield actual[k], self.expected[k] | ||
|
||
def _check_type(self): | ||
__tracebackhide__ = True | ||
for key, value in self.expected.items(): | ||
if isinstance(value, type(self.expected)): | ||
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}" | ||
raise TypeError(msg.format(key, value, pprint.pformat(self.expected))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lovely 👍 |
||
elif not isinstance(value, Number): | ||
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key)) | ||
|
||
|
||
class ApproxSequence(ApproxBase): | ||
""" | ||
Perform approximate comparisons for sequences of numbers. | ||
Perform approximate comparisons where the expected value is a sequence of | ||
numbers. | ||
""" | ||
|
||
def __repr__(self): | ||
|
@@ -169,10 +204,21 @@ def __eq__(self, actual): | |
def _yield_comparisons(self, actual): | ||
return zip(actual, self.expected) | ||
|
||
def _check_type(self): | ||
__tracebackhide__ = True | ||
for index, x in enumerate(self.expected): | ||
if isinstance(x, type(self.expected)): | ||
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}" | ||
raise TypeError(msg.format(x, index, pprint.pformat(self.expected))) | ||
elif not isinstance(x, Number): | ||
raise _non_numeric_type_error( | ||
self.expected, at="index {}".format(index) | ||
) | ||
|
||
|
||
class ApproxScalar(ApproxBase): | ||
""" | ||
Perform approximate comparisons for single numbers only. | ||
Perform approximate comparisons where the expected value is a single number. | ||
""" | ||
|
||
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 | ||
|
@@ -286,7 +332,9 @@ def set_default(x, default): | |
|
||
|
||
class ApproxDecimal(ApproxScalar): | ||
from decimal import Decimal | ||
""" | ||
Perform approximate comparisons where the expected value is a decimal. | ||
""" | ||
|
||
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12") | ||
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6") | ||
|
@@ -445,32 +493,35 @@ def approx(expected, rel=None, abs=None, nan_ok=False): | |
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__ | ||
""" | ||
|
||
from decimal import Decimal | ||
|
||
# Delegate the comparison to a class that knows how to deal with the type | ||
# of the expected value (e.g. int, float, list, dict, numpy.array, etc). | ||
# | ||
# This architecture is really driven by the need to support numpy arrays. | ||
# The only way to override `==` for arrays without requiring that approx be | ||
# the left operand is to inherit the approx object from `numpy.ndarray`. | ||
# But that can't be a general solution, because it requires (1) numpy to be | ||
# installed and (2) the expected value to be a numpy array. So the general | ||
# solution is to delegate each type of expected value to a different class. | ||
# The primary responsibility of these classes is to implement ``__eq__()`` | ||
# and ``__repr__()``. The former is used to actually check if some | ||
# "actual" value is equivalent to the given expected value within the | ||
# allowed tolerance. The latter is used to show the user the expected | ||
# value and tolerance, in the case that a test failed. | ||
# | ||
# This has the advantage that it made it easy to support mapping types | ||
# (i.e. dict). The old code accepted mapping types, but would only compare | ||
# their keys, which is probably not what most people would expect. | ||
# The actual logic for making approximate comparisons can be found in | ||
# ApproxScalar, which is used to compare individual numbers. All of the | ||
# other Approx classes eventually delegate to this class. The ApproxBase | ||
# class provides some convenient methods and overloads, but isn't really | ||
# essential. | ||
|
||
if _is_numpy_array(expected): | ||
cls = ApproxNumpy | ||
__tracebackhide__ = True | ||
|
||
if isinstance(expected, Decimal): | ||
cls = ApproxDecimal | ||
elif isinstance(expected, Number): | ||
cls = ApproxScalar | ||
elif isinstance(expected, Mapping): | ||
cls = ApproxMapping | ||
elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES): | ||
cls = ApproxSequence | ||
elif isinstance(expected, Decimal): | ||
cls = ApproxDecimal | ||
elif _is_numpy_array(expected): | ||
cls = ApproxNumpy | ||
else: | ||
cls = ApproxScalar | ||
raise _non_numeric_type_error(expected, at=None) | ||
|
||
return cls(expected, rel, abs, nan_ok) | ||
|
||
|
@@ -480,17 +531,11 @@ def _is_numpy_array(obj): | |
Return true if the given object is a numpy array. Make a special effort to | ||
avoid importing numpy unless it's really necessary. | ||
""" | ||
import inspect | ||
|
||
for cls in inspect.getmro(type(obj)): | ||
if cls.__module__ == "numpy": | ||
try: | ||
import numpy as np | ||
|
||
return isinstance(obj, np.ndarray) | ||
except ImportError: | ||
pass | ||
import sys | ||
|
||
np = sys.modules.get("numpy") | ||
if np is not None: | ||
return isinstance(obj, np.ndarray) | ||
return False | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little hestiant about this being a module-level function, because it only works for nested lists (e.g. not tuples or any other kind of iterable). That's fine for
ApproxNumpy.__repr__()
, becausenumpy.tolist()
is guaranteed to return nested lists, but this isn't really a generally useful function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You got a point, but I think it doesn't hurt leaving it there; it is not part of the public API after all. @RonnyPfannschmidt any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name it
recursive_list_map
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, renamed it
_recursive_list_map
as well to make it clear it is an internal API.