-
Notifications
You must be signed in to change notification settings - Fork 341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GatherElements]: Implement verification, shape inference, code gen #1375
Conversation
…eneration. Signed-off-by: Ettore Tiotto <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding some notes to help code reviewers understand changes more easily.
@@ -72,12 +72,13 @@ struct OnnxToKrnlBuilder : public OnnxBuilder { | |||
// Common functions used when lowering the ONNX frontend dialect to KRNL. | |||
//===----------------------------------------------------------------------===// | |||
|
|||
/// Check is all dimensions are known at compile time. | |||
bool hasAllConstantDimensions(MemRefType type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is declaration does not have a corresponding definition, removing.
int64_t axisLit = gatherOp.axis(); | ||
int64_t dataRank = shapeHelper.dataDims.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have simplified the shapehelper implementation, to make it consistent with other shape helpers.
@@ -19,17 +19,6 @@ using namespace mlir; | |||
|
|||
namespace onnx_mlir { | |||
|
|||
// Returns true if all the indices are known to be positive and false otherwise. | |||
static bool indicesArePositiveConstants(Value indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Placed this utility function in ONNXtoKernlCommon.cpp so that it can be used by Gather, GatherElements, and ScatterElements during code generation (conversion from ONNX to Krnl).
int64_t axisIndex = op->axis(); | ||
// The 'axis' value must be in [-rank, rank-1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified to make more consistent with other shape helpers. The auxiliary vector (dataDims) now computed in the ONNXToKrnl conversion code.
@@ -296,21 +298,6 @@ struct ONNXMatMulOpShapeHelper : public ONNXOpShapeHelper<mlir::ONNXMatMulOp> { | |||
bPadDims; // When true, that dim was padded. | |||
}; | |||
|
|||
// Shape for Gather. | |||
struct ONNXGatherOpShapeHelper : public ONNXOpShapeHelper<mlir::ONNXGatherOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now generated via DECLARE_SHAPE_HELPER
func @test_gather_scalar(%arg0: tensor<4xi64>, %arg1: tensor<i64>) -> tensor<i64> { | ||
%0 = "onnx.Gather"(%arg0, %arg1) {axis = 0 : si64} : (tensor<4xi64>, tensor<i64>) -> tensor<i64> | ||
return %0 : tensor<i64> | ||
// CHECK-LABEL: @test_gather_scalar | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi64>, [[PARAM_1_:%.+]]: memref<i64>) -> memref<i64> { | ||
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codegen essentially the same, some constant moved around.
@@ -1564,7 +1564,7 @@ func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64> | |||
"std.return"(%0) : (tensor<*xf32>) -> () | |||
|
|||
// CHECK-LABEL: test_gather_negative_axis | |||
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> | |||
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = -1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shapeHelper used to change the negative axis value. Now we are leaving the axis attribute as it was on the original graph, and dealing with the semantics associated with a negative axis value during codegen (same approach as in ScatterElements, GatherElements, and ScatterND).
Signed-off-by: Ettore Tiotto <[email protected]>
Signed-off-by: Ettore Tiotto <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Jenkins Linux amd64 Build #5395 [push] [GatherElements]: Implem... started at 08:31 |
Jenkins Linux ppc64le Build #4532 [push] [GatherElements]: Implem... started at 09:35 |
Jenkins Linux s390x Build #5411 [push] [GatherElements]: Implem... started at 09:31 |
Jenkins Linux s390x Build #5411 [push] [GatherElements]: Implem... failed after 40 min |
Jenkins Linux ppc64le Build #4532 [push] [GatherElements]: Implem... passed after 1 hr 0 min |
Jenkins Linux amd64 Build #5395 [push] [GatherElements]: Implem... passed after 1 hr 3 min |
Implement support for the ONNX
GatherElements
operator:Signed-off-by: Ettore Tiotto [email protected]