Skip to content

Commit 190bd98

Browse files
hunhoffeclaude
andauthored
Runtime and tensor improvements (#3008)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 262b751 commit 190bd98

9 files changed

Lines changed: 718 additions & 76 deletions

File tree

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ if (AIE_ENABLE_XRT_PYTHON_BINDINGS)
206206
endif()
207207
endif()
208208

209+
# Detect whether PyTorch is importable by the configured Python interpreter.
210+
# Used to enable/skip lit tests that require torch (# REQUIRES: pytorch).
211+
execute_process(
212+
COMMAND ${Python3_EXECUTABLE} -c "import torch"
213+
RESULT_VARIABLE _torch_result
214+
OUTPUT_QUIET ERROR_QUIET
215+
)
216+
if(_torch_result EQUAL 0)
217+
set(AIE_ENABLE_PYTORCH ON)
218+
else()
219+
set(AIE_ENABLE_PYTORCH OFF)
220+
endif()
221+
message(STATUS "PyTorch available for testing: ${AIE_ENABLE_PYTORCH}")
222+
209223
cmake_dependent_option(AIECC_COMPILE
210224
"Set aiecc to compile." ON "NOT AIE_COMPILER STREQUAL NONE" OFF)
211225

programming_examples/lit.cfg.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,9 @@
7070
config.opencv_include_dir, config.opencv_lib_dir, config.opencv_libs
7171
)
7272

73-
try:
74-
import torch
75-
73+
if config.pytorch:
7674
config.available_features.add("torch")
77-
except ImportError:
78-
print("torch not found", file=sys.stderr)
79-
pass
75+
config.available_features.add("pytorch")
8076

8177
# Setup host target triplet and sysroot
8278
triplet, sysroot_flag = LitConfigHelper.setup_host_target_triplet(

programming_examples/lit.site.cfg.py.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ config.test_exec_root = r"""@CMAKE_CURRENT_BINARY_DIR@"""
5353

5454
config.hsa_dir = r"""@hsa-runtime64_DIR@"""
5555
config.hsa_found = lit.util.pythonize_bool(r"""@hsa-runtime64_FOUND@""")
56+
config.pytorch = lit.util.pythonize_bool(r"""@AIE_ENABLE_PYTORCH@""")
5657

5758
# pass on vitis settings
5859
config.enable_chess_tests = @CONFIG_ENABLE_CHESS_TESTS@

python/utils/hostruntime/tensor_class.py

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,80 @@
99
from functools import cached_property
1010
import numpy as np
1111

12+
# Mapping from ml_dtypes (non-native numpy) types to their torch equivalents.
13+
# Native numpy dtypes (float32, int32, …) are handled directly by torch.from_numpy
14+
# and do not need an entry here.
15+
# Populated lazily at first use to avoid importing torch/ml_dtypes at module load.
16+
_ML_DTYPE_TO_TORCH: dict | None = None
17+
18+
19+
def _ml_dtype_to_torch_map():
20+
global _ML_DTYPE_TO_TORCH
21+
if _ML_DTYPE_TO_TORCH is None:
22+
import torch
23+
import ml_dtypes
24+
25+
_candidates = {
26+
ml_dtypes.bfloat16: torch.bfloat16,
27+
}
28+
for attr in (
29+
"float8_e4m3fn",
30+
"float8_e5m2",
31+
"float8_e4m3fnuz",
32+
"float8_e5m2fnuz",
33+
):
34+
ml_dt = getattr(ml_dtypes, attr, None)
35+
torch_dt = getattr(torch, attr, None)
36+
if ml_dt is not None and torch_dt is not None:
37+
_candidates[ml_dt] = torch_dt
38+
_ML_DTYPE_TO_TORCH = {
39+
np.dtype(ml_dt): torch_dt for ml_dt, torch_dt in _candidates.items()
40+
}
41+
return _ML_DTYPE_TO_TORCH
42+
43+
44+
# Same-width unsigned integer dtype for the ND reinterpret-view trick.
45+
_UINT_VIEW_DTYPE = {
46+
1: np.uint8,
47+
2: np.uint16,
48+
4: np.uint32,
49+
8: np.uint64,
50+
}
51+
52+
53+
def _array_to_torch(array: np.ndarray):
54+
"""
55+
Convert a numpy array to a torch tensor, zero-copy.
56+
57+
For native numpy dtypes (float32, float16, int32, …) torch.from_numpy is used directly
58+
(fastest path for these types).
59+
60+
For ml_dtypes types (bfloat16, float8_*) that torch cannot consume via from_numpy:
61+
reinterpret as a same-width unsigned integer numpy view, wrap with from_numpy,
62+
then view as the target torch dtype. This is guaranteed zero-copy for all ranks.
63+
64+
Raises:
65+
ImportError: If torch is not installed.
66+
"""
67+
# _ml_dtype_to_torch_map() imports torch (raising ImportError with a helpful message
68+
# if absent) and returns the ml_dtype -> torch dtype mapping.
69+
torch_dtype = _ml_dtype_to_torch_map().get(array.dtype)
70+
import torch # already imported by _ml_dtype_to_torch_map(); cached by Python
71+
72+
if torch_dtype is None:
73+
# Native numpy dtype: torch.from_numpy handles it directly and fastest.
74+
return torch.from_numpy(array)
75+
76+
# ml_dtype: reinterpret memory as a same-width uint, then view as the torch dtype.
77+
uint_dtype = _UINT_VIEW_DTYPE[array.dtype.itemsize]
78+
return torch.from_numpy(array.view(uint_dtype)).view(torch_dtype)
79+
1280

1381
class Tensor(ABC):
1482
"""
1583
Tensor object backed by NPU or CPU memory.
1684
17-
The class provides commom tensor operations such as creation,
85+
The class provides common tensor operations such as creation,
1886
filling with values, and accessing data.
1987
2088
"""
@@ -258,28 +326,33 @@ def to_torch(self):
258326
"""
259327
Returns a torch tensor sharing the data in this tensor if possible.
260328
329+
Syncs from device first if the tensor is on the NPU.
330+
261331
Returns:
262332
torch.Tensor: A torch tensor containing the data.
263333
264334
Raises:
265335
ImportError: If torch is not installed.
266336
"""
267-
try:
268-
import torch
269-
from ml_dtypes import bfloat16
270-
except ImportError:
271-
raise ImportError(
272-
"torch is not installed. Please install it with 'pip install torch'"
273-
)
337+
return _array_to_torch(self.numpy())
274338

275-
array = self.numpy()
339+
def torch_view(self):
340+
"""
341+
Returns a torch tensor sharing this buffer's host memory without syncing from device.
276342
277-
if array.dtype == bfloat16:
278-
# reinterpret the same memory as int16, then view as torch.bfloat16
279-
t_u16 = torch.from_numpy(array.view(np.uint16))
280-
return t_u16.view(torch.bfloat16)
343+
Unlike to_torch(), this does NOT sync from the NPU first. Marks the buffer as
344+
CPU-resident so that a subsequent .to("npu") call (or the NPU operator's implicit
345+
sync) will push the written data to device. Use this on write paths where the
346+
caller is about to overwrite the buffer contents.
281347
282-
return torch.from_numpy(array)
348+
Returns:
349+
torch.Tensor: A zero-copy torch tensor view of the host-side buffer.
350+
351+
Raises:
352+
ImportError: If torch is not installed.
353+
"""
354+
self.device = "cpu" # mark dirty so next to("npu") will actually sync
355+
return _array_to_torch(self.data)
283356

284357
@classmethod
285358
def from_torch(cls, torch_tensor, device=None, **kwargs):
@@ -297,13 +370,8 @@ def from_torch(cls, torch_tensor, device=None, **kwargs):
297370
Raises:
298371
ImportError: If torch is not installed.
299372
"""
300-
try:
301-
import torch
302-
from ml_dtypes import bfloat16
303-
except ImportError:
304-
raise ImportError(
305-
"torch is not installed. Please install it with 'pip install torch'"
306-
)
373+
import torch
374+
from ml_dtypes import bfloat16
307375

308376
# Detach (to drop grad) and ensure on CPU
309377
t = torch_tensor.detach()

python/utils/hostruntime/xrtruntime/hostruntime.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,10 @@ def load(
204204
raise RuntimeError("No kernels found in xclbin")
205205
kernel_name = kernels[0].get_name()
206206
else:
207-
if not kernel_name in [k.get_name() for k in xclbin.get_kernels()]:
207+
available_kernels = [k.get_name() for k in xclbin.get_kernels()]
208+
if kernel_name not in available_kernels:
208209
raise HostRuntimeError(
209-
f"Kernel {kernel_name} not found in xclbin (kernels found: {[k.get_name() for k in xclbin.get_kernels()]})"
210+
f"Kernel {kernel_name} not found in xclbin (kernels found: {available_kernels})"
210211
)
211212

212213
insts = self.read_insts(insts_path)
@@ -399,7 +400,6 @@ def cleanup(self):
399400
gc.collect() # Make sure contexts are garbage collected.
400401

401402
def _cleanup_entry(self, entry):
402-
context = entry["context"]
403403
handles = entry["handles"]
404404

405405
# Invalidate all handles
@@ -408,17 +408,24 @@ def _cleanup_entry(self, entry):
408408
if handle:
409409
handle.invalidate()
410410

411-
# Explicitly delete context
412-
del context
411+
# Clear kernel cache so pyxrt.kernel objects are released with the context
412+
entry["kernels"].clear()
413+
414+
# Release the hw_context by removing its strong reference from the entry dict.
415+
# Simply assigning a local `context = entry["context"]` and then `del context`
416+
# only removes the local name — entry["context"] would keep the object alive for
417+
# as long as any caller holds a reference to the entry dict (e.g. tests, or the
418+
# exception handler). Deleting the key guarantees the refcount drops here.
419+
del entry["context"]
413420

414421
def _evict(self):
415422
# Pop the oldest item
416423
key, entry = self._context_cache.popitem(last=False)
417424
self._cleanup_entry(entry)
418425

419426
def _cleanup_insts_entry(self, entry):
420-
insts_bo = entry["insts_bo"]
421-
del insts_bo
427+
# Delete the key (not a local copy) so the refcount drops here.
428+
del entry["insts_bo"]
422429

423430
def _evict_insts(self):
424431
key, entry = self._insts_cache.popitem(last=False)
@@ -542,6 +549,7 @@ def load(
542549
entry = {
543550
"context": context,
544551
"xclbin": xclbin,
552+
"kernels": {}, # kernel_name -> pyxrt.kernel (strong ref, tied to context)
545553
"handles": [],
546554
"uuid": xclbin_uuid,
547555
}
@@ -554,17 +562,25 @@ def load(
554562
raise RuntimeError("No kernels found in xclbin")
555563
kernel_name = kernels[0].get_name()
556564
else:
557-
if not kernel_name in [k.get_name() for k in xclbin.get_kernels()]:
565+
available_kernels = [k.get_name() for k in xclbin.get_kernels()]
566+
if kernel_name not in available_kernels:
558567
raise HostRuntimeError(
559-
f"Kernel {kernel_name} not found in xclbin (kernels found: {[k.get_name() for k in xclbin.get_kernels()]})"
568+
f"Kernel {kernel_name} not found in xclbin (kernels found: {available_kernels})"
560569
)
561570

562571
insts = self.read_insts(insts_path)
563572
insts_bo = None
564573
if hasattr(pyxrt, "module") and isinstance(insts, pyxrt.module):
565-
kernel = pyxrt.ext.kernel(context, insts, kernel_name)
574+
ext_kernel_key = (kernel_name, str(insts_path), insts_mtime)
575+
if ext_kernel_key not in entry["kernels"]:
576+
entry["kernels"][ext_kernel_key] = pyxrt.ext.kernel(
577+
context, insts, kernel_name
578+
)
579+
kernel = entry["kernels"][ext_kernel_key]
566580
else:
567-
kernel = pyxrt.kernel(context, kernel_name)
581+
if kernel_name not in entry["kernels"]:
582+
entry["kernels"][kernel_name] = pyxrt.kernel(context, kernel_name)
583+
kernel = entry["kernels"][kernel_name]
568584

569585
# Magic number for RyzenAI group id that will be fixed in the future. See same code at XRT:
570586
# https://github.com/Xilinx/XRT/blob/56222ed5cfd119dff0d5bd920735b87024e8c829/src/runtime_src/core/common/api/xrt_module.cpp#L1621

test/lit.cfg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@
188188
if config.has_mlir_runtime_libraries:
189189
config.available_features.add("has_mlir_runtime_libraries")
190190

191+
if config.pytorch:
192+
config.available_features.add("pytorch")
191193

192194
if "LIT_AVAILABLE_FEATURES" in os.environ:
193195
for feature in os.environ["LIT_AVAILABLE_FEATURES"].split():

test/lit.site.cfg.py.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ config.enable_python_tests = lit.util.pythonize_bool(r"""@ENABLE_PYTHON_TESTS@""
5959
config.python_passes = lit.util.pythonize_bool(r"""@AIE_ENABLE_PYTHON_PASSES@""")
6060
config.xrt_python_bindings = lit.util.pythonize_bool(r"""@AIE_ENABLE_XRT_PYTHON_BINDINGS@""")
6161
config.has_mlir_runtime_libraries = lit.util.pythonize_bool(r"""@HAS_MLIR_RUNTIME_LIBRARIES@""")
62+
config.pytorch = lit.util.pythonize_bool(r"""@AIE_ENABLE_PYTORCH@""")
6263

6364
# pass on vitis settings
6465
config.vitis_root = r"""@VITIS_ROOT@"""

0 commit comments

Comments
 (0)