Skip to content

Commit

Permalink
[TOSA] Update tosa.slice's start and size to !tosa.shape type
Browse files Browse the repository at this point in the history
* In TOSA 1.0, tosa.slice's `start` and `size` are !tosa.shape types.
  Update tosa.slice in Torch to TOSA in alignment with that.
* Update LIT tests.

Signed-off-by: Justin Ngo <[email protected]>
Change-Id: Icf878ea4dc43ec1af3bd498b5ae96f514fe0f04a
  • Loading branch information
justin-ngo-arm committed Feb 10, 2025
1 parent af2aff6 commit eb349e5
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 72 deletions.
32 changes: 16 additions & 16 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4010,8 +4010,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(startSlice),
rewriter.getDenseI64ArrayAttr(sizeSlice));
tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice));

return success();
}
Expand Down Expand Up @@ -7143,8 +7143,8 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
startSlice[targetDim1] = std::abs(offset);
diagonalTensor = rewriter.create<tosa::SliceOp>(
op->getLoc(), transposedInputType, diagonalTensor,
rewriter.getDenseI64ArrayAttr(startSlice),
rewriter.getDenseI64ArrayAttr(sizeSlice));
tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice));
}

// Apply Reduce Sum to get the result
Expand Down Expand Up @@ -7669,8 +7669,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef<int64_t> unpaddedShape,
auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy);

auto leftPadSlice = rewriter.create<tosa::SliceOp>(
loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice),
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
loc, leftPadType, input, tosa::getTosaConstShape(rewriter, loc, leftStartSlice),
tosa::getTosaConstShape(rewriter, loc, leftSizeSlice));

auto leftPad = rewriter.create<tosa::ReverseOp>(
loc, leftPadType, leftPadSlice.getResult(), static_cast<int32_t>(axis));
Expand Down Expand Up @@ -7702,8 +7702,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef<int64_t> unpaddedShape,

auto rightPadSlice = rewriter.create<tosa::SliceOp>(
loc, rightPadType, input,
rewriter.getDenseI64ArrayAttr(rightStartSlice),
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
tosa::getTosaConstShape(rewriter, loc, rightStartSlice),
tosa::getTosaConstShape(rewriter, loc, rightSizeSlice));

auto rightPad = rewriter.create<tosa::ReverseOp>(
loc, rightPadType, rightPadSlice.getResult(),
Expand Down Expand Up @@ -7949,8 +7949,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(

auto leftPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), leftPadSliceType, self,
rewriter.getDenseI64ArrayAttr(leftStartSlice),
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
tosa::getTosaConstShape(rewriter, op->getLoc(), leftStartSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), leftSizeSlice));

for (int64_t i = 0; i < paddingLeft; i++)
sideTensors.push_back(leftPadSlice.getResult());
Expand All @@ -7974,8 +7974,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(

auto rightPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), rightPadSliceType, self,
rewriter.getDenseI64ArrayAttr(rightStartSlice),
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
tosa::getTosaConstShape(rewriter, op->getLoc(), rightStartSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), rightSizeSlice));

for (int64_t i = 0; i < paddingRight; i++)
sideTensors.push_back(rightPadSlice.getResult());
Expand Down Expand Up @@ -8009,8 +8009,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(

auto topPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), topPadSliceType, selfSidePadded,
rewriter.getDenseI64ArrayAttr(topStartSlice),
rewriter.getDenseI64ArrayAttr(topSizeSlice));
tosa::getTosaConstShape(rewriter, op->getLoc(), topStartSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), topSizeSlice));

for (int64_t i = 0; i < paddingTop; i++)
resultTensors.push_back(topPadSlice.getResult());
Expand All @@ -8037,8 +8037,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(

auto bottomPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), bottomPadSliceType, selfSidePadded,
rewriter.getDenseI64ArrayAttr(bottomStartSlice),
rewriter.getDenseI64ArrayAttr(bottomSizeSlice));
tosa::getTosaConstShape(rewriter, op->getLoc(), bottomStartSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), bottomSizeSlice));

for (int64_t i = 0; i < paddingBottom; i++)
resultTensors.push_back(bottomPadSlice.getResult());
Expand Down
Loading

0 comments on commit eb349e5

Please sign in to comment.