Skip to content

feat: Allow for users / kv cache to add aliased I/O for inplace operations#4251

Open
narendasan wants to merge 1 commit into
mainfrom
narendasan/aliased_io
Open

feat: Allow for users / kv cache to add aliased I/O for inplace operations#4251
narendasan wants to merge 1 commit into
mainfrom
narendasan/aliased_io

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

Description

Adds support for in-place ATen operators by extending the Torch-TensorRT compile pipeline and C++ runtime with aliased input/output bindings. The motivating case is streaming inference with a key/value cache (e.g. autoregressive decoders, ZoomASR): each step writes a single timestep into the cache, and without aliasing every step pays a full cache-size copy at the engine boundary. With aliased I/O the TensorRT engine writes directly into the user's (or module-held) cache storage; no fresh allocation, no post-engine copy.

Two aliased_io "kinds" are tracked so the runtime can reason about provenance:

  • kv_cache_update — TensorRT-enforced via IKVCacheUpdateLayer; reported through ICudaEngine::getAliasedInputTensor.
  • user — Torch-TensorRT-declared; reserved for future expansion if TRT exposes a public non-KV aliasing API.

What this PR does

Pipeline (Python)

  • New slice_scatter and index_copy converters that recognize KV-cache-update patterns (4-D static cache, dim=2, batch=1) and emit IKVCacheUpdateLayer with the output aliased to the cache input. Non-eligible cases fall back to scatter in TRT — no graph break.
  • For index_copy, two disjoint converters (validator-gated KV fast path at ConverterPriority.HIGH + scatter fallback at standard priority) cleanly split the cases.
  • aliased_io plumbed through TRTInterpreterTRTInterpreterResultSerializedInterpreterResultTorchTensorRTModule. The output() step automatically promotes layer outputs that need to be network outputs (KVCacheUpdate requires it) and appends them after user outputs. The user/side-effect boundary is derived at runtime, not stored.

Buffer-style support

  • lift_mutated_buffers lowering pass detects BUFFER_MUTATION patterns (the trailing aten.copy_(get_attr_buffer, ...) that ep.module() emits) and lifts each mutated buffer from get_attr to placeholder so the engine sees it as an input binding.
  • inline_lifted_buffers_into_gm post-compile transform registers the buffers as nn.Module state on the compiled GraphModule and rewrites the lifted placeholders to get_attr reads. The result is a plain fx.GraphModule (no custom wrapper class) that serializes cleanly through torch_tensorrt.save / torch.export.
  • Low-level entry point convert_exported_program_to_serialized_trt_engine gains lift_mutable_buffers: bool = False for power users who want to manage the resulting bindings themselves.

C++ runtime (ABI v9 → v10)

  • Bumped ABI_VERSION to "10"; added ALIASED_IO_IDX to SerializedInfoIndex.
  • serialize_aliased_io / deserialize_aliased_io helpers (wire format: output@input@kind records joined by BINDING_DELIM). Helpers live in runtime_utils.cpp alongside serialize_bindings.
  • TRTEngine constructor reconciles the build-time map against ICudaEngine::getAliasedInputTensor — the engine API is the source of truth for KV-style aliasing.
  • execute_engine records bound input tensors by binding name; for each output binding in aliased_io, binds the same data_ptr as the source input and skips fresh allocation. Pre-allocated outputs are disabled when aliased I/O is present.
  • CUDA Graphs integration: aliased input bindings bypass the persistent-clone path so the engine writes through to user storage; aliased outputs are skipped in the post-exec copy-back loop. Capture + replay both correctly mutate the user's tensor.

Docs + examples

  • docsrc/contributors/inplace_operations.rst — full design doc covering motivation, primitives, pipeline, runtime, serialization format, and known limitations.
  • Three examples under examples/dynamo/:
    • aliased_io_user_inputs.py — caller-owned cache (simplest case)
    • aliased_io_buffers.py — module-owned cache via register_buffer
    • aliased_io_kv_attention.py — realistic single-layer transformer attention block with static KV cache

Fixes partially #4240 (in-place custom plugins / multiple outputs — addresses the in-place-operator side; plugin-side aliased I/O is explicitly out of scope here).

Type of change

  • New feature (non-breaking for callers who don't opt in to aliased I/O; ABI-breaking for existing engine binaries — older serialized engines fail the version check and need to be rebuilt, consistent with prior ABI bumps).
  • This change requires a documentation update (included).

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR

Test summary

38 new tests across 8 files, all passing:

File Cases Covers
tests/py/dynamo/conversion/test_slice_scatter_aten.py 8 scatter-fallback path (numerical correctness via the standard converter harness)
tests/py/dynamo/runtime/test_aliased_io.py 8 end-to-end aliased I/O (user-input single/paired/streaming + buffer-style + regressions)
tests/py/dynamo/runtime/test_index_copy_kv.py 4 KV fast path + 3 fallback shapes for aten.index_copy
tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py 4 low-level lift_mutable_buffers=True flag round-trip (introspect engine, construct module, run)
tests/py/dynamo/runtime/test_aliased_io_serialization.py 3 torch_tensorrt.save / load round-trip for user-input + buffer-backed + streaming buffer
tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py 3 CUDA Graph capture + replay for both user-input and buffer-backed KV; cudagraphs vs non-cudagraphs parity
tests/py/dynamo/runtime/test_hf_static_cache_xfail.py 1 xfail documenting the current HF + StaticCache gap (asserts known failure mode)
tests/py/dynamo/lowering/test_buffer_lifting.py 9 lift_mutated_buffers + inline_lifted_buffers_into_gm unit tests

Known gaps (documented)

  • Stock HuggingFace decoder LMs with StaticCache don't compile end-to-end yet: torch.export's run_decompositions raises internally on the EP that convert_and_export_with_cache produces. The xfail test asserts the known failure so a future upstream fix surfaces as a test failure. Path forward documented in the design doc.
  • IKVCacheUpdateLayer requires static s_max. Dynamic-sequence-length cache shapes fall through to the scatter path (still correct, no aliasing).

@narendasan narendasan requested a review from apbose May 12, 2026 00:19
@meta-cla meta-cla Bot added the cla signed label May 12, 2026
@github-actions github-actions Bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 12, 2026
@github-actions github-actions Bot requested a review from cehongwang May 12, 2026 00:19
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@narendasan narendasan force-pushed the narendasan/aliased_io branch from 354674d to 813e753 Compare May 12, 2026 00:26
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 00:26:56.728308+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 00:27:18.993037+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 00:26:56.731194+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 00:27:19.634834+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
  placeholders into ``get_attr`` reads and registers the buffers as
  module state. The result is a plain ``fx.GraphModule`` that
  serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect

import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 00:27:21.127481+00:00
@@ -17,10 +17,11 @@
  is already visible on the user's input).

These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 00:27:21.138494+00:00
@@ -15,10 +15,11 @@
  ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
  transform replaces what used to be an external ``BufferThreadingModule``
  wrapper — making the result a plain ``fx.GraphModule`` that exports
  naturally without a custom wrapper class.
"""
+
import tempfile

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 00:27:21.208126+00:00
@@ -36,10 +36,11 @@
  workaround that skips ``run_decompositions`` for already-decomposed EPs.

When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest

import torch
import torch_tensorrt

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 00:27:21.275543+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
  return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
  ``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 00:27:21.304553+00:00
@@ -13,10 +13,11 @@

These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 00:27:21.380261+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
   aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
   mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine

@narendasan narendasan force-pushed the narendasan/aliased_io branch from 813e753 to bcaf725 Compare May 12, 2026 20:26
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 20:26:34.855069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 20:26:58.441876+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 20:26:34.858373+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 20:26:59.052592+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
  placeholders into ``get_attr`` reads and registers the buffers as
  module state. The result is a plain ``fx.GraphModule`` that
  serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect

import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 20:26:34.860665+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 20:27:00.527955+00:00
@@ -17,10 +17,11 @@
  is already visible on the user's input).

These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 20:27:00.554620+00:00
@@ -15,10 +15,11 @@
  ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
  transform replaces what used to be an external ``BufferThreadingModule``
  wrapper — making the result a plain ``fx.GraphModule`` that exports
  naturally without a custom wrapper class.
"""
+
import tempfile

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 20:27:00.602080+00:00
@@ -36,10 +36,11 @@
  workaround that skips ``run_decompositions`` for already-decomposed EPs.

When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest

import torch
import torch_tensorrt

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 20:26:34.860665+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 20:27:00.676156+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
  return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
  ``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 20:27:00.713992+00:00
@@ -13,10 +13,11 @@

These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 20:27:00.785133+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
   aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
   mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant