diff --git a/flang/lib/Semantics/rewrite-directives.cpp b/flang/lib/Semantics/rewrite-directives.cpp index c94d0f3855bee..245bc16763e9e 100644 --- a/flang/lib/Semantics/rewrite-directives.cpp +++ b/flang/lib/Semantics/rewrite-directives.cpp @@ -44,6 +44,7 @@ class OmpRewriteMutator : public DirectiveRewriteMutator { bool Pre(parser::OpenMPAtomicConstruct &); bool Pre(parser::OpenMPRequiresConstruct &); + bool Pre(parser::ExecutableConstruct &); private: bool atomicDirectiveDefaultOrderFound_{false}; @@ -163,6 +164,66 @@ bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) { return false; } +static void convertLoopControl(parser::LoopControl &loop) { + // Extract relevant information from concurrent loop control. + // TODO: Support LocalitySpec, IntegerTypeSpec, ScalarMaskExpr. + auto &concurrent{std::get(loop.u)}; + auto &header{std::get(concurrent.t)}; + // TODO: Support multiple ConcurrentControl: create multiple loops + collapse. + auto &control{ + std::get>(header.t).front()}; + auto &[name, lower, upper, step] = control.t; + + // Create needed information for the new loop control. + parser::ScalarName newName{std::move(name)}; + parser::Scalar newLower{ + common::Indirection{parser::Expr{std::move(lower.thing.thing.value())}}}; + parser::Scalar newUpper{ + common::Indirection{parser::Expr{std::move(upper.thing.thing.value())}}}; + std::optional>> newStep; + if (step) { + newStep = parser::Scalar{common::Indirection{ + parser::Expr{std::move(step->thing.thing.value())}}}; + } + + // Replace loop control. + loop.u = parser::LoopControl::Bounds{std::move(newName), std::move(newLower), + std::move(newUpper), std::move(newStep)}; +} + +// TODO: Investigate crashes after sema. Looks like symbols need updating too. +bool OmpRewriteMutator::Pre(parser::ExecutableConstruct &x) { + // TODO: Only mutate PFT if -fdo-concurrent-parallel is passed. Would have to + // pass that information separately or add it to SemanticsContext. + if (auto *doConstruct{ + std::get_if>(&x.u)}) { + if (doConstruct->value().IsDoConcurrent()) { + // Replace do-concurrent loop control. + auto &doStmt{std::get>( + doConstruct->value().t)}; + auto &loopControl{ + std::get>(doStmt.statement.t)}; + convertLoopControl(*loopControl); + + // Construct OpenMP loop PFT node. + // TODO: Select directive based on -fdo-concurrent-parallel instead. + parser::OmpLoopDirective ompLoopDir{ + llvm::omp::Directive::OMPD_parallel_do}; + std::list clauses; + parser::OmpClauseList ompClauses{std::move(clauses)}; + parser::OmpBeginLoopDirective beginLoopDir{ + std::move(ompLoopDir), std::move(ompClauses)}; + parser::OpenMPLoopConstruct loopConstruct{std::move(beginLoopDir)}; + std::get<1>(loopConstruct.t) = std::move(doConstruct->value()); + + // Replace original loop with constructed OpenMP loop. + x.u = common::Indirection{ + std::move(parser::OpenMPConstruct{std::move(loopConstruct)})}; + } + } + return true; +} + bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) { if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { return true;