Skip to content

Commit 3023352

Browse files
Denis Khalikovantiagainst
Denis Khalikov
authored andcommitted
[mlir][spirv] Simplify scalar type size calculation.
Simplify scalar type size calculation and reject boolean memrefs. Differential Revision: https://reviews.llvm.org/D72999
1 parent a688301 commit 3023352

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
5151
{
5252
for (auto argType : enumerate(funcOp.getType().getInputs())) {
5353
auto convertedType = typeConverter.convertType(argType.value());
54+
if (!convertedType) {
55+
return matchFailure();
56+
}
5457
signatureConverter.addInputs(argType.index(), convertedType);
5558
}
5659
}

mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,18 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
4141
// TODO(ravishankarm): This is a utility function that should probably be
4242
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
4343
static Optional<int64_t> getTypeNumBytes(Type t) {
44-
if (auto integerType = t.dyn_cast<IntegerType>()) {
45-
return integerType.getWidth() / 8;
46-
} else if (auto floatType = t.dyn_cast<FloatType>()) {
47-
return floatType.getWidth() / 8;
44+
if (spirv::SPIRVDialect::isValidScalarType(t)) {
45+
auto bitWidth = t.getIntOrFloatBitWidth();
46+
// According to the SPIR-V spec:
47+
// "There is no physical size or bit pattern defined for values with boolean
48+
// type. If they are stored (in conjunction with OpVariable), they can only
49+
// be used with logical addressing operations, not physical, and only with
50+
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
51+
// Private, Function, Input, and Output."
52+
if (bitWidth == 1) {
53+
return llvm::None;
54+
}
55+
return bitWidth / 8;
4856
} else if (auto memRefType = t.dyn_cast<MemRefType>()) {
4957
// TODO: Layout should also be controlled by the ABI attributes. For now
5058
// using the layout from MemRef.

mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,12 @@ func @sitofp(%arg0 : i32) {
289289
%0 = std.sitofp %arg0 : i32 to f32
290290
return
291291
}
292+
293+
//===----------------------------------------------------------------------===//
294+
// memref type
295+
//===----------------------------------------------------------------------===//
296+
297+
// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) {
298+
func @memref_type(%arg0: memref<3xi1>) {
299+
return
300+
}

0 commit comments

Comments
 (0)