-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][OpenMP] Support target enter|update|exit .. nowait
#113305
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-mlir Author: Kareem Ergawy (ergawy) ChangesExtends Patch is 26.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113305.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 8834c3b1f50115..d71712a677078c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2264,6 +2264,9 @@ class OpenMPIRBuilder {
bool EmitDebug = false;
+ /// Whether the `target ... data` directive has a `nowait` clause.
+ bool HasNoWait = false;
+
explicit TargetDataInfo() {}
explicit TargetDataInfo(bool RequiresDevicePointerInfo,
bool SeparateBeginEndCalls)
@@ -2342,7 +2345,6 @@ class OpenMPIRBuilder {
/// Generate a target region entry call and host fallback call.
///
/// \param Loc The location at which the request originated and is fulfilled.
- /// \param OutlinedFn The outlined kernel function.
/// \param OutlinedFnID The ooulined function ID.
/// \param EmitTargetCallFallbackCB Call back function to generate host
/// fallback code.
@@ -2350,18 +2352,27 @@ class OpenMPIRBuilder {
/// \param DeviceID Identifier for the device via the 'device' clause.
/// \param RTLoc Source location identifier
/// \param AllocaIP The insertion point to be used for alloca instructions.
- InsertPointTy emitKernelLaunch(
- const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
- Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP);
+ InsertPointTy
+ emitKernelLaunch(const LocationDescription &Loc, Value *OutlinedFnID,
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB,
+ TargetKernelArgs &Args, Value *DeviceID, Value *RTLoc,
+ InsertPointTy AllocaIP);
+
+ /// Callback type for generating the bodies of device directives that require
+ /// outer tasks (e.g. in case of having `nowait` or `depend` clauses).
+ ///
+ /// \param DeviceID The ID of the device on which the target region will
+ /// execute.
+ /// \param RTLoc Source location identifier
+ /// \Param TargetTaskAllocaIP Insertion point for the alloca block of the
+ /// generated task.
+ using TaskBodyCallbackTy =
+ function_ref<void(Value *DeviceID, Value *RTLoc,
+ IRBuilderBase::InsertPoint TargetTaskAllocaIP)>;
/// Generate a target-task for the target construct
///
- /// \param OutlinedFn The outlined device/target kernel function.
- /// \param OutlinedFnID The ooulined function ID.
- /// \param EmitTargetCallFallbackCB Call back function to generate host
- /// fallback code.
- /// \param Args Data structure holding information about the kernel arguments.
+ /// \param TaskBodyCB Callback to generate the actual body of the target task.
/// \param DeviceID Identifier for the device via the 'device' clause.
/// \param RTLoc Source location identifier
/// \param AllocaIP The insertion point to be used for alloca instructions.
@@ -2370,10 +2381,10 @@ class OpenMPIRBuilder {
/// \param HasNoWait True if the target construct had 'nowait' on it, false
/// otherwise
InsertPointTy emitTargetTask(
- Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
- Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP,
- SmallVector<OpenMPIRBuilder::DependData> &Dependencies, bool HasNoWait);
+ TaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
+ bool HasNoWait);
/// Emit the arguments to be passed to the runtime library based on the
/// arrays of base pointers, pointers, sizes, map types, and mappers. If
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 172812a3802d33..809ba3aad2982e 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1080,8 +1080,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
}
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
- const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
+ const LocationDescription &Loc, Value *OutlinedFnID,
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
if (!updateToLocation(Loc))
@@ -1134,7 +1134,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
auto CurFn = Builder.GetInsertBlock()->getParent();
emitBlock(OffloadFailedBlock, CurFn);
- Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
+ Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
emitBranch(OffloadContBlock);
emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
return Builder.saveIP();
@@ -1736,7 +1736,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
// - All code is inserted in the entry block of the current function.
static Value *emitTaskDependencies(
OpenMPIRBuilder &OMPBuilder,
- SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
+ const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
// Early return if we have no dependencies to process
if (Dependencies.empty())
return nullptr;
@@ -6386,12 +6386,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
// closing of the region.
auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
MapInfo = &GenMapInfoCB(Builder.saveIP());
- emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
- /*IsNonContiguous=*/true, DeviceAddrCB,
- CustomMapperCB);
-
TargetDataRTArgs RTArgs;
- emitOffloadingArraysArgument(Builder, RTArgs, Info);
+ emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, RTArgs,
+ *MapInfo, /*IsNonContiguous=*/true,
+ /*ForEndCall=*/false);
// Emit the number of elements in the offloading arrays.
Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
@@ -6403,16 +6401,45 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
}
- Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
- PointerNum, RTArgs.BasePointersArray,
- RTArgs.PointersArray, RTArgs.SizesArray,
- RTArgs.MapTypesArray, RTArgs.MapNamesArray,
- RTArgs.MappersArray};
+ SmallVector<llvm::Value *, 13> OffloadingArgs = {
+ SrcLocInfo, DeviceID,
+ PointerNum, RTArgs.BasePointersArray,
+ RTArgs.PointersArray, RTArgs.SizesArray,
+ RTArgs.MapTypesArray, RTArgs.MapNamesArray,
+ RTArgs.MappersArray};
if (IsStandAlone) {
assert(MapperFunc && "MapperFunc missing for standalone target data");
- Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
- OffloadingArgs);
+
+ auto TaskBodyCB = [&](Value *, Value *, IRBuilderBase::InsertPoint) {
+ if (Info.HasNoWait) {
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(Int32));
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(VoidPtr));
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(Int32));
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(VoidPtr));
+ }
+
+ Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
+ OffloadingArgs);
+
+ if (Info.HasNoWait) {
+ BasicBlock *OffloadContBlock =
+ BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
+ auto *CurFn = Builder.GetInsertBlock()->getParent();
+ emitBranch(OffloadContBlock);
+ emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
+ Builder.restoreIP(Builder.saveIP());
+ }
+ };
+
+ bool RequiresOuterTargetTask = Info.HasNoWait;
+
+ if (!RequiresOuterTargetTask)
+ TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
+ /*TargetTaskAllocaIP=*/{});
+ else
+ emitTargetTask(TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
+ /*Dependencies=*/{}, Info.HasNoWait);
} else {
Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
omp::OMPRTL___tgt_target_data_begin_mapper);
@@ -6836,13 +6863,18 @@ static void emitTargetOutlinedFunction(
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
IsOffloadEntry, OutlinedFn, OutlinedFnID);
}
+
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
- Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
- Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
- SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
+ TaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
bool HasNoWait) {
+ // The following explains the code-gen scenario for the `target` directive. A
+ // similar scneario is followed for other device-related directives (e.g.
+ // `target enter data`) but in similar fashion since we only need to emit task
+ // that encapsulates the proper runtime call.
+ //
// When we arrive at this function, the target region itself has been
// outlined into the function OutlinedFn.
// So at ths point, for
@@ -6950,22 +6982,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
Builder.restoreIP(TargetTaskBodyIP);
- if (OutlinedFnID) {
- // emitKernelLaunch makes the necessary runtime call to offload the kernel.
- // We then outline all that code into a separate function
- // ('kernel_launch_function' in the pseudo code above). This function is
- // then called by the target task proxy function (see
- // '@.omp_target_task_proxy_func' in the pseudo code above)
- // "@.omp_target_task_proxy_func' is generated by
- // emitTargetTaskProxyFunction.
- Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
- EmitTargetCallFallbackCB, Args, DeviceID,
- RTLoc, TargetTaskAllocaIP));
- } else {
- // When OutlinedFnID is set to nullptr, then it's not an offloading call. In
- // this case, we execute the host implementation directly.
- Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
- }
+ TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP);
OI.ExitBB = Builder.saveIP().getBlock();
OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, HasNoWait,
@@ -7153,6 +7170,29 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
+ OpenMPIRBuilder::TargetKernelArgs KArgs;
+
+ auto TaskBodyCB = [&](Value *DeviceID, Value *RTLoc,
+ IRBuilderBase::InsertPoint TargetTaskAllocaIP) {
+ if (OutlinedFnID) {
+ // emitKernelLaunch makes the necessary runtime call to offload the
+ // kernel. We then outline all that code into a separate function
+ // ('kernel_launch_function' in the pseudo code above). This function is
+ // then called by the target task proxy function (see
+ // '@.omp_target_task_proxy_func' in the pseudo code above)
+ // "@.omp_target_task_proxy_func' is generated by
+ // emitTargetTaskProxyFunction.
+ Builder.restoreIP(OMPBuilder.emitKernelLaunch(
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
+ RTLoc, TargetTaskAllocaIP));
+ } else {
+ // When OutlinedFnID is set to nullptr, then it's not an offloading
+ // call. In this case, we execute the host implementation directly.
+ OMPBuilder.Builder.restoreIP(
+ EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP()));
+ }
+ };
+
// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
@@ -7160,14 +7200,14 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
// results in that call not being done.
- OpenMPIRBuilder::TargetKernelArgs KArgs;
- Builder.restoreIP(OMPBuilder.emitTargetTask(
- OutlinedFn, /*OutlinedFnID=*/nullptr, EmitTargetCallFallbackCB, KArgs,
- /*DeviceID=*/nullptr, /*RTLoc=*/nullptr, AllocaIP, Dependencies,
- HasNoWait));
+ Builder.restoreIP(OMPBuilder.emitTargetTask(TaskBodyCB,
+ /*DeviceID=*/nullptr,
+ /*RTLoc=*/nullptr, AllocaIP,
+ Dependencies, HasNoWait));
} else {
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
+
return;
}
@@ -7201,20 +7241,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
- NumTeamsC, NumThreadsC, DynCGGroupMem,
- HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
- OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
- RTLoc, AllocaIP, Dependencies, HasNoWait));
+ TaskBodyCB, DeviceID, RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
- Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP));
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, RTLoc,
+ AllocaIP));
}
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7c45e89cd8ac4b..27cd38dc3c62d9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2886,6 +2886,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
+ /*SeparateBeginEndCalls=*/true);
LogicalResult result =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
@@ -2905,9 +2907,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
return success();
})
.Case([&](omp::TargetEnterDataOp enterDataOp) {
- if (enterDataOp.getNowait())
+ if (!enterDataOp.getDependVars().empty())
return (LogicalResult)(enterDataOp.emitError(
- "`nowait` is not supported yet"));
+ "`depend` is not supported yet"));
if (auto ifVar = enterDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -2917,14 +2919,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
- RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
+ RTLFn =
+ enterDataOp.getNowait()
+ ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
+ : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
mapVars = enterDataOp.getMapVars();
+ info.HasNoWait = enterDataOp.getNowait();
return success();
})
.Case([&](omp::TargetExitDataOp exitDataOp) {
- if (exitDataOp.getNowait())
+ if (!exitDataOp.getDependVars().empty())
return (LogicalResult)(exitDataOp.emitError(
- "`nowait` is not supported yet"));
+ "`depend` is not supported yet"));
if (auto ifVar = exitDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -2935,14 +2941,17 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
- RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
+ RTLFn = exitDataOp.getNowait()
+ ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
+ : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
mapVars = exitDataOp.getMapVars();
+ info.HasNoWait = exitDataOp.getNowait();
return success();
})
.Case([&](omp::TargetUpdateOp updateDataOp) {
- if (updateDataOp.getNowait())
+ if (!updateDataOp.getDependVars().empty())
return (LogicalResult)(updateDataOp.emitError(
- "`nowait` is not supported yet"));
+ "`depend` is not supported yet"));
if (auto ifVar = updateDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -2953,8 +2962,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
- RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
+ RTLFn =
+ updateDataOp.getNowait()
+ ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
+ : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
mapVars = updateDataOp.getMapVars();
+ info.HasNoWait = updateDataOp.getNowait();
return success();
})
.Default([&](Operation *op) {
@@ -3005,9 +3018,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
: basePointer);
};
- llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
- /*SeparateBeginEndCalls=*/true);
-
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir
deleted file mode 100644
index 1e2fbe86d13c47..00000000000000
--- a/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: not mlir-translate -mlir-to-llvmir -split-input-file %s 2>&1 | FileCheck %s
-
-llvm.func @_QPopenmp_target_data_update() {
- %0 = llvm.mlir.constant(1 : i64) : i64
- %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
- %2 = o...
[truncated]
|
@llvm/pr-subscribers-flang-openmp Author: Kareem Ergawy (ergawy) ChangesExtends Patch is 26.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113305.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 8834c3b1f50115..d71712a677078c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2264,6 +2264,9 @@ class OpenMPIRBuilder {
bool EmitDebug = false;
+ /// Whether the `target ... data` directive has a `nowait` clause.
+ bool HasNoWait = false;
+
explicit TargetDataInfo() {}
explicit TargetDataInfo(bool RequiresDevicePointerInfo,
bool SeparateBeginEndCalls)
@@ -2342,7 +2345,6 @@ class OpenMPIRBuilder {
/// Generate a target region entry call and host fallback call.
///
/// \param Loc The location at which the request originated and is fulfilled.
- /// \param OutlinedFn The outlined kernel function.
/// \param OutlinedFnID The ooulined function ID.
/// \param EmitTargetCallFallbackCB Call back function to generate host
/// fallback code.
@@ -2350,18 +2352,27 @@ class OpenMPIRBuilder {
/// \param DeviceID Identifier for the device via the 'device' clause.
/// \param RTLoc Source location identifier
/// \param AllocaIP The insertion point to be used for alloca instructions.
- InsertPointTy emitKernelLaunch(
- const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
- Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP);
+ InsertPointTy
+ emitKernelLaunch(const LocationDescription &Loc, Value *OutlinedFnID,
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB,
+ TargetKernelArgs &Args, Value *DeviceID, Value *RTLoc,
+ InsertPointTy AllocaIP);
+
+ /// Callback type for generating the bodies of device directives that require
+ /// outer tasks (e.g. in case of having `nowait` or `depend` clauses).
+ ///
+ /// \param DeviceID The ID of the device on which the target region will
+ /// execute.
+ /// \param RTLoc Source location identifier
+ /// \Param TargetTaskAllocaIP Insertion point for the alloca block of the
+ /// generated task.
+ using TaskBodyCallbackTy =
+ function_ref<void(Value *DeviceID, Value *RTLoc,
+ IRBuilderBase::InsertPoint TargetTaskAllocaIP)>;
/// Generate a target-task for the target construct
///
- /// \param OutlinedFn The outlined device/target kernel function.
- /// \param OutlinedFnID The ooulined function ID.
- /// \param EmitTargetCallFallbackCB Call back function to generate host
- /// fallback code.
- /// \param Args Data structure holding information about the kernel arguments.
+ /// \param TaskBodyCB Callback to generate the actual body of the target task.
/// \param DeviceID Identifier for the device via the 'device' clause.
/// \param RTLoc Source location identifier
/// \param AllocaIP The insertion point to be used for alloca instructions.
@@ -2370,10 +2381,10 @@ class OpenMPIRBuilder {
/// \param HasNoWait True if the target construct had 'nowait' on it, false
/// otherwise
InsertPointTy emitTargetTask(
- Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
- Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP,
- SmallVector<OpenMPIRBuilder::DependData> &Dependencies, bool HasNoWait);
+ TaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
+ bool HasNoWait);
/// Emit the arguments to be passed to the runtime library based on the
/// arrays of base pointers, pointers, sizes, map types, and mappers. If
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 172812a3802d33..809ba3aad2982e 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1080,8 +1080,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
}
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
- const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
+ const LocationDescription &Loc, Value *OutlinedFnID,
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
if (!updateToLocation(Loc))
@@ -1134,7 +1134,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
auto CurFn = Builder.GetInsertBlock()->getParent();
emitBlock(OffloadFailedBlock, CurFn);
- Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
+ Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
emitBranch(OffloadContBlock);
emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
return Builder.saveIP();
@@ -1736,7 +1736,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
// - All code is inserted in the entry block of the current function.
static Value *emitTaskDependencies(
OpenMPIRBuilder &OMPBuilder,
- SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
+ const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
// Early return if we have no dependencies to process
if (Dependencies.empty())
return nullptr;
@@ -6386,12 +6386,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
// closing of the region.
auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
MapInfo = &GenMapInfoCB(Builder.saveIP());
- emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
- /*IsNonContiguous=*/true, DeviceAddrCB,
- CustomMapperCB);
-
TargetDataRTArgs RTArgs;
- emitOffloadingArraysArgument(Builder, RTArgs, Info);
+ emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, RTArgs,
+ *MapInfo, /*IsNonContiguous=*/true,
+ /*ForEndCall=*/false);
// Emit the number of elements in the offloading arrays.
Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
@@ -6403,16 +6401,45 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
}
- Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
- PointerNum, RTArgs.BasePointersArray,
- RTArgs.PointersArray, RTArgs.SizesArray,
- RTArgs.MapTypesArray, RTArgs.MapNamesArray,
- RTArgs.MappersArray};
+ SmallVector<llvm::Value *, 13> OffloadingArgs = {
+ SrcLocInfo, DeviceID,
+ PointerNum, RTArgs.BasePointersArray,
+ RTArgs.PointersArray, RTArgs.SizesArray,
+ RTArgs.MapTypesArray, RTArgs.MapNamesArray,
+ RTArgs.MappersArray};
if (IsStandAlone) {
assert(MapperFunc && "MapperFunc missing for standalone target data");
- Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
- OffloadingArgs);
+
+ auto TaskBodyCB = [&](Value *, Value *, IRBuilderBase::InsertPoint) {
+ if (Info.HasNoWait) {
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(Int32));
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(VoidPtr));
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(Int32));
+ OffloadingArgs.push_back(llvm::Constant::getNullValue(VoidPtr));
+ }
+
+ Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
+ OffloadingArgs);
+
+ if (Info.HasNoWait) {
+ BasicBlock *OffloadContBlock =
+ BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
+ auto *CurFn = Builder.GetInsertBlock()->getParent();
+ emitBranch(OffloadContBlock);
+ emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
+ Builder.restoreIP(Builder.saveIP());
+ }
+ };
+
+ bool RequiresOuterTargetTask = Info.HasNoWait;
+
+ if (!RequiresOuterTargetTask)
+ TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
+ /*TargetTaskAllocaIP=*/{});
+ else
+ emitTargetTask(TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
+ /*Dependencies=*/{}, Info.HasNoWait);
} else {
Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
omp::OMPRTL___tgt_target_data_begin_mapper);
@@ -6836,13 +6863,18 @@ static void emitTargetOutlinedFunction(
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
IsOffloadEntry, OutlinedFn, OutlinedFnID);
}
+
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
- Function *OutlinedFn, Value *OutlinedFnID,
- EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
- Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
- SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
+ TaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
bool HasNoWait) {
+ // The following explains the code-gen scenario for the `target` directive. A
+ // similar scneario is followed for other device-related directives (e.g.
+ // `target enter data`) but in similar fashion since we only need to emit task
+ // that encapsulates the proper runtime call.
+ //
// When we arrive at this function, the target region itself has been
// outlined into the function OutlinedFn.
// So at ths point, for
@@ -6950,22 +6982,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
Builder.restoreIP(TargetTaskBodyIP);
- if (OutlinedFnID) {
- // emitKernelLaunch makes the necessary runtime call to offload the kernel.
- // We then outline all that code into a separate function
- // ('kernel_launch_function' in the pseudo code above). This function is
- // then called by the target task proxy function (see
- // '@.omp_target_task_proxy_func' in the pseudo code above)
- // "@.omp_target_task_proxy_func' is generated by
- // emitTargetTaskProxyFunction.
- Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
- EmitTargetCallFallbackCB, Args, DeviceID,
- RTLoc, TargetTaskAllocaIP));
- } else {
- // When OutlinedFnID is set to nullptr, then it's not an offloading call. In
- // this case, we execute the host implementation directly.
- Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
- }
+ TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP);
OI.ExitBB = Builder.saveIP().getBlock();
OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, HasNoWait,
@@ -7153,6 +7170,29 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
+ OpenMPIRBuilder::TargetKernelArgs KArgs;
+
+ auto TaskBodyCB = [&](Value *DeviceID, Value *RTLoc,
+ IRBuilderBase::InsertPoint TargetTaskAllocaIP) {
+ if (OutlinedFnID) {
+ // emitKernelLaunch makes the necessary runtime call to offload the
+ // kernel. We then outline all that code into a separate function
+ // ('kernel_launch_function' in the pseudo code above). This function is
+ // then called by the target task proxy function (see
+ // '@.omp_target_task_proxy_func' in the pseudo code above)
+ // "@.omp_target_task_proxy_func' is generated by
+ // emitTargetTaskProxyFunction.
+ Builder.restoreIP(OMPBuilder.emitKernelLaunch(
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
+ RTLoc, TargetTaskAllocaIP));
+ } else {
+ // When OutlinedFnID is set to nullptr, then it's not an offloading
+ // call. In this case, we execute the host implementation directly.
+ OMPBuilder.Builder.restoreIP(
+ EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP()));
+ }
+ };
+
// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
@@ -7160,14 +7200,14 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
// results in that call not being done.
- OpenMPIRBuilder::TargetKernelArgs KArgs;
- Builder.restoreIP(OMPBuilder.emitTargetTask(
- OutlinedFn, /*OutlinedFnID=*/nullptr, EmitTargetCallFallbackCB, KArgs,
- /*DeviceID=*/nullptr, /*RTLoc=*/nullptr, AllocaIP, Dependencies,
- HasNoWait));
+ Builder.restoreIP(OMPBuilder.emitTargetTask(TaskBodyCB,
+ /*DeviceID=*/nullptr,
+ /*RTLoc=*/nullptr, AllocaIP,
+ Dependencies, HasNoWait));
} else {
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
+
return;
}
@@ -7201,20 +7241,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
- NumTeamsC, NumThreadsC, DynCGGroupMem,
- HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
- OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
- RTLoc, AllocaIP, Dependencies, HasNoWait));
+ TaskBodyCB, DeviceID, RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
- Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP));
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, RTLoc,
+ AllocaIP));
}
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7c45e89cd8ac4b..27cd38dc3c62d9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2886,6 +2886,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
+ /*SeparateBeginEndCalls=*/true);
LogicalResult result =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
@@ -2905,9 +2907,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
return success();
})
.Case([&](omp::TargetEnterDataOp enterDataOp) {
- if (enterDataOp.getNowait())
+ if (!enterDataOp.getDependVars().empty())
return (LogicalResult)(enterDataOp.emitError(
- "`nowait` is not supported yet"));
+ "`depend` is not supported yet"));
if (auto ifVar = enterDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -2917,14 +2919,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
- RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
+ RTLFn =
+ enterDataOp.getNowait()
+ ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
+ : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
mapVars = enterDataOp.getMapVars();
+ info.HasNoWait = enterDataOp.getNowait();
return success();
})
.Case([&](omp::TargetExitDataOp exitDataOp) {
- if (exitDataOp.getNowait())
+ if (!exitDataOp.getDependVars().empty())
return (LogicalResult)(exitDataOp.emitError(
- "`nowait` is not supported yet"));
+ "`depend` is not supported yet"));
if (auto ifVar = exitDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -2935,14 +2941,17 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
- RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
+ RTLFn = exitDataOp.getNowait()
+ ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
+ : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
mapVars = exitDataOp.getMapVars();
+ info.HasNoWait = exitDataOp.getNowait();
return success();
})
.Case([&](omp::TargetUpdateOp updateDataOp) {
- if (updateDataOp.getNowait())
+ if (!updateDataOp.getDependVars().empty())
return (LogicalResult)(updateDataOp.emitError(
- "`nowait` is not supported yet"));
+ "`depend` is not supported yet"));
if (auto ifVar = updateDataOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -2953,8 +2962,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
- RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
+ RTLFn =
+ updateDataOp.getNowait()
+ ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
+ : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
mapVars = updateDataOp.getMapVars();
+ info.HasNoWait = updateDataOp.getNowait();
return success();
})
.Default([&](Operation *op) {
@@ -3005,9 +3018,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
: basePointer);
};
- llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
- /*SeparateBeginEndCalls=*/true);
-
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
LogicalResult bodyGenStatus = success();
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir
deleted file mode 100644
index 1e2fbe86d13c47..00000000000000
--- a/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: not mlir-translate -mlir-to-llvmir -split-input-file %s 2>&1 | FileCheck %s
-
-llvm.func @_QPopenmp_target_data_update() {
- %0 = llvm.mlir.constant(1 : i64) : i64
- %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
- %2 = o...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -2905,9 +2907,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, | |||
return success(); | |||
}) | |||
.Case([&](omp::TargetEnterDataOp enterDataOp) { | |||
if (enterDataOp.getNowait()) | |||
if (!enterDataOp.getDependVars().empty()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handling of depend
looks unrelated to handling nowait
. Since it reuses code, I don't mind in this case.
acd617a
to
fddc36e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of minor nits, that's all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for this.
83088c0
to
70a0c97
Compare
Extends `nowait` support for other device directives. This PR refactors the task generation utils used for the `target` directive so that they are general enough to be reused for other device directives as well.
70a0c97
to
52b5966
Compare
|
Extends
nowait
support for other device directives. This PR refactors the task generation utils used for thetarget
directive so that they are general enough to be reused for other device directives as well.