@@ -305,7 +305,10 @@ bool IsHLWaveSensitive(Function *F) {
305
305
return attrSet.hasAttribute (AttributeSet::FunctionIndex, HLWaveSensitive);
306
306
}
307
307
308
- std::string GetHLFullName (HLOpcodeGroup op, unsigned opcode) {
308
+ static std::string GetHLFunctionAttributeMangling (const AttributeSet &attribs);
309
+
310
+ std::string GetHLFullName (HLOpcodeGroup op, unsigned opcode,
311
+ const AttributeSet &attribs = AttributeSet()) {
309
312
assert (op != HLOpcodeGroup::HLExtIntrinsic && " else table name should be used" );
310
313
std::string opName = GetHLOpcodeGroupFullName (op).str () + " ." ;
311
314
@@ -321,22 +324,26 @@ std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
321
324
case HLOpcodeGroup::HLIntrinsic: {
322
325
// intrinsic with same signature will share the funciton now
323
326
// The opcode is in arg0.
324
- return opName;
327
+ return opName + GetHLFunctionAttributeMangling (attribs) ;
325
328
}
326
329
case HLOpcodeGroup::HLMatLoadStore: {
327
330
HLMatLoadStoreOpcode matOp = static_cast <HLMatLoadStoreOpcode>(opcode);
328
331
return opName + GetHLOpcodeName (matOp).str ();
329
332
}
330
333
case HLOpcodeGroup::HLSubscript: {
331
334
HLSubscriptOpcode subOp = static_cast <HLSubscriptOpcode>(opcode);
332
- return opName + GetHLOpcodeName (subOp).str ();
335
+ return opName + GetHLOpcodeName (subOp).str () + " ." +
336
+ GetHLFunctionAttributeMangling (attribs);
333
337
}
334
338
case HLOpcodeGroup::HLCast: {
335
339
HLCastOpcode castOp = static_cast <HLCastOpcode>(opcode);
336
340
return opName + GetHLOpcodeName (castOp).str ();
337
341
}
338
- default :
342
+ case HLOpcodeGroup::HLCreateHandle:
343
+ case HLOpcodeGroup::HLAnnotateHandle:
339
344
return opName;
345
+ default :
346
+ return opName + GetHLFunctionAttributeMangling (attribs);
340
347
}
341
348
}
342
349
@@ -417,38 +424,59 @@ HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode) {
417
424
}
418
425
}
419
426
420
- static void SetHLFunctionAttribute (Function *F, HLOpcodeGroup group,
421
- unsigned opcode) {
422
- F->addFnAttr (Attribute::NoUnwind);
427
+ static AttributeSet
428
+ GetHLFunctionAttributes (LLVMContext &C, FunctionType *funcTy,
429
+ const AttributeSet &origAttribs,
430
+ HLOpcodeGroup group, unsigned opcode) {
431
+ // Always add nounwind
432
+ AttributeSet attribs =
433
+ AttributeSet::get (C, AttributeSet::FunctionIndex,
434
+ ArrayRef<Attribute::AttrKind>({Attribute::NoUnwind}));
435
+
436
+ auto addAttr = [&](Attribute::AttrKind Attr) {
437
+ if (!attribs.hasAttribute (AttributeSet::FunctionIndex, Attr))
438
+ attribs = attribs.addAttribute (C, AttributeSet::FunctionIndex, Attr);
439
+ };
440
+ auto copyAttr = [&](Attribute::AttrKind Attr) {
441
+ if (origAttribs.hasAttribute (AttributeSet::FunctionIndex, Attr))
442
+ addAttr (Attr);
443
+ };
444
+ auto copyStrAttr = [&](StringRef Kind) {
445
+ if (origAttribs.hasAttribute (AttributeSet::FunctionIndex, Kind))
446
+ attribs = attribs.addAttribute (
447
+ C, AttributeSet::FunctionIndex, Kind,
448
+ origAttribs.getAttribute (AttributeSet::FunctionIndex, Kind)
449
+ .getValueAsString ());
450
+ };
451
+
452
+ // Copy attributes we preserve from the original function.
453
+ copyAttr (Attribute::ReadOnly);
454
+ copyAttr (Attribute::ReadNone);
455
+ copyStrAttr (HLWaveSensitive);
423
456
424
457
switch (group) {
425
458
case HLOpcodeGroup::HLUnOp:
426
459
case HLOpcodeGroup::HLBinOp:
427
460
case HLOpcodeGroup::HLCast:
428
461
case HLOpcodeGroup::HLSubscript:
429
- if (!F->hasFnAttribute (Attribute::ReadNone)) {
430
- F->addFnAttr (Attribute::ReadNone);
431
- }
462
+ addAttr (Attribute::ReadNone);
432
463
break ;
433
464
case HLOpcodeGroup::HLInit:
434
- if (!F->hasFnAttribute (Attribute::ReadNone))
435
- if (!F->getReturnType ()->isVoidTy ()) {
436
- F->addFnAttr (Attribute::ReadNone);
437
- }
465
+ if (!funcTy->getReturnType ()->isVoidTy ()) {
466
+ addAttr (Attribute::ReadNone);
467
+ }
438
468
break ;
439
469
case HLOpcodeGroup::HLMatLoadStore: {
440
470
HLMatLoadStoreOpcode matOp = static_cast <HLMatLoadStoreOpcode>(opcode);
441
471
if (matOp == HLMatLoadStoreOpcode::ColMatLoad ||
442
472
matOp == HLMatLoadStoreOpcode::RowMatLoad)
443
- if (!F->hasFnAttribute (Attribute::ReadOnly)) {
444
- F->addFnAttr (Attribute::ReadOnly);
445
- }
473
+ addAttr (Attribute::ReadOnly);
446
474
} break ;
447
475
case HLOpcodeGroup::HLCreateHandle: {
448
- F-> addFnAttr (Attribute::ReadNone);
476
+ addAttr (Attribute::ReadNone);
449
477
} break ;
450
478
case HLOpcodeGroup::HLAnnotateHandle: {
451
- F-> addFnAttr (Attribute::ReadNone);
479
+ addAttr (Attribute::ReadNone);
452
480
} break ;
453
481
case HLOpcodeGroup::HLIntrinsic: {
454
482
IntrinsicOp intrinsicOp = static_cast <IntrinsicOp>(opcode);
@@ -461,7 +489,7 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
461
489
case IntrinsicOp::IOP_GroupMemoryBarrier:
462
490
case IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync:
463
491
case IntrinsicOp::IOP_AllMemoryBarrier:
464
- F-> addFnAttr (Attribute::NoDuplicate);
492
+ addAttr (Attribute::NoDuplicate);
465
493
break ;
466
494
}
467
495
} break ;
@@ -472,6 +500,75 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
472
500
// No default attributes for these opcodes.
473
501
break ;
474
502
}
503
+ assert (!(attribs.hasAttribute (AttributeSet::FunctionIndex,
504
+ Attribute::ReadNone) &&
505
+ attribs.hasAttribute (AttributeSet::FunctionIndex,
506
+ Attribute::ReadOnly)) &&
507
+ " conflicting ReadNone and ReadOnly attributes" );
508
+ return attribs;
509
+ }
510
+
511
+ static std::string GetHLFunctionAttributeMangling (const AttributeSet &attribs) {
512
+ std::string mangledName;
513
+ raw_string_ostream mangledNameStr (mangledName);
514
+
515
+ // Capture for adding in canonical order later.
516
+ bool ReadNone = false ;
517
+ bool ReadOnly = false ;
518
+ bool NoDuplicate = false ;
519
+ bool WaveSensitive = false ;
520
+
521
+ // Ensure every function attribute is recognized.
522
+ for (unsigned Slot = 0 ; Slot < attribs.getNumSlots (); Slot++) {
523
+ if (attribs.getSlotIndex (Slot) == AttributeSet::FunctionIndex) {
524
+ for (auto it = attribs.begin (Slot), e = attribs.end (Slot); it != e;
525
+ it++) {
526
+ if (it->isEnumAttribute ()) {
527
+ switch (it->getKindAsEnum ()) {
528
+ case Attribute::ReadNone:
529
+ ReadNone = true ;
530
+ break ;
531
+ case Attribute::ReadOnly:
532
+ ReadOnly = true ;
533
+ break ;
534
+ case Attribute::NoDuplicate:
535
+ NoDuplicate = true ;
536
+ break ;
537
+ case Attribute::NoUnwind:
538
+ // All intrinsics have this attribute, so mangling is unaffected.
539
+ break ;
540
+ default :
541
+ assert (false && " unexpected attribute for HLOperation" );
542
+ }
543
+ } else if (it->isStringAttribute ()) {
544
+ StringRef Kind = it->getKindAsString ();
545
+ if (Kind == HLWaveSensitive) {
546
+ assert (it->getValueAsString () == " y" &&
547
+ " otherwise, unexpected value for WaveSensitive attribute" );
548
+ WaveSensitive = true ;
549
+ } else {
550
+ assert (false &&
551
+ " unexpected string function attribute for HLOperation" );
552
+ }
553
+ }
554
+ }
555
+ }
556
+ }
557
+
558
+ // Validate attribute combinations.
559
+ assert (!(ReadNone && ReadOnly) &&
560
+ " ReadNone and ReadOnly are mutually exclusive" );
561
+
562
+ // Add mangling in canonical order
563
+ if (NoDuplicate)
564
+ mangledNameStr << " nd" ;
565
+ if (ReadNone)
566
+ mangledNameStr << " rn" ;
567
+ if (ReadOnly)
568
+ mangledNameStr << " ro" ;
569
+ if (WaveSensitive)
570
+ mangledNameStr << " wave" ;
571
+ return mangledName;
475
572
}
476
573
477
574
@@ -497,7 +594,11 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
497
594
Function *GetOrCreateHLFunction (Module &M, FunctionType *funcTy,
498
595
HLOpcodeGroup group, StringRef *groupName,
499
596
StringRef *fnName, unsigned opcode,
500
- const AttributeSet &attribs) {
597
+ const AttributeSet &origAttribs) {
598
+ // Set/transfer all common attributes
599
+ AttributeSet attribs = GetHLFunctionAttributes (
600
+ M.getContext (), funcTy, origAttribs, group, opcode);
601
+
501
602
std::string mangledName;
502
603
raw_string_ostream mangledNameStr (mangledName);
503
604
if (group == HLOpcodeGroup::HLExtIntrinsic) {
@@ -506,33 +607,31 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
506
607
mangledNameStr << *groupName;
507
608
mangledNameStr << ' .' ;
508
609
mangledNameStr << *fnName;
610
+ attribs = attribs.addAttribute (M.getContext (), AttributeSet::FunctionIndex,
611
+ hlsl::HLPrefix, *groupName);
509
612
}
510
613
else {
511
- mangledNameStr << GetHLFullName (group, opcode);
512
- // Need to add wave sensitivity to name to prevent clashes with non-wave intrinsic
513
- if (attribs.hasAttribute (AttributeSet::FunctionIndex, HLWaveSensitive))
514
- mangledNameStr << " wave" ;
614
+ mangledNameStr << GetHLFullName (group, opcode, attribs);
515
615
mangledNameStr << ' .' ;
516
616
funcTy->print (mangledNameStr);
517
617
}
518
618
519
619
mangledNameStr.flush ();
520
620
521
- Function *F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy));
522
- if (group == HLOpcodeGroup::HLExtIntrinsic) {
523
- F->addFnAttr (hlsl::HLPrefix, *groupName);
621
+ // Avoid getOrInsertFunction to verify attributes and type without casting.
622
+ Function *F = cast_or_null<Function>(M.getNamedValue (mangledName));
623
+ if (F) {
624
+ assert (F->getFunctionType () == funcTy &&
625
+ " otherwise, function type mismatch not captured by mangling" );
626
+ // Compare attribute mangling to ensure function attributes are as expected.
627
+ assert (
628
+ GetHLFunctionAttributeMangling (F->getAttributes ().getFnAttributes ()) ==
629
+ GetHLFunctionAttributeMangling (attribs) &&
630
+ " otherwise, function attribute mismatch not captured by mangling" );
631
+ } else {
632
+ F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy, attribs));
524
633
}
525
634
526
- SetHLFunctionAttribute (F, group, opcode);
527
-
528
- // Copy attributes
529
- if (attribs.hasAttribute (AttributeSet::FunctionIndex, Attribute::ReadNone))
530
- F->addFnAttr (Attribute::ReadNone);
531
- if (attribs.hasAttribute (AttributeSet::FunctionIndex, Attribute::ReadOnly))
532
- F->addFnAttr (Attribute::ReadOnly);
533
- if (attribs.hasAttribute (AttributeSet::FunctionIndex, HLWaveSensitive))
534
- F->addFnAttr (HLWaveSensitive, " y" );
535
-
536
635
return F;
537
636
}
538
637
@@ -541,15 +640,17 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
541
640
Function *GetOrCreateHLFunctionWithBody (Module &M, FunctionType *funcTy,
542
641
HLOpcodeGroup group, unsigned opcode,
543
642
StringRef name) {
544
- std::string operatorName = GetHLFullName (group, opcode);
643
+ // Set/transfer all common attributes
644
+ AttributeSet attribs = GetHLFunctionAttributes (
645
+ M.getContext (), funcTy, AttributeSet (), group, opcode);
646
+
647
+ std::string operatorName = GetHLFullName (group, opcode, attribs);
545
648
std::string mangledName = operatorName + " ." + name.str ();
546
649
raw_string_ostream mangledNameStr (mangledName);
547
650
funcTy->print (mangledNameStr);
548
651
mangledNameStr.flush ();
549
652
550
- Function *F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy));
551
-
552
- SetHLFunctionAttribute (F, group, opcode);
653
+ Function *F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy, attribs));
553
654
554
655
F->setLinkage (llvm::GlobalValue::LinkageTypes::InternalLinkage);
555
656
0 commit comments