@@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp,
261
261
return 1 ;
262
262
};
263
263
264
- std::optional<int64_t > ubConstant =
265
- getConstantIntValue (forOp.getUpperBound ());
266
- std::optional<int64_t > lbConstant =
267
- getConstantIntValue (forOp.getLowerBound ());
264
+ std::optional<int64_t > ubConstant = getConstantIntValue (forOp.getUpperBound ());
265
+ std::optional<int64_t > lbConstant = getConstantIntValue (forOp.getLowerBound ());
268
266
DenseMap<Operation *, unsigned > opCycles;
269
267
std::map<unsigned , std::vector<Operation *>> wrappedSchedule;
270
268
for (Operation &op : forOp.getBody ()->getOperations ()) {
@@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects(
449
447
// LoopFuseSiblingOp
450
448
// ===----------------------------------------------------------------------===//
451
449
450
+ // / Check if `target` and `source` are siblings, in the context that `target`
451
+ // / is being fused into `source`.
452
+ // /
453
+ // / This is a simple check that just checks if both operations are in the same
454
+ // / block and some checks to ensure that the fused IR does not violate
455
+ // / dominance.
456
+ static DiagnosedSilenceableFailure isOpSibling (Operation *target,
457
+ Operation *source) {
458
+ // Check if both operations are same.
459
+ if (target == source)
460
+ return emitSilenceableFailure (source)
461
+ << " target and source need to be different loops" ;
462
+
463
+ // Check if both operations are in the same block.
464
+ if (target->getBlock () != source->getBlock ())
465
+ return emitSilenceableFailure (source)
466
+ << " target and source are not in the same block" ;
467
+
468
+ // Check if fusion will violate dominance.
469
+ DominanceInfo domInfo (source);
470
+ if (target->isBeforeInBlock (source)) {
471
+ // Since `target` is before `source`, all users of results of `target`
472
+ // need to be dominated by `source`.
473
+ for (Operation *user : target->getUsers ()) {
474
+ if (!domInfo.properlyDominates (source, user, /* enclosingOpOk=*/ false )) {
475
+ return emitSilenceableFailure (target)
476
+ << " user of results of target should be properly dominated by "
477
+ " source" ;
478
+ }
479
+ }
480
+ } else {
481
+ // Since `target` is after `source`, all values used by `target` need
482
+ // to dominate `source`.
483
+
484
+ // Check if operands of `target` are dominated by `source`.
485
+ for (Value operand : target->getOperands ()) {
486
+ Operation *operandOp = operand.getDefiningOp ();
487
+ // Operands without defining operations are block arguments. When `target`
488
+ // and `source` occur in the same block, these operands dominate `source`.
489
+ if (!operandOp)
490
+ continue ;
491
+
492
+ // Operand's defining operation should properly dominate `source`.
493
+ if (!domInfo.properlyDominates (operandOp, source,
494
+ /* enclosingOpOk=*/ false ))
495
+ return emitSilenceableFailure (target)
496
+ << " operands of target should be properly dominated by source" ;
497
+ }
498
+
499
+ // Check if values used by `target` are dominated by `source`.
500
+ bool failed = false ;
501
+ OpOperand *failedValue = nullptr ;
502
+ visitUsedValuesDefinedAbove (target->getRegions (), [&](OpOperand *operand) {
503
+ Operation *operandOp = operand->get ().getDefiningOp ();
504
+ if (operandOp && !domInfo.properlyDominates (operandOp, source,
505
+ /* enclosingOpOk=*/ false )) {
506
+ // `operand` is not an argument of an enclosing block and the defining
507
+ // op of `operand` is outside `target` but does not dominate `source`.
508
+ failed = true ;
509
+ failedValue = operand;
510
+ }
511
+ });
512
+
513
+ if (failed)
514
+ return emitSilenceableFailure (failedValue->getOwner ())
515
+ << " values used inside regions of target should be properly "
516
+ " dominated by source" ;
517
+ }
518
+
519
+ return DiagnosedSilenceableFailure::success ();
520
+ }
521
+
522
+ // / Check if `target` scf.forall can be fused into `source` scf.forall.
523
+ // /
524
+ // / This simply checks if both loops have the same bounds, steps and mapping.
525
+ // / No attempt is made at checking that the side effects of `target` and
526
+ // / `source` are independent of each other.
527
+ static bool isForallWithIdenticalConfiguration (Operation *target,
528
+ Operation *source) {
529
+ auto targetOp = dyn_cast<scf::ForallOp>(target);
530
+ auto sourceOp = dyn_cast<scf::ForallOp>(source);
531
+ if (!targetOp || !sourceOp)
532
+ return false ;
533
+
534
+ return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
535
+ targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
536
+ targetOp.getMixedStep () == sourceOp.getMixedStep () &&
537
+ targetOp.getMapping () == sourceOp.getMapping ();
538
+ }
539
+
540
+ // / Check if `target` scf.for can be fused into `source` scf.for.
541
+ // /
542
+ // / This simply checks if both loops have the same bounds and steps. No attempt
543
+ // / is made at checking that the side effects of `target` and `source` are
544
+ // / independent of each other.
545
+ static bool isForWithIdenticalConfiguration (Operation *target,
546
+ Operation *source) {
547
+ auto targetOp = dyn_cast<scf::ForOp>(target);
548
+ auto sourceOp = dyn_cast<scf::ForOp>(source);
549
+ if (!targetOp || !sourceOp)
550
+ return false ;
551
+
552
+ return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
553
+ targetOp.getUpperBound () == sourceOp.getUpperBound () &&
554
+ targetOp.getStep () == sourceOp.getStep ();
555
+ }
556
+
452
557
DiagnosedSilenceableFailure
453
558
transform::LoopFuseSiblingOp::apply (transform::TransformRewriter &rewriter,
454
559
transform::TransformResults &results,
@@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
464
569
<< " source handle (got " << llvm::range_size (sourceOps) << " )" ;
465
570
}
466
571
467
- auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin ());
468
- auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin ());
469
- if (!target || !source)
470
- return emitSilenceableFailure (target->getLoc ())
471
- << " target or source is not a loop op" ;
572
+ Operation *target = *targetOps.begin ();
573
+ Operation *source = *sourceOps.begin ();
472
574
473
- // Check if loops can be fused
474
- Diagnostic diag (target. getLoc (), DiagnosticSeverity::Error );
475
- if (!mlir::checkFusionStructuralLegality (target, source, diag))
476
- return DiagnosedSilenceableFailure::silenceableFailure ( std::move ( diag)) ;
575
+ // Check if the target and source are siblings.
576
+ DiagnosedSilenceableFailure diag = isOpSibling (target, source );
577
+ if (!diag. succeeded ( ))
578
+ return diag;
477
579
478
580
Operation *fusedLoop;
479
- // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
480
- // and scf.parallel.
481
- if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
581
+ // / TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
582
+ if (isForWithIdenticalConfiguration (target, source)) {
482
583
fusedLoop = fuseIndependentSiblingForLoops (
483
584
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
484
- } else if (isa<scf::ForallOp> (target) && isa<scf::ForallOp>( source)) {
585
+ } else if (isForallWithIdenticalConfiguration (target, source)) {
485
586
fusedLoop = fuseIndependentSiblingForallLoops (
486
587
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
487
- } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
488
- fusedLoop = fuseIndependentSiblingParallelLoops (
489
- cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
490
588
} else
491
589
return emitSilenceableFailure (target->getLoc ())
492
- << " unsupported loop type for fusion " ;
590
+ << " operations cannot be fused " ;
493
591
494
592
assert (fusedLoop && " failed to fuse operations" );
495
593
0 commit comments