Skip to content

Commit 6820b08

Browse files
authored
Refactor LoopFuseSiblingOp and support parallel fusion (llvm#94391)
This patch refactors code related to `LoopFuseSiblingOp` transform in attempt to reduce duplicate common code. The aim is to refactor as much as possible to a functions on `LoopLikeOpInterface`s, but this is still a work in progress. A full refactor will require more additions to the `LoopLikeOpInterface`. In addition, `scf.parallel` fusion support has been added.
1 parent 6c3897d commit 6820b08

File tree

9 files changed

+586
-283
lines changed

9 files changed

+586
-283
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ def ForallOp : SCF_Op<"forall", [
303303
DeclareOpInterfaceMethods<LoopLikeOpInterface,
304304
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
305305
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
306-
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
306+
"replaceWithAdditionalYields", "promoteIfSingleIteration",
307+
"yieldTiledValuesAndReplace"]>,
307308
RecursiveMemoryEffects,
308309
SingleBlockImplicitTerminator<"scf::InParallelOp">,
309310
DeclareOpInterfaceMethods<RegionBranchOpInterface>,

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

+20
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
181181
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
182182
scf::ForOp root);
183183

184+
//===----------------------------------------------------------------------===//
185+
// Fusion related helpers
186+
//===----------------------------------------------------------------------===//
187+
188+
/// Check structural compatibility between two loops such as iteration space
189+
/// and dominance.
190+
bool checkFusionStructuralLegality(LoopLikeOpInterface target,
191+
LoopLikeOpInterface source,
192+
Diagnostic &diag);
193+
184194
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
185195
/// `source`. Assumes that the given loops are siblings and are independent of
186196
/// each other.
@@ -202,6 +212,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
202212
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
203213
RewriterBase &rewriter);
204214

215+
/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
216+
/// `source`. Assumes that the given loops are siblings and are independent of
217+
/// each other.
218+
///
219+
/// This function does not perform any legality checks and simply fuses the
220+
/// loops. The caller is responsible for ensuring that the loops are legal to
221+
/// fuse.
222+
scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
223+
scf::ParallelOp source,
224+
RewriterBase &rewriter);
205225
} // namespace mlir
206226

207227
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/include/mlir/Interfaces/LoopLikeInterface.h

+20
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,24 @@ struct JamBlockGatherer {
9090
/// Include the generated interface declarations.
9191
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
9292

93+
namespace mlir {
94+
/// A function that rewrites `target`'s terminator as a teminator obtained by
95+
/// fusing `source` into `target`.
96+
using FuseTerminatorFn =
97+
function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
98+
LoopLikeOpInterface &target, IRMapping mapping)>;
99+
100+
/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
101+
/// `target`. The `NewYieldValuesFn` callback is used to pass to the
102+
/// `replaceWithAdditionalYields` interface method to replace the loop with a
103+
/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
104+
/// callback is repsonsible for updating the fused loop terminator.
105+
LoopLikeOpInterface createFused(LoopLikeOpInterface target,
106+
LoopLikeOpInterface source,
107+
RewriterBase &rewriter,
108+
NewYieldValuesFn newYieldValuesFn,
109+
FuseTerminatorFn fuseTerminatorFn);
110+
111+
} // namespace mlir
112+
93113
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_

mlir/lib/Dialect/SCF/IR/SCF.cpp

+38
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
618618

619619
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620620

621+
FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
622+
RewriterBase &rewriter, ValueRange newInitOperands,
623+
bool replaceInitOperandUsesInLoop,
624+
const NewYieldValuesFn &newYieldValuesFn) {
625+
// Create a new loop before the existing one, with the extra operands.
626+
OpBuilder::InsertionGuard g(rewriter);
627+
rewriter.setInsertionPoint(getOperation());
628+
SmallVector<Value> inits(getOutputs());
629+
llvm::append_range(inits, newInitOperands);
630+
scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
631+
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
632+
inits, getMapping(),
633+
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
634+
635+
// Move the loop body to the new op.
636+
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
637+
newLoop.getBody()->getArguments().take_front(
638+
getBody()->getNumArguments()));
639+
640+
if (replaceInitOperandUsesInLoop) {
641+
// Replace all uses of `newInitOperands` with the corresponding basic block
642+
// arguments.
643+
for (auto &&[newOperand, oldOperand] :
644+
llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
645+
newInitOperands.size()))) {
646+
rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
647+
Operation *user = use.getOwner();
648+
return newLoop->isProperAncestor(user);
649+
});
650+
}
651+
}
652+
653+
// Replace the old loop.
654+
rewriter.replaceOp(getOperation(),
655+
newLoop->getResults().take_front(getNumResults()));
656+
return cast<LoopLikeOpInterface>(newLoop.getOperation());
657+
}
658+
621659
/// Promotes the loop body of a forallOp to its containing block if it can be
622660
/// determined that the loop has a single iteration.
623661
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

+21-119
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
261261
return 1;
262262
};
263263

264-
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
265-
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
264+
std::optional<int64_t> ubConstant =
265+
getConstantIntValue(forOp.getUpperBound());
266+
std::optional<int64_t> lbConstant =
267+
getConstantIntValue(forOp.getLowerBound());
266268
DenseMap<Operation *, unsigned> opCycles;
267269
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
268270
for (Operation &op : forOp.getBody()->getOperations()) {
@@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
447449
// LoopFuseSiblingOp
448450
//===----------------------------------------------------------------------===//
449451

450-
/// Check if `target` and `source` are siblings, in the context that `target`
451-
/// is being fused into `source`.
452-
///
453-
/// This is a simple check that just checks if both operations are in the same
454-
/// block and some checks to ensure that the fused IR does not violate
455-
/// dominance.
456-
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
457-
Operation *source) {
458-
// Check if both operations are same.
459-
if (target == source)
460-
return emitSilenceableFailure(source)
461-
<< "target and source need to be different loops";
462-
463-
// Check if both operations are in the same block.
464-
if (target->getBlock() != source->getBlock())
465-
return emitSilenceableFailure(source)
466-
<< "target and source are not in the same block";
467-
468-
// Check if fusion will violate dominance.
469-
DominanceInfo domInfo(source);
470-
if (target->isBeforeInBlock(source)) {
471-
// Since `target` is before `source`, all users of results of `target`
472-
// need to be dominated by `source`.
473-
for (Operation *user : target->getUsers()) {
474-
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
475-
return emitSilenceableFailure(target)
476-
<< "user of results of target should be properly dominated by "
477-
"source";
478-
}
479-
}
480-
} else {
481-
// Since `target` is after `source`, all values used by `target` need
482-
// to dominate `source`.
483-
484-
// Check if operands of `target` are dominated by `source`.
485-
for (Value operand : target->getOperands()) {
486-
Operation *operandOp = operand.getDefiningOp();
487-
// Operands without defining operations are block arguments. When `target`
488-
// and `source` occur in the same block, these operands dominate `source`.
489-
if (!operandOp)
490-
continue;
491-
492-
// Operand's defining operation should properly dominate `source`.
493-
if (!domInfo.properlyDominates(operandOp, source,
494-
/*enclosingOpOk=*/false))
495-
return emitSilenceableFailure(target)
496-
<< "operands of target should be properly dominated by source";
497-
}
498-
499-
// Check if values used by `target` are dominated by `source`.
500-
bool failed = false;
501-
OpOperand *failedValue = nullptr;
502-
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
503-
Operation *operandOp = operand->get().getDefiningOp();
504-
if (operandOp && !domInfo.properlyDominates(operandOp, source,
505-
/*enclosingOpOk=*/false)) {
506-
// `operand` is not an argument of an enclosing block and the defining
507-
// op of `operand` is outside `target` but does not dominate `source`.
508-
failed = true;
509-
failedValue = operand;
510-
}
511-
});
512-
513-
if (failed)
514-
return emitSilenceableFailure(failedValue->getOwner())
515-
<< "values used inside regions of target should be properly "
516-
"dominated by source";
517-
}
518-
519-
return DiagnosedSilenceableFailure::success();
520-
}
521-
522-
/// Check if `target` scf.forall can be fused into `source` scf.forall.
523-
///
524-
/// This simply checks if both loops have the same bounds, steps and mapping.
525-
/// No attempt is made at checking that the side effects of `target` and
526-
/// `source` are independent of each other.
527-
static bool isForallWithIdenticalConfiguration(Operation *target,
528-
Operation *source) {
529-
auto targetOp = dyn_cast<scf::ForallOp>(target);
530-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
531-
if (!targetOp || !sourceOp)
532-
return false;
533-
534-
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
535-
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
536-
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
537-
targetOp.getMapping() == sourceOp.getMapping();
538-
}
539-
540-
/// Check if `target` scf.for can be fused into `source` scf.for.
541-
///
542-
/// This simply checks if both loops have the same bounds and steps. No attempt
543-
/// is made at checking that the side effects of `target` and `source` are
544-
/// independent of each other.
545-
static bool isForWithIdenticalConfiguration(Operation *target,
546-
Operation *source) {
547-
auto targetOp = dyn_cast<scf::ForOp>(target);
548-
auto sourceOp = dyn_cast<scf::ForOp>(source);
549-
if (!targetOp || !sourceOp)
550-
return false;
551-
552-
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
553-
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
554-
targetOp.getStep() == sourceOp.getStep();
555-
}
556-
557452
DiagnosedSilenceableFailure
558453
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
559454
transform::TransformResults &results,
@@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
569464
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
570465
}
571466

572-
Operation *target = *targetOps.begin();
573-
Operation *source = *sourceOps.begin();
467+
auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
468+
auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
469+
if (!target || !source)
470+
return emitSilenceableFailure(target->getLoc())
471+
<< "target or source is not a loop op";
574472

575-
// Check if the target and source are siblings.
576-
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
577-
if (!diag.succeeded())
578-
return diag;
473+
// Check if loops can be fused
474+
Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
475+
if (!mlir::checkFusionStructuralLegality(target, source, diag))
476+
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
579477

580478
Operation *fusedLoop;
581-
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
582-
if (isForWithIdenticalConfiguration(target, source)) {
479+
// TODO: Support fusion for loop-like ops besides scf.for, scf.forall
480+
// and scf.parallel.
481+
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
583482
fusedLoop = fuseIndependentSiblingForLoops(
584483
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
585-
} else if (isForallWithIdenticalConfiguration(target, source)) {
484+
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
586485
fusedLoop = fuseIndependentSiblingForallLoops(
587486
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
487+
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
488+
fusedLoop = fuseIndependentSiblingParallelLoops(
489+
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
588490
} else
589491
return emitSilenceableFailure(target->getLoc())
590-
<< "operations cannot be fused";
492+
<< "unsupported loop type for fusion";
591493

592494
assert(fusedLoop && "failed to fuse operations");
593495

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

+6-74
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/SCF/IR/SCF.h"
1818
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/IRMapping.h"
2122
#include "mlir/IR/OpDefinition.h"
@@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
3738
return walkResult.wasInterrupted();
3839
}
3940

40-
/// Verify equal iteration spaces.
41-
static bool equalIterationSpaces(ParallelOp firstPloop,
42-
ParallelOp secondPloop) {
43-
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44-
return false;
45-
46-
auto matchOperands = [&](const OperandRange &lhs,
47-
const OperandRange &rhs) -> bool {
48-
// TODO: Extend this to support aliases and equal constants.
49-
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50-
};
51-
return matchOperands(firstPloop.getLowerBound(),
52-
secondPloop.getLowerBound()) &&
53-
matchOperands(firstPloop.getUpperBound(),
54-
secondPloop.getUpperBound()) &&
55-
matchOperands(firstPloop.getStep(), secondPloop.getStep());
56-
}
57-
5841
/// Checks if the parallel loops have mixed access to the same buffers. Returns
5942
/// `true` if the first parallel loop writes to the same indices that the second
6043
/// loop reads.
@@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
153136
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154137
const IRMapping &firstToSecondPloopIndices,
155138
llvm::function_ref<bool(Value, Value)> mayAlias) {
139+
Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
156140
return !hasNestedParallelOp(firstPloop) &&
157141
!hasNestedParallelOp(secondPloop) &&
158-
equalIterationSpaces(firstPloop, secondPloop) &&
142+
checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
159143
succeeded(verifyDependencies(firstPloop, secondPloop,
160144
firstToSecondPloopIndices, mayAlias));
161145
}
@@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
174158
mayAlias))
175159
return;
176160

177-
DominanceInfo dom;
178-
// We are fusing first loop into second, make sure there are no users of the
179-
// first loop results between loops.
180-
for (Operation *user : firstPloop->getUsers())
181-
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182-
return;
183-
184-
ValueRange inits1 = firstPloop.getInitVals();
185-
ValueRange inits2 = secondPloop.getInitVals();
186-
187-
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188-
newInitVars.append(inits2.begin(), inits2.end());
189-
190-
IRRewriter b(builder);
191-
b.setInsertionPoint(secondPloop);
192-
auto newSecondPloop = b.create<ParallelOp>(
193-
secondPloop.getLoc(), secondPloop.getLowerBound(),
194-
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195-
196-
Block *newBlock = newSecondPloop.getBody();
197-
auto term1 = cast<ReduceOp>(block1->getTerminator());
198-
auto term2 = cast<ReduceOp>(block2->getTerminator());
199-
200-
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
201-
newBlock->getArguments());
202-
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
203-
newBlock->getArguments());
204-
205-
ValueRange results = newSecondPloop.getResults();
206-
if (!results.empty()) {
207-
b.setInsertionPointToEnd(newBlock);
208-
209-
ValueRange reduceArgs1 = term1.getOperands();
210-
ValueRange reduceArgs2 = term2.getOperands();
211-
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212-
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
213-
214-
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215-
216-
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217-
term1.getReductions(), term2.getReductions()))) {
218-
Block &oldRedBlock = reg.front();
219-
Block &newRedBlock = newReduceOp.getReductions()[i].front();
220-
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221-
newRedBlock.getArguments());
222-
}
223-
224-
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225-
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
226-
}
227-
term1->erase();
228-
term2->erase();
229-
firstPloop.erase();
230-
secondPloop.erase();
231-
secondPloop = newSecondPloop;
161+
IRRewriter rewriter(builder);
162+
secondPloop = mlir::fuseIndependentSiblingParallelLoops(
163+
firstPloop, secondPloop, rewriter);
232164
}
233165

234166
void mlir::scf::naivelyFuseParallelOps(

0 commit comments

Comments
 (0)