Skip to content

Commit 3db244b

Browse files
spikerheado1234vimarsh6739
authored andcommitted
scf.for fix when iter-args and result args have shadow mismatch in fwd
mode ad.
1 parent 9b78a1c commit 3db244b

File tree

3 files changed

+107
-4
lines changed

3 files changed

+107
-4
lines changed

enzyme/.bazelrc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ build --define=use_fast_cpp_protos=true
2121
build --define=allow_oversize_protos=true
2222

2323
build -c opt
24-

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,31 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
349349
: regionBranchOp.getSuccessorInputs(successor);
350350

351351
// Need to know which of the arguments are being forwarded to from
352-
// operands.
352+
// operands. An operand needs a shadow — and the ForOp needs a matching
353+
// shadow result — whenever EITHER its iter arg OR its corresponding op
354+
// result is active. Using only iter arg activity misses the
355+
// constant-accumulator case (constant init arg that produces an active
356+
// result because the loop body accumulates active values into it).
357+
// Using only result activity misses the case where an iter arg is active
358+
// but its result is not (e.g. pointer-typed iter args used for address
359+
// arithmetic whose final values are unused downstream).
360+
// forceAugmentedReturns uses only iter arg activity, so for positions
361+
// where the result is active but the iter arg is constant, the second
362+
// overload inserts the missing shadow block arg after takeBody.
353363
for (auto &&[i, regionValue, operand] :
354364
llvm::enumerate(targetValues, operandRange)) {
355-
if (gutils->isConstantValue(regionValue))
365+
bool iterArgActive = !gutils->isConstantValue(regionValue);
366+
bool resultActive = i < op->getNumResults() &&
367+
!gutils->isConstantValue(op->getResult(i));
368+
if (!iterArgActive && !resultActive)
356369
continue;
357370
operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
358-
if (successor.isParent())
371+
// Add the corresponding result to resultPositionsToShadow if the iter
372+
// arg is active: forceAugmentedReturns will have inserted a shadow
373+
// block arg for it, so the ForOp needs a matching shadow result.
374+
// Active results (regardless of iter arg activity) are covered by the
375+
// loop below.
376+
if (successor.isParent() || (iterArgActive && i < op->getNumResults()))
359377
resultPositionsToShadow.insert(i);
360378
}
361379
}
@@ -423,6 +441,47 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
423441
replacementRegion.takeBody(region);
424442
}
425443

444+
// forceAugmentedReturns inserts shadow block args only for iter args that
445+
// are themselves active. When an iter arg is constant but its corresponding
446+
// op result is active (e.g. a zero accumulator that accumulates active
447+
// values across iterations), the first overload still adds that position to
448+
// both operandPositionsToShadow and resultPositionsToShadow (union
449+
// criterion), so replacement has the right number of results. However, the
450+
// body block is missing the shadow block arg that the replacement's
451+
// iter_arg slot expects. Insert it here, after takeBody has placed the
452+
// cloned body into replacement.
453+
//
454+
// We also register the mapping in invertedPointers so that invertPointerM,
455+
// which checks invertedPointers before isConstantValue, returns the shadow
456+
// block arg instead of zero when body ops reference this iter arg.
457+
if (auto rbIface = dyn_cast<RegionBranchOpInterface>(op)) {
458+
SmallVector<RegionSuccessor> entrySuccessors;
459+
rbIface.getEntrySuccessorRegions(
460+
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
461+
entrySuccessors);
462+
for (const RegionSuccessor &successor : entrySuccessors) {
463+
if (successor.isParent())
464+
continue;
465+
ValueRange successorInputs = rbIface.getSuccessorInputs(successor);
466+
for (auto [i, iterArg] : llvm::enumerate(successorInputs)) {
467+
if (!resultPositionsToShadow.count(i))
468+
continue;
469+
if (!gutils->isConstantValue(iterArg))
470+
continue;
471+
// iterArg is constant but position i needs a shadow result.
472+
// Insert the missing shadow block arg right after iterArg's clone.
473+
auto clonedIterArg =
474+
cast<BlockArgument>(gutils->getNewFromOriginal(iterArg));
475+
Block *block = clonedIterArg.getParentBlock();
476+
Value shadowArg = block->insertArgument(
477+
clonedIterArg.getArgNumber() + 1,
478+
gutils->getShadowType(clonedIterArg.getType()),
479+
clonedIterArg.getLoc());
480+
gutils->invertedPointers.map(iterArg, shadowArg);
481+
}
482+
}
483+
}
484+
426485
// Inject the mapping for the new results into GradientUtil's shadow
427486
// table.
428487
SmallVector<Value> reps;
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: %eopt --enzyme %s | FileCheck %s
2+
3+
// Test that a constant iter arg whose corresponding ForOp result is active
4+
// (the "constant accumulator" pattern) is correctly differentiated.
5+
// The iter arg %acc is initialized from a constant zero and is therefore
6+
// marked constant by activity analysis, but the ForOp result is active
7+
// because active values (%x) are accumulated into it through the body.
8+
// The differentiated ForOp must have a shadow iter arg (also zero-initialized)
9+
// that accumulates the tangent dx on each iteration.
10+
11+
module {
12+
func.func @square(%x : f64) -> f64 {
13+
%zero = arith.constant 0.0 : f64
14+
%c0 = arith.constant 0 : index
15+
%c1 = arith.constant 1 : index
16+
%c10 = arith.constant 10 : index
17+
%r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) {
18+
%n = arith.addf %acc, %x : f64
19+
scf.yield %n : f64
20+
}
21+
return %r : f64
22+
}
23+
func.func @dsq(%x : f64, %dx : f64) -> f64 {
24+
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
25+
return %r : f64
26+
}
27+
}
28+
29+
// The differentiated ForOp must have TWO iter args: the primal accumulator
30+
// (init = 0.0) and its shadow (init = 0.0, since the original init is a
31+
// constant). On each iteration the shadow accumulates dx (= %arg1).
32+
33+
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
34+
// CHECK-DAG: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
35+
// CHECK-DAG: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
36+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
37+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
38+
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
39+
// CHECK-NEXT: %[[r:.+]]:2 = scf.for %{{.+}} = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[acc:.+]] = %[[cst_0]], %[[sacc:.+]] = %[[cst]]) -> (f64, f64) {
40+
// CHECK-NEXT: %[[sn:.+]] = arith.addf %[[sacc]], %[[arg1]] : f64
41+
// CHECK-NEXT: %[[n:.+]] = arith.addf %[[acc]], %[[arg0]] : f64
42+
// CHECK-NEXT: scf.yield %[[n]], %[[sn]] : f64, f64
43+
// CHECK-NEXT: }
44+
// CHECK-NEXT: return %[[r]]#1 : f64
45+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)