From eb349e5bd379bb0940c020c3c938febff55006a4 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Fri, 31 Jan 2025 20:54:32 +0000 Subject: [PATCH] [TOSA] Update tosa.slice's start and size to !tosa.shape type * In TOSA 1.0, tosa.slice's `start` and `size` are !tosa.shape types. Update tosa.slice in Torch to TOSA in alignment with that. * Update LIT tests. Signed-off-by: Justin Ngo Change-Id: Icf878ea4dc43ec1af3bd498b5ae96f514fe0f04a --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 32 ++--- test/Conversion/TorchToTosa/basic.mlir | 149 +++++++++++++-------- 2 files changed, 109 insertions(+), 72 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0863ea49e64f..ae493470f439 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4010,8 +4010,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(startSlice), - rewriter.getDenseI64ArrayAttr(sizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); return success(); } @@ -7143,8 +7143,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( startSlice[targetDim1] = std::abs(offset); diagonalTensor = rewriter.create( op->getLoc(), transposedInputType, diagonalTensor, - rewriter.getDenseI64ArrayAttr(startSlice), - rewriter.getDenseI64ArrayAttr(sizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); } // Apply Reduce Sum to get the result @@ -7669,8 +7669,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); auto leftPadSlice = rewriter.create( - loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + loc, leftPadType, input, tosa::getTosaConstShape(rewriter, loc, leftStartSlice), + tosa::getTosaConstShape(rewriter, loc, leftSizeSlice)); auto leftPad = rewriter.create( loc, leftPadType, leftPadSlice.getResult(), static_cast(axis)); @@ -7702,8 +7702,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, auto rightPadSlice = rewriter.create( loc, rightPadType, input, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + tosa::getTosaConstShape(rewriter, loc, rightStartSlice), + tosa::getTosaConstShape(rewriter, loc, rightSizeSlice)); auto rightPad = rewriter.create( loc, rightPadType, rightPadSlice.getResult(), @@ -7949,8 +7949,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto leftPadSlice = rewriter.create( op->getLoc(), leftPadSliceType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), leftStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), leftSizeSlice)); for (int64_t i = 0; i < paddingLeft; i++) sideTensors.push_back(leftPadSlice.getResult()); @@ -7974,8 +7974,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto rightPadSlice = rewriter.create( op->getLoc(), rightPadSliceType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), rightStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), rightSizeSlice)); for (int64_t i = 0; i < paddingRight; i++) sideTensors.push_back(rightPadSlice.getResult()); @@ -8009,8 +8009,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto topPadSlice = rewriter.create( op->getLoc(), topPadSliceType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), topStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), topSizeSlice)); for (int64_t i = 0; i < paddingTop; i++) resultTensors.push_back(topPadSlice.getResult()); @@ -8037,8 +8037,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bottomPadSlice = rewriter.create( op->getLoc(), bottomPadSliceType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), bottomStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), bottomSizeSlice)); for (int64_t i = 0; i < paddingBottom; i++) resultTensors.push_back(bottomPadSlice.getResult()); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index cdf6c79eb4f5..267b05b93aa5 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1193,15 +1193,17 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> // ----- // CHECK-LABEL: func.func @torch.aten.slice.negative_start( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.constant.int 100 // CHECK: %[[VAL_5:.*]] = torch.constant.int -16 -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<4x65x256xf32>) -> tensor<4x16x256xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[0, 49, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 16, 256]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<4x65x256xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x16x256xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[4,16,256],f32> // CHECK: } func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { %int0 = torch.constant.int 0 @@ -2001,11 +2003,13 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?, // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>, tensor<1xi8>) -> tensor<5x6x4x3xi32> -// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reduce_sum %[[VAL_10]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32> -// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> -// CHECK: return %[[VAL_13]] : !torch.vtensor<[5,6,2],si32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[0, 0, 2, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[5, 6, 2, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : (tensor<5x6x4x3xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x6x2x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,6,2],si32> // CHECK: } func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> { %dim1 = torch.constant.int 1 @@ -2604,13 +2608,17 @@ func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !tor // CHECK: %[[VAL_2:.*]] = torch.constant.int 3 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[0, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_5]], %[[VAL_6]] : (tensor<1x2x4xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[0, 0, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_1]], %[[VAL_9]], %[[VAL_10]] : (tensor<1x2x4xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_8]], %[[VAL_1]], %[[VAL_12]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,2,8],f32> // CHECK: } func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { %int3 = torch.constant.int 3 @@ -2627,18 +2635,26 @@ func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 10 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 20, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]], %[[VAL_4]], %[[VAL_5]] : (tensor<1x20x20xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x20x10xf32> // CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> -// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> -// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> -// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[0, 0, 9]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[1, 20, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_1]], %[[VAL_8]], %[[VAL_9]] : (tensor<1x20x20xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reverse %[[VAL_10]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[0, 1, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[1, 10, 40]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : (tensor<1x20x40xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reverse %[[VAL_15]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[0, 9, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[1, 10, 40]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x20x40xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reverse %[[VAL_19]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_21:.*]] = tosa.concat %[[VAL_16]], %[[VAL_12]], %[[VAL_20]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,40,40],f32> // CHECK: } func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { %int10 = torch.constant.int 10 @@ -2650,27 +2666,40 @@ func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32 // ----- // CHECK-LABEL: func.func @torch.aten.reflection_pad3d$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> -// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[SLICE_L:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[REVERSE_L:.*]] = tosa.reverse %[[SLICE_L]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[SLICE_R:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[REVERSE_R:.*]] = tosa.reverse %[[SLICE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[CONCAT_LR:.*]] = tosa.concat %[[REVERSE_L]], %[[VAL_0]], %[[REVERSE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> -// CHECK: %[[SLICE_T:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[REVERSE_T:.*]] = tosa.reverse %[[SLICE_T]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[SLICE_B:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[REVERSE_B:.*]] = tosa.reverse %[[SLICE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[CONCAT_TB:.*]] = tosa.concat %[[REVERSE_T]], %[[CONCAT_LR]], %[[REVERSE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> -// CHECK: %[[SLICE_F:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[REVERSE_F:.*]] = tosa.reverse %[[SLICE_F]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[SLICE_BACK:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[REVERSE_BACK:.*]] = tosa.reverse %[[SLICE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[CONCAT_FB:.*]] = tosa.concat %[[REVERSE_F]], %[[CONCAT_TB]], %[[REVERSE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONCAT_FB]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> -// CHECK: return %[[RESULT]] +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 3, 2]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]], %[[VAL_4]], %[[VAL_5]] : (tensor<4x5x7x3x4xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 3, 2]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_1]], %[[VAL_8]], %[[VAL_9]] : (tensor<4x5x7x3x4xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reverse %[[VAL_10]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_11]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 2, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : (tensor<4x5x7x3x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reverse %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 2, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] : (tensor<4x5x7x3x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reverse %[[VAL_19]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_21:.*]] = tosa.concat %[[VAL_16]], %[[VAL_12]], %[[VAL_20]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> +// CHECK: %[[VAL_22:.*]] = tosa.const_shape {value = dense<[0, 0, 1, 0, 0]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_23:.*]] = tosa.const_shape {value = dense<[4, 5, 2, 7, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_24:.*]] = tosa.slice %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : (tensor<4x5x7x7x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reverse %[[VAL_24]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_26:.*]] = tosa.const_shape {value = dense<[0, 0, 4, 0, 0]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {value = dense<[4, 5, 2, 7, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_28:.*]] = tosa.slice %[[VAL_21]], %[[VAL_26]], %[[VAL_27]] : (tensor<4x5x7x7x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_29:.*]] = tosa.reverse %[[VAL_28]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_30:.*]] = tosa.concat %[[VAL_25]], %[[VAL_21]], %[[VAL_29]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> +// CHECK: %[[VAL_31:.*]] = torch_c.from_builtin_tensor %[[VAL_30]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: return %[[VAL_31]] : !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: } func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { %int2 = torch.constant.int 2 %0 = torch.prim.ListConstruct %int2, %int2, %int2, %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list @@ -2688,14 +2717,22 @@ func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 // CHECK: %[[VAL_5:.*]] = torch.constant.int 4 // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> -// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> -// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> -// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 1, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_1]], %[[VAL_7]], %[[VAL_8]] : (tensor<1x1x3x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[1, 1, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_1]], %[[VAL_10]], %[[VAL_11]] : (tensor<1x1x3x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_9]], %[[VAL_1]], %[[VAL_12]], %[[VAL_12]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x1x3x6xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[0, 0, 2, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_13]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x1x3x6xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_20:.*]] = tosa.concat %[[VAL_16]], %[[VAL_16]], %[[VAL_16]], %[[VAL_13]], %[[VAL_19]], %[[VAL_19]], %[[VAL_19]], %[[VAL_19]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,10,6],f32> // CHECK: } func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { %int1 = torch.constant.int 1