Skip to content

Commit 86d6804

Browse files
authored
Merge pull request #3313 from tadeu/approx-array-scalar
Add support for `pytest.approx` comparisons between array and scalar
2 parents cbb2c55 + a754f00 commit 86d6804

File tree

3 files changed

+57
-18
lines changed

3 files changed

+57
-18
lines changed

_pytest/python_api.py

+34-18
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class ApproxBase(object):
3131
or sequences of numbers.
3232
"""
3333

34+
# Tell numpy to use our `__eq__` operator instead of its
35+
__array_ufunc__ = None
36+
__array_priority__ = 100
37+
3438
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
3539
self.expected = expected
3640
self.abs = abs
@@ -69,39 +73,46 @@ class ApproxNumpy(ApproxBase):
6973
Perform approximate comparisons for numpy arrays.
7074
"""
7175

72-
# Tell numpy to use our `__eq__` operator instead of its.
73-
__array_priority__ = 100
74-
7576
def __repr__(self):
7677
# It might be nice to rewrite this function to account for the
7778
# shape of the array...
79+
import numpy as np
80+
7881
return "approx({0!r})".format(list(
79-
self._approx_scalar(x) for x in self.expected))
82+
self._approx_scalar(x) for x in np.asarray(self.expected)))
8083

8184
if sys.version_info[0] == 2:
8285
__cmp__ = _cmp_raises_type_error
8386

8487
def __eq__(self, actual):
8588
import numpy as np
8689

87-
try:
88-
actual = np.asarray(actual)
89-
except: # noqa
90-
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
90+
# self.expected is supposed to always be an array here
9191

92-
if actual.shape != self.expected.shape:
92+
if not np.isscalar(actual):
93+
try:
94+
actual = np.asarray(actual)
95+
except: # noqa
96+
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
97+
98+
if not np.isscalar(actual) and actual.shape != self.expected.shape:
9399
return False
94100

95101
return ApproxBase.__eq__(self, actual)
96102

97103
def _yield_comparisons(self, actual):
98104
import numpy as np
99105

100-
# We can be sure that `actual` is a numpy array, because it's
101-
# casted in `__eq__` before being passed to `ApproxBase.__eq__`,
102-
# which is the only method that calls this one.
103-
for i in np.ndindex(self.expected.shape):
104-
yield actual[i], self.expected[i]
106+
# `actual` can either be a numpy array or a scalar, it is treated in
107+
# `__eq__` before being passed to `ApproxBase.__eq__`, which is the
108+
# only method that calls this one.
109+
110+
if np.isscalar(actual):
111+
for i in np.ndindex(self.expected.shape):
112+
yield actual, np.asscalar(self.expected[i])
113+
else:
114+
for i in np.ndindex(self.expected.shape):
115+
yield np.asscalar(actual[i]), np.asscalar(self.expected[i])
105116

106117

107118
class ApproxMapping(ApproxBase):
@@ -131,9 +142,6 @@ class ApproxSequence(ApproxBase):
131142
Perform approximate comparisons for sequences of numbers.
132143
"""
133144

134-
# Tell numpy to use our `__eq__` operator instead of its.
135-
__array_priority__ = 100
136-
137145
def __repr__(self):
138146
seq_type = type(self.expected)
139147
if seq_type not in (tuple, list, set):
@@ -189,6 +197,8 @@ def __eq__(self, actual):
189197
Return true if the given value is equal to the expected value within
190198
the pre-specified tolerance.
191199
"""
200+
if _is_numpy_array(actual):
201+
return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected
192202

193203
# Short-circuit exact equality.
194204
if actual == self.expected:
@@ -308,12 +318,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
308318
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
309319
True
310320
311-
And ``numpy`` arrays::
321+
``numpy`` arrays::
312322
313323
>>> import numpy as np # doctest: +SKIP
314324
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
315325
True
316326
327+
And for a ``numpy`` array against a scalar::
328+
329+
>>> import numpy as np # doctest: +SKIP
330+
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
331+
True
332+
317333
By default, ``approx`` considers numbers within a relative tolerance of
318334
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
319335
This treatment would lead to surprising results if the expected value was

changelog/3312.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``pytest.approx`` now accepts comparing a numpy array with a scalar.

testing/python/approx.py

+22
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,25 @@ def test_comparison_operator_type_error(self, op):
391391
"""
392392
with pytest.raises(TypeError):
393393
op(1, approx(1, rel=1e-6, abs=1e-12))
394+
395+
def test_numpy_array_with_scalar(self):
396+
np = pytest.importorskip('numpy')
397+
398+
actual = np.array([1 + 1e-7, 1 - 1e-8])
399+
expected = 1.0
400+
401+
assert actual == approx(expected, rel=5e-7, abs=0)
402+
assert actual != approx(expected, rel=5e-8, abs=0)
403+
assert approx(expected, rel=5e-7, abs=0) == actual
404+
assert approx(expected, rel=5e-8, abs=0) != actual
405+
406+
def test_numpy_scalar_with_array(self):
407+
np = pytest.importorskip('numpy')
408+
409+
actual = 1.0
410+
expected = np.array([1 + 1e-7, 1 - 1e-8])
411+
412+
assert actual == approx(expected, rel=5e-7, abs=0)
413+
assert actual != approx(expected, rel=5e-8, abs=0)
414+
assert approx(expected, rel=5e-7, abs=0) == actual
415+
assert approx(expected, rel=5e-8, abs=0) != actual

0 commit comments

Comments
 (0)