|
204 | 204 | 4, 5, 6], device='cuda:0')
|
205 | 205 | """,
|
206 | 206 | )
|
| 207 | + |
| 208 | +add_docs( |
| 209 | + torch.ops.fbgemm.segment_sum_csr, |
| 210 | + """ |
| 211 | +segment_sum_csr(batch_size, csr_seg, values) -> Tensor |
| 212 | +
|
| 213 | +Sum values within each segment on the given CSR data where each row has the |
| 214 | +same number of non-zero elements. |
| 215 | +
|
| 216 | +Args: |
| 217 | + batch_size (int): The row stride (number of non-zero elements in each row) |
| 218 | +
|
| 219 | + csr_seg (Tensor): The complete cumulative sum of segment lengths. A segment |
| 220 | + length is the number of rows within each segment. The shape of the |
| 221 | + `csr_seg` tensor is `num_segments + 1` where `num_segments` is the |
| 222 | + number of segments. |
| 223 | +
|
| 224 | + values (Tensor): The values tensor to be segment summed. The number of |
| 225 | + elements in the tensor must be multiple of `batch_size` |
| 226 | +
|
| 227 | +Returns: |
| 228 | + A tensor containing the segment sum results. Shape is the number of |
| 229 | + segments. |
| 230 | +
|
| 231 | +**Example:** |
| 232 | +
|
| 233 | + >>> batch_size = 2 |
| 234 | + >>> # Randomize inputs |
| 235 | + >>> lengths = torch.tensor([3, 4, 1], dtype=torch.int, device="cuda") |
| 236 | + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) |
| 237 | + >>> print(offsets) |
| 238 | + tensor([0, 3, 7, 8], device='cuda:0', dtype=torch.int32) |
| 239 | + >>> values = torch.randn(lengths.sum().item() * batch_size, dtype=torch.float32, device="cuda") |
| 240 | + >>> print(values) |
| 241 | + tensor([-2.8642e-01, 1.6451e+00, 1.1322e-01, 1.7335e+00, -8.4700e-02, |
| 242 | + -1.2756e+00, 1.1206e+00, 9.6385e-01, 6.2122e-02, 1.3104e-03, |
| 243 | + 2.2667e-01, 2.3113e+00, -1.1948e+00, -1.5463e-01, -1.0031e+00, |
| 244 | + -3.5531e-01], device='cuda:0') |
| 245 | + >>> # Invoke |
| 246 | + >>> torch.ops.fbgemm.segment_sum_csr(batch_size, offsets, values) |
| 247 | + tensor([ 1.8451, 3.3365, -1.3584], device='cuda:0') |
| 248 | + """, |
| 249 | +) |
| 250 | + |
| 251 | +add_docs( |
| 252 | + torch.ops.fbgemm.keyed_jagged_index_select_dim1, |
| 253 | + """ |
| 254 | +keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size, weights=None, selected_lengths_sum=None) -> List[Tensor] |
| 255 | +
|
| 256 | +Perform an index select operation on the batch dimension (dim 1) of the given |
| 257 | +keyed jagged tensor (KJT) input. The same samples in the batch of every key |
| 258 | +will be selected. Note that each KJT has 3 dimensions: (`num_keys`, `batch_size`, |
| 259 | +jagged dim), where `num_keys` is the number of keys, and `batch_size` is the |
| 260 | +batch size. This operator is similar to a permute operator. |
| 261 | +
|
| 262 | +Args: |
| 263 | + values (Tensor): The KJT values tensor which contains concatenated data of |
| 264 | + every key |
| 265 | +
|
| 266 | + lengths (Tensor): The KJT lengths tensor which contains the jagged shapes |
| 267 | + of every key (dim 0) and sample (dim 1). Shape is `num_keys * |
| 268 | + batch_size` |
| 269 | +
|
| 270 | + offsets (Tensor): The KJT offsets tensor which is the complete cumulative |
| 271 | + sum of `lengths`. Shape is `num_keys * batch_size + 1` |
| 272 | +
|
| 273 | + indices (Tensor): The indices to select, i.e., samples in the batch to |
| 274 | + select. The values of `indices` must be >= 0 and < `batch_size` |
| 275 | +
|
| 276 | + batch_size (int): The batch size (dim 1 of KJT) |
| 277 | +
|
| 278 | + weights (Optional[Tensor] = None): An optional float tensor which will be |
| 279 | + selected the same way as `values`. Thus, it must have the same shape as |
| 280 | + `values` |
| 281 | +
|
| 282 | + selected_lengths_sum (Optional[int] = None): An optional value that |
| 283 | + represents the total number of elements in the index select data |
| 284 | + (output shape). If not provided, the operator will compute this data |
| 285 | + which may cause a device-host synchronization (if using GPU). Thus, it |
| 286 | + is recommended to supply this value to avoid such the synchronization. |
| 287 | +
|
| 288 | +Returns: |
| 289 | + The index-select KJT tensor (as a list of values, lengths, and weights if |
| 290 | + `weights` is not None) |
| 291 | +
|
| 292 | +**Example:** |
| 293 | +
|
| 294 | + >>> num_keys = 2 |
| 295 | + >>> batch_size = 4 |
| 296 | + >>> output_size = 3 |
| 297 | + >>> # Randomize inputs |
| 298 | + >>> lengths = torch.randint(low=0, high=10, size=(batch_size * num_keys,), dtype=torch.int64, device="cuda") |
| 299 | + >>> print(lengths) |
| 300 | + tensor([8, 5, 1, 4, 2, 7, 5, 9], device='cuda:0') |
| 301 | + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) |
| 302 | + >>> print(offsets) |
| 303 | + tensor([ 0, 8, 13, 14, 18, 20, 27, 32, 41], device='cuda:0') |
| 304 | + >>> indices = torch.randint(low=0, high=batch_size, size=(output_size,), dtype=torch.int64, device="cuda") |
| 305 | + >>> print(indices) |
| 306 | + tensor([3, 3, 1], device='cuda:0') |
| 307 | + >>> # Use torch.arange instead of torch.randn to simplify the example |
| 308 | + >>> values = torch.arange(lengths.sum().item(), dtype=torch.float32, device="cuda") |
| 309 | + >>> print(values) |
| 310 | + tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., |
| 311 | + 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., |
| 312 | + 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.], |
| 313 | + device='cuda:0') |
| 314 | + >>> # Invoke. Output = (output, lengths) |
| 315 | + >>> torch.ops.fbgemm.keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size) |
| 316 | + [tensor([14., 15., 16., 17., 14., 15., 16., 17., 8., 9., 10., 11., 12., 32., |
| 317 | + 33., 34., 35., 36., 37., 38., 39., 40., 32., 33., 34., 35., 36., 37., |
| 318 | + 38., 39., 40., 20., 21., 22., 23., 24., 25., 26.], device='cuda:0'), |
| 319 | + tensor([4, 4, 5, 9, 9, 7], device='cuda:0')] |
| 320 | + """, |
| 321 | +) |
0 commit comments