Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/compressed_tensors/offload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import contextlib
from collections.abc import Iterable
from typing import Literal
from typing import Literal, Optional

import torch
from compressed_tensors.distributed.utils import set_source_process
Expand Down Expand Up @@ -32,6 +32,7 @@
to_meta,
)
from compressed_tensors.utils.helpers import patch_attr
from torch._prims_common import DeviceLikeType


__all__ = [
Expand Down Expand Up @@ -145,8 +146,8 @@ def update_offload_parameter(module: torch.nn.Module, name: str, data: torch.Ten


def get_execution_device(
module: torch.nn.Module, default: torch.device | None = None
) -> torch.device | Literal["disk"]:
module: torch.nn.Module, default: Optional[DeviceLikeType] = None
) -> torch.device:
"""
Get the device which inputs should be moved to before module execution.

Expand All @@ -161,7 +162,7 @@ def get_execution_device(


def get_offloaded_device(
module: torch.nn.Module, default: torch.device | None = None
module: torch.nn.Module, default: Optional[DeviceLikeType] = None
) -> torch.device | Literal["disk"]:
"""
:param module: module to check
Expand Down Expand Up @@ -236,7 +237,7 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
@contextlib.contextmanager
def align_modules(
modules: torch.nn.Module | Iterable[torch.nn.Module],
execution_device: torch.device | None = None,
execution_device: Optional[DeviceLikeType] = None,
):
"""
Context manager for onloading modules to a device, and disabling onload and offload
Expand All @@ -253,7 +254,7 @@ def align_modules(

@contextlib.contextmanager
def align_module_device(
module: torch.nn.Module, execution_device: torch.device | None = None
module: torch.nn.Module, execution_device: Optional[DeviceLikeType] = None
):
"""
Context manager that moves a module's parameters to the specified execution device.
Expand Down Expand Up @@ -286,4 +287,4 @@ def align_module_device(
finally:
for name, param in module.named_parameters(recurse=False):
device = original_device[name]
move_module_tensor(module, name, device)
move_module_tensor(module, name, device)
10 changes: 5 additions & 5 deletions src/compressed_tensors/offload/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed as dist
from compressed_tensors.utils import is_accelerator_type
from torch._prims_common import DeviceLikeType


class OffloadCache(MutableMapping, ABC):
Expand All @@ -32,8 +33,8 @@ class OffloadCache(MutableMapping, ABC):
info, see `compressed_tensors.offload::(disable_offloading|disable_onloading)`
"""

onload_device: torch.device | str
offload_device: torch.device | Literal["disk"]
onload_device: DeviceLikeType
offload_device: DeviceLikeType | Literal["disk"]
Comment on lines +36 to +37
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how onload_device is used after being stored - does any code assume it's a torch.device?
rg -n -A2 'onload_device' src/compressed_tensors/offload/ --type py | head -60

Repository: vllm-project/compressed-tensors

Length of output: 4390


🏁 Script executed:

# Find where DeviceLikeType is defined and check its definition
rg -n "DeviceLikeType" src/compressed_tensors/ --type py -B2 -A2 | head -80

Repository: vllm-project/compressed-tensors

Length of output: 5353


🏁 Script executed:

# Check the send_tensors function signature
rg -n "def send_tensors" src/compressed_tensors/ --type py -A5

Repository: vllm-project/compressed-tensors

Length of output: 543


🏁 Script executed:

# Check the original type annotations in base.py before change
git show HEAD:src/compressed_tensors/offload/cache/base.py | head -60

Repository: vllm-project/compressed-tensors

Length of output: 2316


🏁 Script executed:

# Search for the get_execution_device function mentioned in review comment
rg -n "get_execution_device" src/compressed_tensors/ --type py -B2 -A5

Repository: vllm-project/compressed-tensors

Length of output: 7067


🏁 Script executed:

# Check what __init__.py line 149 actually contains (referenced in review)
rg -n "^" src/compressed_tensors/offload/__init__.py | head -200

Repository: vllm-project/compressed-tensors

Length of output: 7278


🏁 Script executed:

# Check all usages of onload_device to see if any depend on it being torch.device
rg -n "\.onload_device" src/compressed_tensors/ --type py -B1 -A1

Repository: vllm-project/compressed-tensors

Length of output: 2514


🏁 Script executed:

# Check torch.device constructor to confirm it accepts strings
python3 << 'EOF'
import torch
# Verify that torch.device can be called with string
try:
    d = torch.device("cuda:0")
    print(f"torch.device('cuda:0') works: {d}")
except Exception as e:
    print(f"Error: {e}")
EOF

Repository: vllm-project/compressed-tensors

Length of output: 192


🏁 Script executed:

# Check what AlignDevicesHook expects for execution_device
rg -n "AlignDevicesHook" src/compressed_tensors/ --type py -B2 -A10

Repository: vllm-project/compressed-tensors

Length of output: 6674


Type annotation mismatch: get_execution_device() return type is incorrect

The function get_execution_device() declares a return type of torch.device (line 141 in __init__.py), but it directly returns module._parameters.onload_device which is typed as DeviceLikeType and can be a string. In dispatch.py (line 56), strings are assigned directly: module._parameters.onload_device = onload_device where the parameter accepts torch.device | str.

While the code works in practice because callers accept DeviceLikeType, the type contract is violated. Normalizing to torch.device in the constructor resolves this:

🔧 Suggested fix
 def __init__(
     self,
     onload_device: torch.device | str,
     offload_device: torch.device | str | Literal["disk"] | None = None,
 ):
     super().__init__()
-    self.onload_device = onload_device
+    self.onload_device = torch.device(onload_device)
     self.offloaded_values = dict()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/compressed_tensors/offload/cache/base.py` around lines 35 - 36,
get_execution_device's annotated return type (torch.device) doesn't match the
actual value returned (module._parameters.onload_device which can be a
string/DeviceLikeType); to fix, normalize and store a proper torch.device in the
object constructor so get_execution_device can safely return torch.device:
during initialization (where onload_device is accepted/assigned, e.g., the
constructor that sets module._parameters.onload_device and where dispatch.py
assigns module._parameters.onload_device = onload_device), convert the incoming
onload_device (string or torch.device) to torch.device (using
torch.device(onload_device) or equivalent) and assign that normalized
torch.device back to module._parameters.onload_device so get_execution_device
can keep its torch.device return type without type mismatch.


# global flags for disabling
offloading_disabled: ClassVar[bool] = False
Expand All @@ -47,8 +48,7 @@ class OffloadCache(MutableMapping, ABC):

@classmethod
def cls_from_device(
cls,
device: torch.device | str | Literal["disk"] | None = None,
cls, device: DeviceLikeType | Literal["disk"]
) -> type["OffloadCache"]:
"""
Get the subclass which implements offloading for the given `offload_device`.
Expand Down Expand Up @@ -274,4 +274,4 @@ def disable_onloading(cls):
yield
OffloadCache.onloading_disabled = restore_value
else:
yield
yield
5 changes: 4 additions & 1 deletion src/compressed_tensors/offload/cache/cpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Literal, Optional

import torch
from compressed_tensors.offload.cache.base import OffloadCache
from compressed_tensors.offload.cache.utils import catch_cpu_mem_error
from compressed_tensors.offload.utils import send_tensors
from torch._prims_common import DeviceLikeType


class CPUCache(OffloadCache):
Expand Down Expand Up @@ -45,4 +48,4 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None):
:param data: new data to copy from
"""
if data is not None:
offloaded.copy_(data)
offloaded.copy_(data)
10 changes: 3 additions & 7 deletions src/compressed_tensors/offload/cache/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,12 @@
import torch
from compressed_tensors.offload.cache.base import OffloadCache
from compressed_tensors.offload.utils import send_tensors


if TYPE_CHECKING:
from torch._prims_common import DeviceLikeType
from torch._prims_common import DeviceLikeType


class DeviceCache(OffloadCache):
"""
Handles offloading and onloading tensors from/to device memory. Onloading
tensors is typically a no-op (except onload device has been modified).
Handles offloading and onloading tensors from/to device memory.
"""

def __init__(
Expand Down Expand Up @@ -57,4 +53,4 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None):
:param data: new data to copy from
"""
if data is not None:
offloaded.copy_(data)
offloaded.copy_(data)
19 changes: 16 additions & 3 deletions src/compressed_tensors/offload/cache/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from safetensors import safe_open
from safetensors.torch import save_file


if TYPE_CHECKING:
from torch._prims_common import DeviceLikeType

Expand Down Expand Up @@ -128,7 +127,7 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None):
"""
Write new param data to file that already exists.

:param offloaded: meta tensors representating parameter to update
:param offloaded: meta tensor representing the parameter to update
:param data: new data
"""
# get weight info from index
Expand All @@ -153,6 +152,20 @@ def create_checkpoint_symlink(
weight_info: dict,
offload_dir: str | os.PathLike | None,
) -> None:
"""
Create a symlink to a checkpoint safetensors file. This symlink allows
individual tensor data to be individually modified and deleted without affecting
the original model checkpoint files.

When reading, the symlink redirects the read to the checkpoint file
When updating, the symlink is destroyed and a new file written to the same path
When deleting, the symlink (or new file) is destroyed

:param offloaded: meta tensor representing the parameter in the checkpoint
:param weight_info: info (typically from accelerate) pointing to checkpoint
:param offload_dir: offload directly to create symlink in
"""
assert offloaded.device.type == "meta"
assert (
is_source_process()
), "Must call on rank 0 to avoid id collisions between ranks"
Expand Down Expand Up @@ -191,4 +204,4 @@ def _get_safe_open_device(device: "DeviceLikeType") -> str:
index = device.index
return f"{device.type}:{index}"
else:
return device.type
return device.type
5 changes: 3 additions & 2 deletions src/compressed_tensors/offload/convert/from_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from compressed_tensors.offload.dispatch import dispatch_with_map
from compressed_tensors.offload.utils import to_tensor
from loguru import logger
from torch._prims_common import DeviceLikeType


if TYPE_CHECKING:
Expand Down Expand Up @@ -88,7 +89,7 @@ def remove_accelerate(model: torch.nn.Module) -> tuple["DeviceMap", str | None]:

def remove_accelerate_from_module(
module: torch.nn.Module,
) -> tuple[torch.device | None, torch.device | Literal["disk"] | None, str | None]:
) -> tuple[DeviceLikeType | None, DeviceLikeType | Literal["disk"] | None, str | None]:
"""
Remove accelerate offloading from a module, if present.
Absolutely no device movement occurs, and parameters/buffers pointers from state
Expand Down Expand Up @@ -236,4 +237,4 @@ def _unwrap_prefixed_dataset(weights_map, PrefixedDatasetType):
def _set_or_validate_offload(current: str | None, new: str) -> str:
if current not in (None, new):
raise ValueError("Expected all accelerate tensors to share offload")
return new
return new
5 changes: 3 additions & 2 deletions src/compressed_tensors/offload/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from compressed_tensors.utils.binary_search import SearchFailureError, max_binary_search
from compressed_tensors.utils.helpers import deprecated
from loguru import logger
from torch._prims_common import DeviceLikeType
from tqdm import tqdm
from transformers import PreTrainedModel

Expand Down Expand Up @@ -104,7 +105,7 @@ def dispatch_with_map(


def get_device_map(
model: torch.nn.Module, default_device: torch.device = torch.device("cpu")
model: torch.nn.Module, default_device: DeviceLikeType = torch.device("cpu")
) -> DeviceMap:
"""
Get the device map of a CT-offloaded model
Expand Down Expand Up @@ -309,4 +310,4 @@ def _get_greedy_dispatch(
memory_remaining[device] -= size
break

return dispatch, memory_remaining
return dispatch, memory_remaining
11 changes: 6 additions & 5 deletions src/compressed_tensors/offload/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from collections.abc import Container
from dataclasses import fields, is_dataclass
from itertools import chain
from typing import TypeVar
from typing import Optional, TypeVar

import torch
from compressed_tensors.utils.helpers import patch_attr
from loguru import logger
from torch._prims_common import DeviceLikeType


__all__ = [
Expand Down Expand Up @@ -66,7 +67,7 @@ def send_tensors(value: T, *args, **kwargs) -> T:


def get_module_device(
module: torch.nn.Module, default: torch.device | None = None
module: torch.nn.Module, default: Optional[DeviceLikeType] = None
) -> torch.device:
"""
Infer the device of a module using the first
Expand All @@ -80,7 +81,7 @@ def get_module_device(
if tensor is not None:
return tensor.device
elif default is not None:
return default
return torch.device(default)
else:
logger.warning(
f"Unable to get execution device of {module}, falling back to CPU",
Expand All @@ -92,7 +93,7 @@ def get_module_device(
def move_module_tensor(
module: torch.nn.Module,
name: str,
device: int | str | torch.device,
device: DeviceLikeType,
):
"""
Move a module's tensor to a new device
Expand Down Expand Up @@ -237,4 +238,4 @@ def as_single_threaded():
patch_attr(DistributedCPUCache, "offload", CPUCache.offload),
patch_attr(DistributedDiskCache, "offload", DiskCache.offload),
):
yield
yield
Loading