Skip to content

Commit e573610

Browse files
sryapfacebook-github-bot
authored andcommitted
Add docstrings for sparse ops (2) (pytorch#3185)
Summary: X-link: facebookresearch/FBGEMM#280 Add a docstring for - torch.ops.fbgemm.segment_sum_csr - torch.ops.fbgemm.keyed_jagged_index_select_dim1 Pull Request resolved: pytorch#3185 Reviewed By: shintaro-iwasaki Differential Revision: D63520029 Pulled By: sryap fbshipit-source-id: 737cd62de83d5c31992aba8898231803545c393a
1 parent 8f6d96d commit e573610

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst

+4
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ Sparse Operators
1212
.. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum
1313

1414
.. autofunction:: torch.ops.fbgemm.offsets_range
15+
16+
.. autofunction:: torch.ops.fbgemm.segment_sum_csr
17+
18+
.. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1

fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py

+115
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,118 @@
204204
4, 5, 6], device='cuda:0')
205205
""",
206206
)
207+
208+
add_docs(
209+
torch.ops.fbgemm.segment_sum_csr,
210+
"""
211+
segment_sum_csr(batch_size, csr_seg, values) -> Tensor
212+
213+
Sum values within each segment on the given CSR data where each row has the
214+
same number of non-zero elements.
215+
216+
Args:
217+
batch_size (int): The row stride (number of non-zero elements in each row)
218+
219+
csr_seg (Tensor): The complete cumulative sum of segment lengths. A segment
220+
length is the number of rows within each segment. The shape of the
221+
`csr_seg` tensor is `num_segments + 1` where `num_segments` is the
222+
number of segments.
223+
224+
values (Tensor): The values tensor to be segment summed. The number of
225+
elements in the tensor must be multiple of `batch_size`
226+
227+
Returns:
228+
A tensor containing the segment sum results. Shape is the number of
229+
segments.
230+
231+
**Example:**
232+
233+
>>> batch_size = 2
234+
>>> # Randomize inputs
235+
>>> lengths = torch.tensor([3, 4, 1], dtype=torch.int, device="cuda")
236+
>>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
237+
>>> print(offsets)
238+
tensor([0, 3, 7, 8], device='cuda:0', dtype=torch.int32)
239+
>>> values = torch.randn(lengths.sum().item() * batch_size, dtype=torch.float32, device="cuda")
240+
>>> print(values)
241+
tensor([-2.8642e-01, 1.6451e+00, 1.1322e-01, 1.7335e+00, -8.4700e-02,
242+
-1.2756e+00, 1.1206e+00, 9.6385e-01, 6.2122e-02, 1.3104e-03,
243+
2.2667e-01, 2.3113e+00, -1.1948e+00, -1.5463e-01, -1.0031e+00,
244+
-3.5531e-01], device='cuda:0')
245+
>>> # Invoke
246+
>>> torch.ops.fbgemm.segment_sum_csr(batch_size, offsets, values)
247+
tensor([ 1.8451, 3.3365, -1.3584], device='cuda:0')
248+
""",
249+
)
250+
251+
add_docs(
252+
torch.ops.fbgemm.keyed_jagged_index_select_dim1,
253+
"""
254+
keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size, weights=None, selected_lengths_sum=None) -> List[Tensor]
255+
256+
Perform an index select operation on the batch dimension (dim 1) of the given
257+
keyed jagged tensor (KJT) input. The same samples in the batch of every key
258+
will be selected. Note that each KJT has 3 dimensions: (`num_keys`, `batch_size`,
259+
jagged dim), where `num_keys` is the number of keys, and `batch_size` is the
260+
batch size. This operator is similar to a permute operator.
261+
262+
Args:
263+
values (Tensor): The KJT values tensor which contains concatenated data of
264+
every key
265+
266+
lengths (Tensor): The KJT lengths tensor which contains the jagged shapes
267+
of every key (dim 0) and sample (dim 1). Shape is `num_keys *
268+
batch_size`
269+
270+
offsets (Tensor): The KJT offsets tensor which is the complete cumulative
271+
sum of `lengths`. Shape is `num_keys * batch_size + 1`
272+
273+
indices (Tensor): The indices to select, i.e., samples in the batch to
274+
select. The values of `indices` must be >= 0 and < `batch_size`
275+
276+
batch_size (int): The batch size (dim 1 of KJT)
277+
278+
weights (Optional[Tensor] = None): An optional float tensor which will be
279+
selected the same way as `values`. Thus, it must have the same shape as
280+
`values`
281+
282+
selected_lengths_sum (Optional[int] = None): An optional value that
283+
represents the total number of elements in the index select data
284+
(output shape). If not provided, the operator will compute this data
285+
which may cause a device-host synchronization (if using GPU). Thus, it
286+
is recommended to supply this value to avoid such the synchronization.
287+
288+
Returns:
289+
The index-select KJT tensor (as a list of values, lengths, and weights if
290+
`weights` is not None)
291+
292+
**Example:**
293+
294+
>>> num_keys = 2
295+
>>> batch_size = 4
296+
>>> output_size = 3
297+
>>> # Randomize inputs
298+
>>> lengths = torch.randint(low=0, high=10, size=(batch_size * num_keys,), dtype=torch.int64, device="cuda")
299+
>>> print(lengths)
300+
tensor([8, 5, 1, 4, 2, 7, 5, 9], device='cuda:0')
301+
>>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
302+
>>> print(offsets)
303+
tensor([ 0, 8, 13, 14, 18, 20, 27, 32, 41], device='cuda:0')
304+
>>> indices = torch.randint(low=0, high=batch_size, size=(output_size,), dtype=torch.int64, device="cuda")
305+
>>> print(indices)
306+
tensor([3, 3, 1], device='cuda:0')
307+
>>> # Use torch.arange instead of torch.randn to simplify the example
308+
>>> values = torch.arange(lengths.sum().item(), dtype=torch.float32, device="cuda")
309+
>>> print(values)
310+
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
311+
14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
312+
28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
313+
device='cuda:0')
314+
>>> # Invoke. Output = (output, lengths)
315+
>>> torch.ops.fbgemm.keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size)
316+
[tensor([14., 15., 16., 17., 14., 15., 16., 17., 8., 9., 10., 11., 12., 32.,
317+
33., 34., 35., 36., 37., 38., 39., 40., 32., 33., 34., 35., 36., 37.,
318+
38., 39., 40., 20., 21., 22., 23., 24., 25., 26.], device='cuda:0'),
319+
tensor([4, 4, 5, 9, 9, 7], device='cuda:0')]
320+
""",
321+
)

0 commit comments

Comments
 (0)