-
Notifications
You must be signed in to change notification settings - Fork 33
Pad alignment for TPU perf #2305
Description
TPU Performance Analysis: ConcatBitcast Bottleneck
Executive Summary
The significant performance degradation observed on TPU (compared to GPU) is primarily due to a Custom Call operation named ConcatBitcast.
In the trace, this operation consumes the vast majority of execution time (seconds range per call), resulting in extremely low FLOPS utilization (~0.08%).
Root Cause Analysis
Origin of ConcatBitcast: The ConcatBitcast Custom Call is introduced late in the XLA compilation pipeline by the MemorySpaceAssignment (MSA) pass. MSA inserts this operation to aggregate multiple buffer allocations or slice results (e.g., from Async communications or Async Halo exchanges) into a single contiguous buffer.
The Dynamic/Unaligned Shape Bottleneck: The tensor shape involved is [760, 1528] with a layout of f32[760,1528]{1,0:T(8,128)S(1)}.
1528 is not aligned to TPU's standard 128-element tile boundary template (1528 % 128 = 120 remainder).
760 is also not aligned to 128 template (760 % 128 = 120).
Because the dimensions are unaligned with respect to the tile grid (lane/sublane boundaries), the lowering algorithm for ConcatBitcast (likely generated by Mosaic or native fallbacks) cannot utilize Bulk DMA or parallel full-tile transfers. Instead, it falls back to a slow, unaligned transfer implementation (e.g., elementwise copies or highly strided copies), which takes 2.8s-4.2s to complete in the trace.
can we potentially write an mlir pass which will recursively add paddings to maximize the amount of alignment
We should only attempt to tile the last two dimensions by multiplies of 128 [and not do so if the size of an axis < 64]. We should attempt to insert operations in such a way that as few intermediate pad/slices are possible. Ideally we can just pad a few operations at the start, and slice a few at the end