@@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
261
261
return 1 ;
262
262
};
263
263
264
- std::optional<int64_t > ubConstant = getConstantIntValue (forOp.getUpperBound ());
265
- std::optional<int64_t > lbConstant = getConstantIntValue (forOp.getLowerBound ());
264
+ std::optional<int64_t > ubConstant =
265
+ getConstantIntValue (forOp.getUpperBound ());
266
+ std::optional<int64_t > lbConstant =
267
+ getConstantIntValue (forOp.getLowerBound ());
266
268
DenseMap<Operation *, unsigned > opCycles;
267
269
std::map<unsigned , std::vector<Operation *>> wrappedSchedule;
268
270
for (Operation &op : forOp.getBody ()->getOperations ()) {
@@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
447
449
// LoopFuseSiblingOp
448
450
// ===----------------------------------------------------------------------===//
449
451
450
- // / Check if `target` and `source` are siblings, in the context that `target`
451
- // / is being fused into `source`.
452
- // /
453
- // / This is a simple check that just checks if both operations are in the same
454
- // / block and some checks to ensure that the fused IR does not violate
455
- // / dominance.
456
- static DiagnosedSilenceableFailure isOpSibling (Operation *target,
457
- Operation *source) {
458
- // Check if both operations are same.
459
- if (target == source)
460
- return emitSilenceableFailure (source)
461
- << " target and source need to be different loops" ;
462
-
463
- // Check if both operations are in the same block.
464
- if (target->getBlock () != source->getBlock ())
465
- return emitSilenceableFailure (source)
466
- << " target and source are not in the same block" ;
467
-
468
- // Check if fusion will violate dominance.
469
- DominanceInfo domInfo (source);
470
- if (target->isBeforeInBlock (source)) {
471
- // Since `target` is before `source`, all users of results of `target`
472
- // need to be dominated by `source`.
473
- for (Operation *user : target->getUsers ()) {
474
- if (!domInfo.properlyDominates (source, user, /* enclosingOpOk=*/ false )) {
475
- return emitSilenceableFailure (target)
476
- << " user of results of target should be properly dominated by "
477
- " source" ;
478
- }
479
- }
480
- } else {
481
- // Since `target` is after `source`, all values used by `target` need
482
- // to dominate `source`.
483
-
484
- // Check if operands of `target` are dominated by `source`.
485
- for (Value operand : target->getOperands ()) {
486
- Operation *operandOp = operand.getDefiningOp ();
487
- // Operands without defining operations are block arguments. When `target`
488
- // and `source` occur in the same block, these operands dominate `source`.
489
- if (!operandOp)
490
- continue ;
491
-
492
- // Operand's defining operation should properly dominate `source`.
493
- if (!domInfo.properlyDominates (operandOp, source,
494
- /* enclosingOpOk=*/ false ))
495
- return emitSilenceableFailure (target)
496
- << " operands of target should be properly dominated by source" ;
497
- }
498
-
499
- // Check if values used by `target` are dominated by `source`.
500
- bool failed = false ;
501
- OpOperand *failedValue = nullptr ;
502
- visitUsedValuesDefinedAbove (target->getRegions (), [&](OpOperand *operand) {
503
- Operation *operandOp = operand->get ().getDefiningOp ();
504
- if (operandOp && !domInfo.properlyDominates (operandOp, source,
505
- /* enclosingOpOk=*/ false )) {
506
- // `operand` is not an argument of an enclosing block and the defining
507
- // op of `operand` is outside `target` but does not dominate `source`.
508
- failed = true ;
509
- failedValue = operand;
510
- }
511
- });
512
-
513
- if (failed)
514
- return emitSilenceableFailure (failedValue->getOwner ())
515
- << " values used inside regions of target should be properly "
516
- " dominated by source" ;
517
- }
518
-
519
- return DiagnosedSilenceableFailure::success ();
520
- }
521
-
522
- // / Check if `target` scf.forall can be fused into `source` scf.forall.
523
- // /
524
- // / This simply checks if both loops have the same bounds, steps and mapping.
525
- // / No attempt is made at checking that the side effects of `target` and
526
- // / `source` are independent of each other.
527
- static bool isForallWithIdenticalConfiguration (Operation *target,
528
- Operation *source) {
529
- auto targetOp = dyn_cast<scf::ForallOp>(target);
530
- auto sourceOp = dyn_cast<scf::ForallOp>(source);
531
- if (!targetOp || !sourceOp)
532
- return false ;
533
-
534
- return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
535
- targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
536
- targetOp.getMixedStep () == sourceOp.getMixedStep () &&
537
- targetOp.getMapping () == sourceOp.getMapping ();
538
- }
539
-
540
- // / Check if `target` scf.for can be fused into `source` scf.for.
541
- // /
542
- // / This simply checks if both loops have the same bounds and steps. No attempt
543
- // / is made at checking that the side effects of `target` and `source` are
544
- // / independent of each other.
545
- static bool isForWithIdenticalConfiguration (Operation *target,
546
- Operation *source) {
547
- auto targetOp = dyn_cast<scf::ForOp>(target);
548
- auto sourceOp = dyn_cast<scf::ForOp>(source);
549
- if (!targetOp || !sourceOp)
550
- return false ;
551
-
552
- return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
553
- targetOp.getUpperBound () == sourceOp.getUpperBound () &&
554
- targetOp.getStep () == sourceOp.getStep ();
555
- }
556
-
557
452
DiagnosedSilenceableFailure
558
453
transform::LoopFuseSiblingOp::apply (transform::TransformRewriter &rewriter,
559
454
transform::TransformResults &results,
@@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
569
464
<< " source handle (got " << llvm::range_size (sourceOps) << " )" ;
570
465
}
571
466
572
- Operation *target = *targetOps.begin ();
573
- Operation *source = *sourceOps.begin ();
467
+ auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin ());
468
+ auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin ());
469
+ if (!target || !source)
470
+ return emitSilenceableFailure (target->getLoc ())
471
+ << " target or source is not a loop op" ;
574
472
575
- // Check if the target and source are siblings.
576
- DiagnosedSilenceableFailure diag = isOpSibling (target, source );
577
- if (!diag. succeeded ( ))
578
- return diag;
473
+ // Check if loops can be fused
474
+ Diagnostic diag (target. getLoc (), DiagnosticSeverity::Error );
475
+ if (!mlir::checkFusionStructuralLegality (target, source, diag ))
476
+ return DiagnosedSilenceableFailure::silenceableFailure ( std::move ( diag)) ;
579
477
580
478
Operation *fusedLoop;
581
- // / TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
582
- if (isForWithIdenticalConfiguration (target, source)) {
479
+ // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
480
+ // and scf.parallel.
481
+ if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
583
482
fusedLoop = fuseIndependentSiblingForLoops (
584
483
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
585
- } else if (isForallWithIdenticalConfiguration (target, source)) {
484
+ } else if (isa<scf::ForallOp> (target) && isa<scf::ForallOp>( source)) {
586
485
fusedLoop = fuseIndependentSiblingForallLoops (
587
486
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
487
+ } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
488
+ fusedLoop = fuseIndependentSiblingParallelLoops (
489
+ cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
588
490
} else
589
491
return emitSilenceableFailure (target->getLoc ())
590
- << " operations cannot be fused " ;
492
+ << " unsupported loop type for fusion " ;
591
493
592
494
assert (fusedLoop && " failed to fuse operations" );
593
495
0 commit comments