@@ -349,13 +349,31 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
349349 : regionBranchOp.getSuccessorInputs (successor);
350350
351351 // Need to know which of the arguments are being forwarded to from
352- // operands.
352+ // operands. An operand needs a shadow — and the ForOp needs a matching
353+ // shadow result — whenever EITHER its iter arg OR its corresponding op
354+ // result is active. Using only iter arg activity misses the
355+ // constant-accumulator case (constant init arg that produces an active
356+ // result because the loop body accumulates active values into it).
357+ // Using only result activity misses the case where an iter arg is active
358+ // but its result is not (e.g. pointer-typed iter args used for address
359+ // arithmetic whose final values are unused downstream).
360+ // forceAugmentedReturns uses only iter arg activity, so for positions
361+ // where the result is active but the iter arg is constant, the second
362+ // overload inserts the missing shadow block arg after takeBody.
353363 for (auto &&[i, regionValue, operand] :
354364 llvm::enumerate (targetValues, operandRange)) {
355- if (gutils->isConstantValue (regionValue))
365+ bool iterArgActive = !gutils->isConstantValue (regionValue);
366+ bool resultActive = i < op->getNumResults () &&
367+ !gutils->isConstantValue (op->getResult (i));
368+ if (!iterArgActive && !resultActive)
356369 continue ;
357370 operandPositionsToShadow.insert (operandRange.getBeginOperandIndex () + i);
358- if (successor.isParent ())
371+ // Add the corresponding result to resultPositionsToShadow if the iter
372+ // arg is active: forceAugmentedReturns will have inserted a shadow
373+ // block arg for it, so the ForOp needs a matching shadow result.
374+ // Active results (regardless of iter arg activity) are covered by the
375+ // loop below.
376+ if (successor.isParent () || (iterArgActive && i < op->getNumResults ()))
359377 resultPositionsToShadow.insert (i);
360378 }
361379 }
@@ -423,6 +441,47 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
423441 replacementRegion.takeBody (region);
424442 }
425443
444+ // forceAugmentedReturns inserts shadow block args only for iter args that
445+ // are themselves active. When an iter arg is constant but its corresponding
446+ // op result is active (e.g. a zero accumulator that accumulates active
447+ // values across iterations), the first overload still adds that position to
448+ // both operandPositionsToShadow and resultPositionsToShadow (union
449+ // criterion), so replacement has the right number of results. However, the
450+ // body block is missing the shadow block arg that the replacement's
451+ // iter_arg slot expects. Insert it here, after takeBody has placed the
452+ // cloned body into replacement.
453+ //
454+ // We also register the mapping in invertedPointers so that invertPointerM,
455+ // which checks invertedPointers before isConstantValue, returns the shadow
456+ // block arg instead of zero when body ops reference this iter arg.
457+ if (auto rbIface = dyn_cast<RegionBranchOpInterface>(op)) {
458+ SmallVector<RegionSuccessor> entrySuccessors;
459+ rbIface.getEntrySuccessorRegions (
460+ SmallVector<Attribute>(op->getNumOperands (), Attribute ()),
461+ entrySuccessors);
462+ for (const RegionSuccessor &successor : entrySuccessors) {
463+ if (successor.isParent ())
464+ continue ;
465+ ValueRange successorInputs = rbIface.getSuccessorInputs (successor);
466+ for (auto [i, iterArg] : llvm::enumerate (successorInputs)) {
467+ if (!resultPositionsToShadow.count (i))
468+ continue ;
469+ if (!gutils->isConstantValue (iterArg))
470+ continue ;
471+ // iterArg is constant but position i needs a shadow result.
472+ // Insert the missing shadow block arg right after iterArg's clone.
473+ auto clonedIterArg =
474+ cast<BlockArgument>(gutils->getNewFromOriginal (iterArg));
475+ Block *block = clonedIterArg.getParentBlock ();
476+ Value shadowArg = block->insertArgument (
477+ clonedIterArg.getArgNumber () + 1 ,
478+ gutils->getShadowType (clonedIterArg.getType ()),
479+ clonedIterArg.getLoc ());
480+ gutils->invertedPointers .map (iterArg, shadowArg);
481+ }
482+ }
483+ }
484+
426485 // Inject the mapping for the new results into GradientUtil's shadow
427486 // table.
428487 SmallVector<Value> reps;
0 commit comments