Skip to content

Commit 3244024

Browse files
wsmosesCopilot
andauthored
Cleanup use of LI when needed OrigLI (#2755)
* Fix LoopAnalysis crash on incomplete new_func * Fix LoopAnalysis crash on incomplete new_func using isOriginal * Cleanup use of LI when needed OrigLI * Fix LoopAnalysis crash on incomplete new_func and update manual check * Fix LoopAnalysis crash on incomplete new_func and fix inverted condition * Fix LoopAnalysis crash and fix manual addition logic for non-loop values * Fix LoopAnalysis crash and include non-instructions in manual induction check * fix * fix * fix * f * fix * fix * fix * Fix clang-format indentation in GradientUtils.cpp (#2756) * Initial plan * Fix formatting in GradientUtils.cpp Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
1 parent 34e7f80 commit 3244024

File tree

2 files changed

+65
-12
lines changed

2 files changed

+65
-12
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,16 +1611,19 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
16111611

16121612
// Don't attempt to unroll a loop induction variable in other
16131613
// circumstances
1614-
auto &LLI = Logic.PPC.FAM.getResult<LoopAnalysis>(*parent->getParent());
16151614
std::set<BasicBlock *> prevIteration;
1616-
if (LLI.isLoopHeader(parent)) {
1615+
BasicBlock *origParent = isOriginal(parent);
1616+
assert(origParent);
1617+
if (OrigLI->isLoopHeader(origParent)) {
16171618
if (phi->getNumIncomingValues() != 2) {
16181619
assert(unwrapMode != UnwrapMode::LegalFullUnwrap);
16191620
goto endCheck;
16201621
}
1621-
auto L = LLI.getLoopFor(parent);
1622+
auto OrigL = OrigLI->getLoopFor(origParent);
16221623
for (auto PH : predecessors(parent)) {
1623-
if (L->contains(PH))
1624+
BasicBlock *origPH = isOriginal(PH);
1625+
assert(origPH);
1626+
if (OrigL->contains(origPH))
16241627
prevIteration.insert(PH);
16251628
}
16261629
if (prevIteration.size() && !legalRecompute(phi, available, &BuilderM)) {
@@ -1629,16 +1632,33 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
16291632
}
16301633
}
16311634
for (auto &val : phi->incoming_values()) {
1632-
if (isPotentialLastLoopValue(val, parent, LLI)) {
1633-
if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
1634-
llvm::errs() << " module: " << *newFunc->getParent() << "\n";
1635-
llvm::errs() << " newFunc: " << *newFunc << "\n";
1636-
llvm::errs() << " parent: " << *parent << "\n";
1637-
llvm::errs() << " val: " << *val << "\n";
1635+
auto inst = dyn_cast<Instruction>(val);
1636+
if (!inst)
1637+
continue;
1638+
auto origInstParent = isOriginal(inst->getParent());
1639+
assert(origInstParent);
1640+
const llvm::Loop *InstLoop = OrigLI->getLoopFor(origInstParent);
1641+
if (!InstLoop) {
1642+
continue;
1643+
}
1644+
bool isParentLoop = false;
1645+
for (const llvm::Loop *L = OrigLI->getLoopFor(origParent); L;
1646+
L = L->getParentLoop()) {
1647+
if (L == InstLoop) {
1648+
isParentLoop = true;
1649+
break;
16381650
}
1639-
assert(unwrapMode != UnwrapMode::LegalFullUnwrap);
1640-
goto endCheck;
16411651
}
1652+
if (isParentLoop)
1653+
continue;
1654+
if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
1655+
llvm::errs() << " module: " << *newFunc->getParent() << "\n";
1656+
llvm::errs() << " newFunc: " << *newFunc << "\n";
1657+
llvm::errs() << " parent: " << *parent << "\n";
1658+
llvm::errs() << " val: " << *val << "\n";
1659+
}
1660+
assert(unwrapMode != UnwrapMode::LegalFullUnwrap);
1661+
goto endCheck;
16421662
}
16431663

16441664
if (phi->getNumIncomingValues() == 1) {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s
2+
3+
define hidden void @_Z10entry_wrapRN6enzyme6tensorIfJLm2ELm3EEEES2_RKS1_(i8* %out_0, i8* %out_1, i8* %in_0) #0 {
4+
entry:
5+
%out_0.addr = alloca i8*, align 8
6+
%out_1.addr = alloca i8*, align 8
7+
%in_0.addr = alloca i8*, align 8
8+
store i8* %out_0, i8** %out_0.addr, align 8
9+
store i8* %out_1, i8** %out_1.addr, align 8
10+
store i8* %in_0, i8** %in_0.addr, align 8
11+
%0 = load i8*, i8** %out_0.addr, align 8
12+
%1 = load i8*, i8** %out_1.addr, align 8
13+
%2 = load i8*, i8** %in_0.addr, align 8
14+
call void @_Z4myfnILm2ELm3EEvRN6enzyme6tensorIfJXT_EXT0_EEEES3_RKS2_(i8* %0, i8* %1, i8* %2)
15+
ret void
16+
}
17+
18+
define hidden void @_Z4myfnILm2ELm3EEvRN6enzyme6tensorIfJXT_EXT0_EEEES3_RKS2_(i8* %0, i8* %1, i8* %2) #0 {
19+
entry:
20+
ret void
21+
}
22+
23+
declare void @__enzyme_autodiff(...)
24+
25+
define void @test_derivative(i8* %out_0, i8* %out_0_d, i8* %out_1, i8* %out_1_d, i8* %in_0, i8* %in_0_d) {
26+
entry:
27+
call void (...) @__enzyme_autodiff(i8* bitcast (void (i8*, i8*, i8*)* @_Z10entry_wrapRN6enzyme6tensorIfJLm2ELm3EEEES2_RKS1_ to i8*), metadata !"enzyme_dup", i8* %out_0, i8* %out_0_d, metadata !"enzyme_dup", i8* %out_1, i8* %out_1_d, metadata !"enzyme_dup", i8* %in_0, i8* %in_0_d)
28+
ret void
29+
}
30+
31+
attributes #0 = { mustprogress nounwind }
32+
33+
; CHECK: define internal void @diffe_Z10entry_wrap

0 commit comments

Comments
 (0)