feat: Allow for users / kv cache to add aliased I/O for inplace operations#4251
Open
narendasan wants to merge 1 commit into
Open
feat: Allow for users / kv cache to add aliased I/O for inplace operations#4251narendasan wants to merge 1 commit into
narendasan wants to merge 1 commit into
Conversation
354674d to
813e753
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py 2026-05-12 00:26:56.728308+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py 2026-05-12 00:27:18.993037+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py 2026-05-12 00:26:56.731194+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py 2026-05-12 00:27:19.634834+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
placeholders into ``get_attr`` reads and registers the buffers as
module state. The result is a plain ``fx.GraphModule`` that
serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect
import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py 2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py 2026-05-12 00:27:21.127481+00:00
@@ -17,10 +17,11 @@
is already visible on the user's input).
These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest
import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py 2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py 2026-05-12 00:27:21.138494+00:00
@@ -15,10 +15,11 @@
``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
transform replaces what used to be an external ``BufferThreadingModule``
wrapper — making the result a plain ``fx.GraphModule`` that exports
naturally without a custom wrapper class.
"""
+
import tempfile
import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py 2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py 2026-05-12 00:27:21.208126+00:00
@@ -36,10 +36,11 @@
workaround that skips ``run_decompositions`` for already-decomposed EPs.
When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest
import torch
import torch_tensorrt
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py 2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py 2026-05-12 00:27:21.275543+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py 2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py 2026-05-12 00:27:21.304553+00:00
@@ -13,10 +13,11 @@
These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py 2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py 2026-05-12 00:27:21.380261+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine813e753 to
bcaf725
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py 2026-05-12 20:26:34.855069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py 2026-05-12 20:26:58.441876+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py 2026-05-12 20:26:34.858373+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py 2026-05-12 20:26:59.052592+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
placeholders into ``get_attr`` reads and registers the buffers as
module state. The result is a plain ``fx.GraphModule`` that
serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect
import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py 2026-05-12 20:26:34.860665+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py 2026-05-12 20:27:00.527955+00:00
@@ -17,10 +17,11 @@
is already visible on the user's input).
These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest
import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py 2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py 2026-05-12 20:27:00.554620+00:00
@@ -15,10 +15,11 @@
``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
transform replaces what used to be an external ``BufferThreadingModule``
wrapper — making the result a plain ``fx.GraphModule`` that exports
naturally without a custom wrapper class.
"""
+
import tempfile
import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py 2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py 2026-05-12 20:27:00.602080+00:00
@@ -36,10 +36,11 @@
workaround that skips ``run_decompositions`` for already-decomposed EPs.
When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest
import torch
import torch_tensorrt
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py 2026-05-12 20:26:34.860665+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py 2026-05-12 20:27:00.676156+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py 2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py 2026-05-12 20:27:00.713992+00:00
@@ -13,10 +13,11 @@
These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py 2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py 2026-05-12 20:27:00.785133+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds support for in-place ATen operators by extending the Torch-TensorRT compile pipeline and C++ runtime with aliased input/output bindings. The motivating case is streaming inference with a key/value cache (e.g. autoregressive decoders, ZoomASR): each step writes a single timestep into the cache, and without aliasing every step pays a full cache-size copy at the engine boundary. With aliased I/O the TensorRT engine writes directly into the user's (or module-held) cache storage; no fresh allocation, no post-engine copy.
Two
aliased_io"kinds" are tracked so the runtime can reason about provenance:kv_cache_update— TensorRT-enforced viaIKVCacheUpdateLayer; reported throughICudaEngine::getAliasedInputTensor.user— Torch-TensorRT-declared; reserved for future expansion if TRT exposes a public non-KV aliasing API.What this PR does
Pipeline (Python)
slice_scatterandindex_copyconverters that recognize KV-cache-update patterns (4-D static cache,dim=2, batch=1) and emitIKVCacheUpdateLayerwith the output aliased to the cache input. Non-eligible cases fall back to scatter in TRT — no graph break.index_copy, two disjoint converters (validator-gated KV fast path atConverterPriority.HIGH+ scatter fallback at standard priority) cleanly split the cases.aliased_ioplumbed throughTRTInterpreter→TRTInterpreterResult→SerializedInterpreterResult→TorchTensorRTModule. Theoutput()step automatically promotes layer outputs that need to be network outputs (KVCacheUpdate requires it) and appends them after user outputs. The user/side-effect boundary is derived at runtime, not stored.Buffer-style support
lift_mutated_bufferslowering pass detectsBUFFER_MUTATIONpatterns (the trailingaten.copy_(get_attr_buffer, ...)thatep.module()emits) and lifts each mutated buffer fromget_attrtoplaceholderso the engine sees it as an input binding.inline_lifted_buffers_into_gmpost-compile transform registers the buffers asnn.Modulestate on the compiledGraphModuleand rewrites the lifted placeholders toget_attrreads. The result is a plainfx.GraphModule(no custom wrapper class) that serializes cleanly throughtorch_tensorrt.save/torch.export.convert_exported_program_to_serialized_trt_enginegainslift_mutable_buffers: bool = Falsefor power users who want to manage the resulting bindings themselves.C++ runtime (ABI v9 → v10)
ABI_VERSIONto"10"; addedALIASED_IO_IDXtoSerializedInfoIndex.serialize_aliased_io/deserialize_aliased_iohelpers (wire format:output@input@kindrecords joined byBINDING_DELIM). Helpers live inruntime_utils.cppalongsideserialize_bindings.TRTEngineconstructor reconciles the build-time map againstICudaEngine::getAliasedInputTensor— the engine API is the source of truth for KV-style aliasing.execute_enginerecords bound input tensors by binding name; for each output binding inaliased_io, binds the samedata_ptras the source input and skips fresh allocation. Pre-allocated outputs are disabled when aliased I/O is present.Docs + examples
docsrc/contributors/inplace_operations.rst— full design doc covering motivation, primitives, pipeline, runtime, serialization format, and known limitations.examples/dynamo/:aliased_io_user_inputs.py— caller-owned cache (simplest case)aliased_io_buffers.py— module-owned cache viaregister_bufferaliased_io_kv_attention.py— realistic single-layer transformer attention block with static KV cacheFixes partially #4240 (in-place custom plugins / multiple outputs — addresses the in-place-operator side; plugin-side aliased I/O is explicitly out of scope here).
Type of change
Checklist
Test summary
38 new tests across 8 files, all passing:
tests/py/dynamo/conversion/test_slice_scatter_aten.pytests/py/dynamo/runtime/test_aliased_io.pytests/py/dynamo/runtime/test_index_copy_kv.pyaten.index_copytests/py/dynamo/runtime/test_lift_mutable_buffers_api.pylift_mutable_buffers=Trueflag round-trip (introspect engine, construct module, run)tests/py/dynamo/runtime/test_aliased_io_serialization.pytorch_tensorrt.save/loadround-trip for user-input + buffer-backed + streaming buffertests/py/dynamo/runtime/test_aliased_io_cudagraphs.pytests/py/dynamo/runtime/test_hf_static_cache_xfail.pytests/py/dynamo/lowering/test_buffer_lifting.pylift_mutated_buffers+inline_lifted_buffers_into_gmunit testsKnown gaps (documented)
StaticCachedon't compile end-to-end yet: torch.export'srun_decompositionsraises internally on the EP thatconvert_and_export_with_cacheproduces. The xfail test asserts the known failure so a future upstream fix surfaces as a test failure. Path forward documented in the design doc.IKVCacheUpdateLayerrequires statics_max. Dynamic-sequence-length cache shapes fall through to the scatter path (still correct, no aliasing).