Skip to content

Enable stencil ghost cell widening for sharded compilations#2732

Closed
gbaraldi wants to merge 7 commits intoEnzymeAD:mainfrom
gbaraldi:stencil-ghost-cell-widening
Closed

Enable stencil ghost cell widening for sharded compilations#2732
gbaraldi wants to merge 7 commits intoEnzymeAD:mainfrom
gbaraldi:stencil-ghost-cell-widening

Conversation

@gbaraldi
Copy link
Copy Markdown
Collaborator

Summary

Enables the stencil-ghost-cell-widening pass in the compilation pipeline for sharded computations. The pass replaces per-operator halo exchanges with a single wide ghost cell exchange, following the overlapped tiling with redundant computation approach.

Pipeline placement

Runs before the optimization passes (which include recognize_extend, recognize_rotate, recognize_wrap, and pad merging patterns). This ensures stencil pads are eliminated before they get recognized as communication patterns or merged into MultiPadOp.

stencil-ghost-cell-widening → canonicalize → cse → [existing optimization passes]

Only runs when is_sharded=true and optimization_passes=:all.

Dependencies

Requires EnzymeAD/Enzyme-JAX#2326 (the pass implementation).

Test plan

  • CI passes with the new pass in the pipeline
  • GB-25 sharded simulation produces correct results
  • Verify collective-permute reduction in XLA output

🤖 Generated with Claude Code

gbaraldi and others added 7 commits March 25, 2026 13:34
Adds `stencil-ghost-cell-widening` to the compilation pipeline for
sharded computations. Runs before optimization passes so stencil
pads are eliminated before they get recognized as communication
patterns (recognize_extend/rotate/wrap) or merged into MultiPadOp.

Requires EnzymeAD/Enzyme-JAX#2326.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Runs as a function pass inside func.func(...) pipeline, before the
transform-dialect patterns (which include multi-pad recognition).
Only enabled when is_sharded=true.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
…ring

The stencil slice→pad patterns are cleaner after canonicalize/CSE and
the pad optimization patterns have run. Moving it after the transform
patterns but before lower_comms gives better detection.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
The pass is a func::FuncOp pass, so it needs func.func() wrapper when
used in a module-level pipeline.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Only keep the func.func()-wrapped version in func_passes.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
The transform_passes include recognize_extend which converts slice→pad
into enzymexla.extend before our pass could see them. Move the ghost
cell widening to run after canonicalize/CSE but before the transform
dialect patterns.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Pad optimization patterns (slice_pad etc.) fold our widened slices back
into stencil pads if they run after us. By running last in func_passes
(after all transform patterns and comm lowering), nothing undoes our work.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@gbaraldi gbaraldi closed this Mar 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant