Skip to content

Commit 3bdeb05

Browse files
committed
Revert "Refactor LoopFuseSiblingOp and support parallel fusion (llvm#94391)"
This reverts commit 6820b08.
1 parent 0cfd03a commit 3bdeb05

File tree

9 files changed

+283
-586
lines changed

9 files changed

+283
-586
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

+1-2
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,7 @@ def ForallOp : SCF_Op<"forall", [
303303
DeclareOpInterfaceMethods<LoopLikeOpInterface,
304304
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
305305
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
306-
"replaceWithAdditionalYields", "promoteIfSingleIteration",
307-
"yieldTiledValuesAndReplace"]>,
306+
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
308307
RecursiveMemoryEffects,
309308
SingleBlockImplicitTerminator<"scf::InParallelOp">,
310309
DeclareOpInterfaceMethods<RegionBranchOpInterface>,

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

-20
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,6 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
181181
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
182182
scf::ForOp root);
183183

184-
//===----------------------------------------------------------------------===//
185-
// Fusion related helpers
186-
//===----------------------------------------------------------------------===//
187-
188-
/// Check structural compatibility between two loops such as iteration space
189-
/// and dominance.
190-
bool checkFusionStructuralLegality(LoopLikeOpInterface target,
191-
LoopLikeOpInterface source,
192-
Diagnostic &diag);
193-
194184
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
195185
/// `source`. Assumes that the given loops are siblings and are independent of
196186
/// each other.
@@ -212,16 +202,6 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
212202
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
213203
RewriterBase &rewriter);
214204

215-
/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
216-
/// `source`. Assumes that the given loops are siblings and are independent of
217-
/// each other.
218-
///
219-
/// This function does not perform any legality checks and simply fuses the
220-
/// loops. The caller is responsible for ensuring that the loops are legal to
221-
/// fuse.
222-
scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
223-
scf::ParallelOp source,
224-
RewriterBase &rewriter);
225205
} // namespace mlir
226206

227207
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/include/mlir/Interfaces/LoopLikeInterface.h

-20
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,4 @@ struct JamBlockGatherer {
9090
/// Include the generated interface declarations.
9191
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
9292

93-
namespace mlir {
94-
/// A function that rewrites `target`'s terminator as a teminator obtained by
95-
/// fusing `source` into `target`.
96-
using FuseTerminatorFn =
97-
function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
98-
LoopLikeOpInterface &target, IRMapping mapping)>;
99-
100-
/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
101-
/// `target`. The `NewYieldValuesFn` callback is used to pass to the
102-
/// `replaceWithAdditionalYields` interface method to replace the loop with a
103-
/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
104-
/// callback is repsonsible for updating the fused loop terminator.
105-
LoopLikeOpInterface createFused(LoopLikeOpInterface target,
106-
LoopLikeOpInterface source,
107-
RewriterBase &rewriter,
108-
NewYieldValuesFn newYieldValuesFn,
109-
FuseTerminatorFn fuseTerminatorFn);
110-
111-
} // namespace mlir
112-
11393
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_

mlir/lib/Dialect/SCF/IR/SCF.cpp

-38
Original file line numberDiff line numberDiff line change
@@ -618,44 +618,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
618618

619619
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620620

621-
FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
622-
RewriterBase &rewriter, ValueRange newInitOperands,
623-
bool replaceInitOperandUsesInLoop,
624-
const NewYieldValuesFn &newYieldValuesFn) {
625-
// Create a new loop before the existing one, with the extra operands.
626-
OpBuilder::InsertionGuard g(rewriter);
627-
rewriter.setInsertionPoint(getOperation());
628-
SmallVector<Value> inits(getOutputs());
629-
llvm::append_range(inits, newInitOperands);
630-
scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
631-
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
632-
inits, getMapping(),
633-
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
634-
635-
// Move the loop body to the new op.
636-
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
637-
newLoop.getBody()->getArguments().take_front(
638-
getBody()->getNumArguments()));
639-
640-
if (replaceInitOperandUsesInLoop) {
641-
// Replace all uses of `newInitOperands` with the corresponding basic block
642-
// arguments.
643-
for (auto &&[newOperand, oldOperand] :
644-
llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
645-
newInitOperands.size()))) {
646-
rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
647-
Operation *user = use.getOwner();
648-
return newLoop->isProperAncestor(user);
649-
});
650-
}
651-
}
652-
653-
// Replace the old loop.
654-
rewriter.replaceOp(getOperation(),
655-
newLoop->getResults().take_front(getNumResults()));
656-
return cast<LoopLikeOpInterface>(newLoop.getOperation());
657-
}
658-
659621
/// Promotes the loop body of a forallOp to its containing block if it can be
660622
/// determined that the loop has a single iteration.
661623
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

+119-21
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp,
261261
return 1;
262262
};
263263

264-
std::optional<int64_t> ubConstant =
265-
getConstantIntValue(forOp.getUpperBound());
266-
std::optional<int64_t> lbConstant =
267-
getConstantIntValue(forOp.getLowerBound());
264+
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
265+
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
268266
DenseMap<Operation *, unsigned> opCycles;
269267
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
270268
for (Operation &op : forOp.getBody()->getOperations()) {
@@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects(
449447
// LoopFuseSiblingOp
450448
//===----------------------------------------------------------------------===//
451449

450+
/// Check if `target` and `source` are siblings, in the context that `target`
451+
/// is being fused into `source`.
452+
///
453+
/// This is a simple check that just checks if both operations are in the same
454+
/// block and some checks to ensure that the fused IR does not violate
455+
/// dominance.
456+
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
457+
Operation *source) {
458+
// Check if both operations are same.
459+
if (target == source)
460+
return emitSilenceableFailure(source)
461+
<< "target and source need to be different loops";
462+
463+
// Check if both operations are in the same block.
464+
if (target->getBlock() != source->getBlock())
465+
return emitSilenceableFailure(source)
466+
<< "target and source are not in the same block";
467+
468+
// Check if fusion will violate dominance.
469+
DominanceInfo domInfo(source);
470+
if (target->isBeforeInBlock(source)) {
471+
// Since `target` is before `source`, all users of results of `target`
472+
// need to be dominated by `source`.
473+
for (Operation *user : target->getUsers()) {
474+
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
475+
return emitSilenceableFailure(target)
476+
<< "user of results of target should be properly dominated by "
477+
"source";
478+
}
479+
}
480+
} else {
481+
// Since `target` is after `source`, all values used by `target` need
482+
// to dominate `source`.
483+
484+
// Check if operands of `target` are dominated by `source`.
485+
for (Value operand : target->getOperands()) {
486+
Operation *operandOp = operand.getDefiningOp();
487+
// Operands without defining operations are block arguments. When `target`
488+
// and `source` occur in the same block, these operands dominate `source`.
489+
if (!operandOp)
490+
continue;
491+
492+
// Operand's defining operation should properly dominate `source`.
493+
if (!domInfo.properlyDominates(operandOp, source,
494+
/*enclosingOpOk=*/false))
495+
return emitSilenceableFailure(target)
496+
<< "operands of target should be properly dominated by source";
497+
}
498+
499+
// Check if values used by `target` are dominated by `source`.
500+
bool failed = false;
501+
OpOperand *failedValue = nullptr;
502+
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
503+
Operation *operandOp = operand->get().getDefiningOp();
504+
if (operandOp && !domInfo.properlyDominates(operandOp, source,
505+
/*enclosingOpOk=*/false)) {
506+
// `operand` is not an argument of an enclosing block and the defining
507+
// op of `operand` is outside `target` but does not dominate `source`.
508+
failed = true;
509+
failedValue = operand;
510+
}
511+
});
512+
513+
if (failed)
514+
return emitSilenceableFailure(failedValue->getOwner())
515+
<< "values used inside regions of target should be properly "
516+
"dominated by source";
517+
}
518+
519+
return DiagnosedSilenceableFailure::success();
520+
}
521+
522+
/// Check if `target` scf.forall can be fused into `source` scf.forall.
523+
///
524+
/// This simply checks if both loops have the same bounds, steps and mapping.
525+
/// No attempt is made at checking that the side effects of `target` and
526+
/// `source` are independent of each other.
527+
static bool isForallWithIdenticalConfiguration(Operation *target,
528+
Operation *source) {
529+
auto targetOp = dyn_cast<scf::ForallOp>(target);
530+
auto sourceOp = dyn_cast<scf::ForallOp>(source);
531+
if (!targetOp || !sourceOp)
532+
return false;
533+
534+
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
535+
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
536+
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
537+
targetOp.getMapping() == sourceOp.getMapping();
538+
}
539+
540+
/// Check if `target` scf.for can be fused into `source` scf.for.
541+
///
542+
/// This simply checks if both loops have the same bounds and steps. No attempt
543+
/// is made at checking that the side effects of `target` and `source` are
544+
/// independent of each other.
545+
static bool isForWithIdenticalConfiguration(Operation *target,
546+
Operation *source) {
547+
auto targetOp = dyn_cast<scf::ForOp>(target);
548+
auto sourceOp = dyn_cast<scf::ForOp>(source);
549+
if (!targetOp || !sourceOp)
550+
return false;
551+
552+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
553+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
554+
targetOp.getStep() == sourceOp.getStep();
555+
}
556+
452557
DiagnosedSilenceableFailure
453558
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
454559
transform::TransformResults &results,
@@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
464569
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
465570
}
466571

467-
auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
468-
auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
469-
if (!target || !source)
470-
return emitSilenceableFailure(target->getLoc())
471-
<< "target or source is not a loop op";
572+
Operation *target = *targetOps.begin();
573+
Operation *source = *sourceOps.begin();
472574

473-
// Check if loops can be fused
474-
Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
475-
if (!mlir::checkFusionStructuralLegality(target, source, diag))
476-
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
575+
// Check if the target and source are siblings.
576+
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
577+
if (!diag.succeeded())
578+
return diag;
477579

478580
Operation *fusedLoop;
479-
// TODO: Support fusion for loop-like ops besides scf.for, scf.forall
480-
// and scf.parallel.
481-
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
581+
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
582+
if (isForWithIdenticalConfiguration(target, source)) {
482583
fusedLoop = fuseIndependentSiblingForLoops(
483584
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
484-
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
585+
} else if (isForallWithIdenticalConfiguration(target, source)) {
485586
fusedLoop = fuseIndependentSiblingForallLoops(
486587
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
487-
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
488-
fusedLoop = fuseIndependentSiblingParallelLoops(
489-
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
490588
} else
491589
return emitSilenceableFailure(target->getLoc())
492-
<< "unsupported loop type for fusion";
590+
<< "operations cannot be fused";
493591

494592
assert(fusedLoop && "failed to fuse operations");
495593

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

+74-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/SCF/IR/SCF.h"
1818
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19-
#include "mlir/Dialect/SCF/Utils/Utils.h"
2019
#include "mlir/IR/Builders.h"
2120
#include "mlir/IR/IRMapping.h"
2221
#include "mlir/IR/OpDefinition.h"
@@ -38,6 +37,24 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
3837
return walkResult.wasInterrupted();
3938
}
4039

40+
/// Verify equal iteration spaces.
41+
static bool equalIterationSpaces(ParallelOp firstPloop,
42+
ParallelOp secondPloop) {
43+
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44+
return false;
45+
46+
auto matchOperands = [&](const OperandRange &lhs,
47+
const OperandRange &rhs) -> bool {
48+
// TODO: Extend this to support aliases and equal constants.
49+
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50+
};
51+
return matchOperands(firstPloop.getLowerBound(),
52+
secondPloop.getLowerBound()) &&
53+
matchOperands(firstPloop.getUpperBound(),
54+
secondPloop.getUpperBound()) &&
55+
matchOperands(firstPloop.getStep(), secondPloop.getStep());
56+
}
57+
4158
/// Checks if the parallel loops have mixed access to the same buffers. Returns
4259
/// `true` if the first parallel loop writes to the same indices that the second
4360
/// loop reads.
@@ -136,10 +153,9 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
136153
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
137154
const IRMapping &firstToSecondPloopIndices,
138155
llvm::function_ref<bool(Value, Value)> mayAlias) {
139-
Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
140156
return !hasNestedParallelOp(firstPloop) &&
141157
!hasNestedParallelOp(secondPloop) &&
142-
checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
158+
equalIterationSpaces(firstPloop, secondPloop) &&
143159
succeeded(verifyDependencies(firstPloop, secondPloop,
144160
firstToSecondPloopIndices, mayAlias));
145161
}
@@ -158,9 +174,61 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
158174
mayAlias))
159175
return;
160176

161-
IRRewriter rewriter(builder);
162-
secondPloop = mlir::fuseIndependentSiblingParallelLoops(
163-
firstPloop, secondPloop, rewriter);
177+
DominanceInfo dom;
178+
// We are fusing first loop into second, make sure there are no users of the
179+
// first loop results between loops.
180+
for (Operation *user : firstPloop->getUsers())
181+
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182+
return;
183+
184+
ValueRange inits1 = firstPloop.getInitVals();
185+
ValueRange inits2 = secondPloop.getInitVals();
186+
187+
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188+
newInitVars.append(inits2.begin(), inits2.end());
189+
190+
IRRewriter b(builder);
191+
b.setInsertionPoint(secondPloop);
192+
auto newSecondPloop = b.create<ParallelOp>(
193+
secondPloop.getLoc(), secondPloop.getLowerBound(),
194+
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195+
196+
Block *newBlock = newSecondPloop.getBody();
197+
auto term1 = cast<ReduceOp>(block1->getTerminator());
198+
auto term2 = cast<ReduceOp>(block2->getTerminator());
199+
200+
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
201+
newBlock->getArguments());
202+
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
203+
newBlock->getArguments());
204+
205+
ValueRange results = newSecondPloop.getResults();
206+
if (!results.empty()) {
207+
b.setInsertionPointToEnd(newBlock);
208+
209+
ValueRange reduceArgs1 = term1.getOperands();
210+
ValueRange reduceArgs2 = term2.getOperands();
211+
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212+
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
213+
214+
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215+
216+
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217+
term1.getReductions(), term2.getReductions()))) {
218+
Block &oldRedBlock = reg.front();
219+
Block &newRedBlock = newReduceOp.getReductions()[i].front();
220+
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221+
newRedBlock.getArguments());
222+
}
223+
224+
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225+
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
226+
}
227+
term1->erase();
228+
term2->erase();
229+
firstPloop.erase();
230+
secondPloop.erase();
231+
secondPloop = newSecondPloop;
164232
}
165233

166234
void mlir::scf::naivelyFuseParallelOps(

0 commit comments

Comments
 (0)