@@ -31,6 +31,10 @@ class ApproxBase(object):
31
31
or sequences of numbers.
32
32
"""
33
33
34
+ # Tell numpy to use our `__eq__` operator instead of its
35
+ __array_ufunc__ = None
36
+ __array_priority__ = 100
37
+
34
38
def __init__ (self , expected , rel = None , abs = None , nan_ok = False ):
35
39
self .expected = expected
36
40
self .abs = abs
@@ -69,39 +73,46 @@ class ApproxNumpy(ApproxBase):
69
73
Perform approximate comparisons for numpy arrays.
70
74
"""
71
75
72
- # Tell numpy to use our `__eq__` operator instead of its.
73
- __array_priority__ = 100
74
-
75
76
def __repr__ (self ):
76
77
# It might be nice to rewrite this function to account for the
77
78
# shape of the array...
79
+ import numpy as np
80
+
78
81
return "approx({0!r})" .format (list (
79
- self ._approx_scalar (x ) for x in self .expected ))
82
+ self ._approx_scalar (x ) for x in np . asarray ( self .expected ) ))
80
83
81
84
if sys .version_info [0 ] == 2 :
82
85
__cmp__ = _cmp_raises_type_error
83
86
84
87
def __eq__ (self , actual ):
85
88
import numpy as np
86
89
87
- try :
88
- actual = np .asarray (actual )
89
- except : # noqa
90
- raise TypeError ("cannot compare '{0}' to numpy.ndarray" .format (actual ))
90
+ # self.expected is supposed to always be an array here
91
91
92
- if actual .shape != self .expected .shape :
92
+ if not np .isscalar (actual ):
93
+ try :
94
+ actual = np .asarray (actual )
95
+ except : # noqa
96
+ raise TypeError ("cannot compare '{0}' to numpy.ndarray" .format (actual ))
97
+
98
+ if not np .isscalar (actual ) and actual .shape != self .expected .shape :
93
99
return False
94
100
95
101
return ApproxBase .__eq__ (self , actual )
96
102
97
103
def _yield_comparisons (self , actual ):
98
104
import numpy as np
99
105
100
- # We can be sure that `actual` is a numpy array, because it's
101
- # casted in `__eq__` before being passed to `ApproxBase.__eq__`,
102
- # which is the only method that calls this one.
103
- for i in np .ndindex (self .expected .shape ):
104
- yield actual [i ], self .expected [i ]
106
+ # `actual` can either be a numpy array or a scalar, it is treated in
107
+ # `__eq__` before being passed to `ApproxBase.__eq__`, which is the
108
+ # only method that calls this one.
109
+
110
+ if np .isscalar (actual ):
111
+ for i in np .ndindex (self .expected .shape ):
112
+ yield actual , np .asscalar (self .expected [i ])
113
+ else :
114
+ for i in np .ndindex (self .expected .shape ):
115
+ yield np .asscalar (actual [i ]), np .asscalar (self .expected [i ])
105
116
106
117
107
118
class ApproxMapping (ApproxBase ):
@@ -131,9 +142,6 @@ class ApproxSequence(ApproxBase):
131
142
Perform approximate comparisons for sequences of numbers.
132
143
"""
133
144
134
- # Tell numpy to use our `__eq__` operator instead of its.
135
- __array_priority__ = 100
136
-
137
145
def __repr__ (self ):
138
146
seq_type = type (self .expected )
139
147
if seq_type not in (tuple , list , set ):
@@ -189,6 +197,8 @@ def __eq__(self, actual):
189
197
Return true if the given value is equal to the expected value within
190
198
the pre-specified tolerance.
191
199
"""
200
+ if _is_numpy_array (actual ):
201
+ return ApproxNumpy (actual , self .abs , self .rel , self .nan_ok ) == self .expected
192
202
193
203
# Short-circuit exact equality.
194
204
if actual == self .expected :
@@ -308,12 +318,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
308
318
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
309
319
True
310
320
311
- And ``numpy`` arrays::
321
+ ``numpy`` arrays::
312
322
313
323
>>> import numpy as np # doctest: +SKIP
314
324
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
315
325
True
316
326
327
+ And for a ``numpy`` array against a scalar::
328
+
329
+ >>> import numpy as np # doctest: +SKIP
330
+ >>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
331
+ True
332
+
317
333
By default, ``approx`` considers numbers within a relative tolerance of
318
334
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
319
335
This treatment would lead to surprising results if the expected value was
0 commit comments