Skip to content

Commit d4d378e

Browse files
committed
Better solution for sandialabs#380 and also sandialabs#368. The function isn't changed but the instructions are.
1 parent e60a2a7 commit d4d378e

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

pyttb/tensor.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,17 @@ def from_function(
217217
function_handle: Callable[[Tuple[int, ...]], np.ndarray],
218218
shape: Shape,
219219
) -> tensor:
220-
"""Construct a :class:`pyttb.tensor` with data generated by function.
220+
"""Construct :class:`pyttb.tensor` with data generated by given function.
221221
222222
Parameters
223223
----------
224224
function_handle:
225-
A function that can accept an integer length and
226-
return a :class:`numpy.ndarray` vector of that length.
225+
A function that takes a tuple of integers and returns a
226+
:class:`numpy.ndarray`. The array should be in Fortran order to avoid
227+
warnings of data being copied. The data will be reshaped to the shape,
228+
so returning a vector of length equal to the product of the shape is fine.
227229
shape:
228-
Shape of the resulting tensor.
230+
Shape of the resulting tensor; e.g., a tuple of integers.
229231
230232
Returns
231233
-------
@@ -236,46 +238,46 @@ def from_function(
236238
Create a :class:`pyttb.tensor` with entries drawn from a normal distribution
237239
using :func:`numpy.random.randn`::
238240
239-
>>> np.random.seed(0)
240-
>>> T = ttb.tensor.from_function(np.random.randn, (4, 3, 2))
241+
>>> randn = lambda s : np.random.randn(np.prod(s))
242+
>>> np.random.seed(0) # reproducibility
243+
>>> T = ttb.tensor.from_function(randn, (4, 3, 2))
241244
>>> print(T)
242245
tensor of shape (4, 3, 2) with order F
243246
data[:, :, 0] =
244247
[[ 1.76405235 1.86755799 -0.10321885]
245-
[ 0.40015721 -0.97727788 0.4105985 ]
246-
[ 0.97873798 0.95008842 0.14404357]
247-
[ 2.2408932 -0.15135721 1.45427351]]
248+
[ 0.40015721 -0.97727788 0.4105985 ]
249+
[ 0.97873798 0.95008842 0.14404357]
250+
[ 2.2408932 -0.15135721 1.45427351]]
248251
data[:, :, 1] =
249252
[[ 0.76103773 1.49407907 -2.55298982]
250-
[ 0.12167502 -0.20515826 0.6536186 ]
251-
[ 0.44386323 0.3130677 0.8644362 ]
252-
[ 0.33367433 -0.85409574 -0.74216502]]
253+
[ 0.12167502 -0.20515826 0.6536186 ]
254+
[ 0.44386323 0.3130677 0.8644362 ]
255+
[ 0.33367433 -0.85409574 -0.74216502]]
253256
254257
Create a :class:`pyttb.tensor` with all entries equal to 1
255258
using :func:`numpy.ones`::
256259
257-
>>> T = ttb.tensor.from_function(np.ones, (2, 3, 4))
260+
>>> T = ttb.tensor.from_function(lambda s: np.ones(s,order='F'), (2, 3, 4))
258261
>>> print(T)
259262
tensor of shape (2, 3, 4) with order F
260263
data[:, :, 0] =
261264
[[1. 1. 1.]
262-
[1. 1. 1.]]
265+
[1. 1. 1.]]
263266
data[:, :, 1] =
264267
[[1. 1. 1.]
265-
[1. 1. 1.]]
268+
[1. 1. 1.]]
266269
data[:, :, 2] =
267270
[[1. 1. 1.]
268-
[1. 1. 1.]]
271+
[1. 1. 1.]]
269272
data[:, :, 3] =
270273
[[1. 1. 1.]
271-
[1. 1. 1.]]
274+
[1. 1. 1.]]
272275
"""
273276
# Check size
274277
shape = parse_shape(shape)
275278

276279
# Generate data
277-
totalsize = prod(shape)
278-
data = function_handle(totalsize)
280+
data = function_handle(shape)
279281

280282
# Create the tensor
281283
return cls(data, shape, copy=False)

0 commit comments

Comments
 (0)