Skip to content

Commit 4eaa5ef

Browse files
authored
Handle maxpd (rust-lang#690)
1 parent c68257e commit 4eaa5ef

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

enzyme/Enzyme/AdjointGenerator.h

+37
Original file line numberDiff line numberDiff line change
@@ -10494,6 +10494,43 @@ class AdjointGenerator
1049410494
}
1049510495
}
1049610496
}
10497+
#if LLVM_VERSION_MAJOR >= 11
10498+
if (auto assembly = dyn_cast<InlineAsm>(orig->getCalledOperand()))
10499+
#else
10500+
if (auto assembly = dyn_cast<InlineAsm>(orig->getCalledValue()))
10501+
#endif
10502+
{
10503+
if (assembly->getAsmString() == "maxpd $1, $0") {
10504+
if (Mode == DerivativeMode::ReverseModePrimal ||
10505+
gutils->isConstantInstruction(orig)) {
10506+
10507+
if (gutils->knownRecomputeHeuristic.find(orig) !=
10508+
gutils->knownRecomputeHeuristic.end()) {
10509+
if (!gutils->knownRecomputeHeuristic[orig]) {
10510+
gutils->cacheForReverse(BuilderZ, newCall,
10511+
getIndex(orig, CacheType::Self));
10512+
}
10513+
}
10514+
eraseIfUnused(*orig);
10515+
return;
10516+
}
10517+
10518+
SmallVector<Value *, 2> orig_ops(orig->getNumOperands());
10519+
for (unsigned i = 0; i < orig->getNumOperands(); ++i) {
10520+
orig_ops[i] = orig->getOperand(i);
10521+
}
10522+
handleAdjointForIntrinsic(Intrinsic::maxnum, *orig, orig_ops);
10523+
if (gutils->knownRecomputeHeuristic.find(orig) !=
10524+
gutils->knownRecomputeHeuristic.end()) {
10525+
if (!gutils->knownRecomputeHeuristic[orig]) {
10526+
gutils->cacheForReverse(BuilderZ, newCall,
10527+
getIndex(orig, CacheType::Self));
10528+
}
10529+
}
10530+
eraseIfUnused(*orig);
10531+
return;
10532+
}
10533+
}
1049710534

1049810535
if (called && isAllocationFunction(*called, gutils->TLI)) {
1049910536

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s
2+
3+
define <2 x double> @pmax(<2 x double> %a, <2 x double> %b) {
4+
%r = call <2 x double> asm "maxpd $1, $0", "=x,x,0,~{dirflag},~{fpsr},~{flags}"(<2 x double> %a, <2 x double> %b)
5+
ret <2 x double> %r
6+
}
7+
8+
declare { <2 x double>, <2 x double> } @__enzyme_autodiff(...)
9+
10+
define { <2 x double>, <2 x double> } @test_derivative(<2 x double> %x, <2 x double> %y) {
11+
entry:
12+
%0 = tail call { <2 x double>, <2 x double> } (...) @__enzyme_autodiff(<2 x double> (<2 x double>, <2 x double>)* @pmax, <2 x double> %x, <2 x double> %y)
13+
ret { <2 x double>, <2 x double> } %0
14+
}
15+
16+
; CHECK: define internal { <2 x double>, <2 x double> } @diffepmax(<2 x double> %a, <2 x double> %b, <2 x double> %differeturn)
17+
; CHECK: %r = call <2 x double> asm "maxpd $1, $0", "=x,x,0,~{dirflag},~{fpsr},~{flags}"(<2 x double> %a, <2 x double> %b)
18+
; CHECK-NEXT: %[[i0:.+]] = fcmp fast olt <2 x double> %a, %b
19+
; CHECK-NEXT: %[[i1:.+]] = select {{(fast )?}}<2 x i1> %[[i0]], <2 x double> zeroinitializer, <2 x double> %differeturn
20+
; CHECK-NEXT: %[[i2:.+]] = select {{(fast )?}}<2 x i1> %[[i0]], <2 x double> %differeturn, <2 x double> zeroinitializer
21+
; CHECK-NEXT: %[[i3:.+]] = insertvalue { <2 x double>, <2 x double> } undef, <2 x double> %[[i1]], 0
22+
; CHECK-NEXT: %[[i4:.+]] = insertvalue { <2 x double>, <2 x double> } %[[i3]], <2 x double> %[[i2]], 1
23+
; CHECK-NEXT: ret { <2 x double>, <2 x double> } %[[i4]]

0 commit comments

Comments
 (0)