99from functools import cached_property
1010import numpy as np
1111
12+ # Mapping from ml_dtypes (non-native numpy) types to their torch equivalents.
13+ # Native numpy dtypes (float32, int32, …) are handled directly by torch.from_numpy
14+ # and do not need an entry here.
15+ # Populated lazily at first use to avoid importing torch/ml_dtypes at module load.
16+ _ML_DTYPE_TO_TORCH : dict | None = None
17+
18+
19+ def _ml_dtype_to_torch_map ():
20+ global _ML_DTYPE_TO_TORCH
21+ if _ML_DTYPE_TO_TORCH is None :
22+ import torch
23+ import ml_dtypes
24+
25+ _candidates = {
26+ ml_dtypes .bfloat16 : torch .bfloat16 ,
27+ }
28+ for attr in (
29+ "float8_e4m3fn" ,
30+ "float8_e5m2" ,
31+ "float8_e4m3fnuz" ,
32+ "float8_e5m2fnuz" ,
33+ ):
34+ ml_dt = getattr (ml_dtypes , attr , None )
35+ torch_dt = getattr (torch , attr , None )
36+ if ml_dt is not None and torch_dt is not None :
37+ _candidates [ml_dt ] = torch_dt
38+ _ML_DTYPE_TO_TORCH = {
39+ np .dtype (ml_dt ): torch_dt for ml_dt , torch_dt in _candidates .items ()
40+ }
41+ return _ML_DTYPE_TO_TORCH
42+
43+
44+ # Same-width unsigned integer dtype for the ND reinterpret-view trick.
45+ _UINT_VIEW_DTYPE = {
46+ 1 : np .uint8 ,
47+ 2 : np .uint16 ,
48+ 4 : np .uint32 ,
49+ 8 : np .uint64 ,
50+ }
51+
52+
53+ def _array_to_torch (array : np .ndarray ):
54+ """
55+ Convert a numpy array to a torch tensor, zero-copy.
56+
57+ For native numpy dtypes (float32, float16, int32, …) torch.from_numpy is used directly
58+ (fastest path for these types).
59+
60+ For ml_dtypes types (bfloat16, float8_*) that torch cannot consume via from_numpy:
61+ reinterpret as a same-width unsigned integer numpy view, wrap with from_numpy,
62+ then view as the target torch dtype. This is guaranteed zero-copy for all ranks.
63+
64+ Raises:
65+ ImportError: If torch is not installed.
66+ """
67+ # _ml_dtype_to_torch_map() imports torch (raising ImportError with a helpful message
68+ # if absent) and returns the ml_dtype -> torch dtype mapping.
69+ torch_dtype = _ml_dtype_to_torch_map ().get (array .dtype )
70+ import torch # already imported by _ml_dtype_to_torch_map(); cached by Python
71+
72+ if torch_dtype is None :
73+ # Native numpy dtype: torch.from_numpy handles it directly and fastest.
74+ return torch .from_numpy (array )
75+
76+ # ml_dtype: reinterpret memory as a same-width uint, then view as the torch dtype.
77+ uint_dtype = _UINT_VIEW_DTYPE [array .dtype .itemsize ]
78+ return torch .from_numpy (array .view (uint_dtype )).view (torch_dtype )
79+
1280
1381class Tensor (ABC ):
1482 """
1583 Tensor object backed by NPU or CPU memory.
1684
17- The class provides commom tensor operations such as creation,
85+ The class provides common tensor operations such as creation,
1886 filling with values, and accessing data.
1987
2088 """
@@ -258,28 +326,33 @@ def to_torch(self):
258326 """
259327 Returns a torch tensor sharing the data in this tensor if possible.
260328
329+ Syncs from device first if the tensor is on the NPU.
330+
261331 Returns:
262332 torch.Tensor: A torch tensor containing the data.
263333
264334 Raises:
265335 ImportError: If torch is not installed.
266336 """
267- try :
268- import torch
269- from ml_dtypes import bfloat16
270- except ImportError :
271- raise ImportError (
272- "torch is not installed. Please install it with 'pip install torch'"
273- )
337+ return _array_to_torch (self .numpy ())
274338
275- array = self .numpy ()
339+ def torch_view (self ):
340+ """
341+ Returns a torch tensor sharing this buffer's host memory without syncing from device.
276342
277- if array . dtype == bfloat16 :
278- # reinterpret the same memory as int16, then view as torch.bfloat16
279- t_u16 = torch . from_numpy ( array . view ( np . uint16 ))
280- return t_u16 . view ( torch . bfloat16 )
343+ Unlike to_torch(), this does NOT sync from the NPU first. Marks the buffer as
344+ CPU-resident so that a subsequent .to("npu") call (or the NPU operator's implicit
345+ sync) will push the written data to device. Use this on write paths where the
346+ caller is about to overwrite the buffer contents.
281347
282- return torch .from_numpy (array )
348+ Returns:
349+ torch.Tensor: A zero-copy torch tensor view of the host-side buffer.
350+
351+ Raises:
352+ ImportError: If torch is not installed.
353+ """
354+ self .device = "cpu" # mark dirty so next to("npu") will actually sync
355+ return _array_to_torch (self .data )
283356
284357 @classmethod
285358 def from_torch (cls , torch_tensor , device = None , ** kwargs ):
@@ -297,13 +370,8 @@ def from_torch(cls, torch_tensor, device=None, **kwargs):
297370 Raises:
298371 ImportError: If torch is not installed.
299372 """
300- try :
301- import torch
302- from ml_dtypes import bfloat16
303- except ImportError :
304- raise ImportError (
305- "torch is not installed. Please install it with 'pip install torch'"
306- )
373+ import torch
374+ from ml_dtypes import bfloat16
307375
308376 # Detach (to drop grad) and ensure on CPU
309377 t = torch_tensor .detach ()
0 commit comments