Skip to content

Commit 0db643a

Browse files
authored
Update pt2_autograd_utils.cpp
1 parent 54e83db commit 0db643a

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,22 @@ Tensor reshape_vbe_output(
5959
}
6060
return grad_output_;
6161
}
62+
Tensor reshape_offsets(
63+
const Tensor& offsets,
64+
const Tensor& B_offsets,
65+
const c10::SymInt max_B,
66+
const int32_t T) {
67+
auto offsets_ = at::empty({T * max_B + 1}, offsets.options());
68+
for (int32_t t = 0; t < T; t++) {
69+
const auto b_begin = B_offsets[t];
70+
const auto b_end = B_offsets[t+1];
71+
const auto values = offsets.slice(0, begin, end);
72+
offsets_.index_put_({t * max_B, t * max_B + end - begin}, values);
73+
offsets_[t * max_B + end - begin : (t + 1) * max_B] = offsets_[end];
74+
}
75+
offsets_[offsets.numel()-1] = offsets[offsets.numel()-1];
76+
return offsets_;
77+
}
78+
79+
6280
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)