Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions aq_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight:
)
return self.quantized_weight

### modified _compute_mse for DropbyDrop
def _compute_mse(self, selection: Union[slice, ellipsis] = ...) -> torch.Tensor:
"""
Compute the activation MSE error = ||X @ quantized_weight - X @ reference_weight||^2
Expand All @@ -114,21 +115,58 @@ def _compute_mse(self, selection: Union[slice, ellipsis] = ...) -> torch.Tensor:
The indices / slices must correspond to output channels (if out_group_size==1) or groups (if > 1).
Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size )
"""
# assert self.quantized_weight is not None, "must be called inside / after AQUtil.quantize"
# quantized_weight = self.quantized_weight(selection)

# if isinstance(selection, ellipsis):
# reference_weight = self.layer.weight.detach().to(quantized_weight.dtype)
# else:
# assert isinstance(selection, slice)
# out_channel_selection = slice(
# selection.start * self.quantized_weight.out_group_size,
# selection.stop * self.quantized_weight.out_group_size,
# )

# reference_weight = self.layer.weight.detach()[out_channel_selection].to(quantized_weight.dtype)
# delta_weight = (quantized_weight - reference_weight).to(self.XTX.dtype)
# return (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features

assert self.quantized_weight is not None, "must be called inside / after AQUtil.quantize"
quantized_weight = self.quantized_weight(selection)

if isinstance(selection, ellipsis):
reference_weight = self.layer.weight.detach().to(quantized_weight.dtype)
reference_weight = self.layer.weight.detach().to(self.quantized_weight.codebooks.dtype)
else:
assert isinstance(selection, slice)
out_channel_selection = slice(
selection.start * self.quantized_weight.out_group_size,
selection.stop * self.quantized_weight.out_group_size,
)
reference_weight = self.layer.weight.detach()[out_channel_selection].to(self.quantized_weight.codebooks.dtype)

total_codebooks = self.quantized_weight.num_codebooks

# EXAMPLE - 35W
codebook_weights = torch.tensor([0,0, 0.5, 0, 0.5], device=self.device, dtype=self.XTX.dtype)

total_loss = torch.tensor(0.0, device=self.device, dtype=self.XTX.dtype)

# Inspired by Matryoshka Representation Learning
for i in range(1, total_codebooks + 1):

quantized_weight_i = self.quantized_weight(selection, num_codebooks=i)

delta_weight = (quantized_weight_i - reference_weight).to(self.XTX.dtype)
mse_i = (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features

# Ensure all tensors are on the same device before computation
mse_i = mse_i.to(self.device)
codebook_weight = codebook_weights[i-1].to(self.device)

#total_loss = total_loss + mse_i
total_loss = total_loss + codebook_weight* mse_i

return total_loss

reference_weight = self.layer.weight.detach()[out_channel_selection].to(quantized_weight.dtype)
delta_weight = (quantized_weight - reference_weight).to(self.XTX.dtype)
return (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features

def _replace_and_compute_mse(self, params_to_replace: nn.ParameterDict, selection: slice) -> torch.Tensor:
"""Utility for parallelism: replace the specified parameters of self.quantized_weight, then compute MSE"""
Expand Down
40 changes: 40 additions & 0 deletions quantize.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
#SBATCH --job-name=quantize
#SBATCH --output=slurm-%j-MAT35-gemma2b.out
#SBATCH --error=slurm-%j-MAT35-gemma2b.err

export CUDA_VISIBLE_DEVICES=0,1,2,3
export MODEL_PATH=/model-weights/gemma-2b/
export DATASET_PATH=wikitext2
export SAVE_PATH=ANONYMOUS

python main.py $MODEL_PATH $DATASET_PATH \
--nsamples=1024 \
--val_size=32 \
--num_codebooks=5 \
--nbits_per_codebook=8 \
--in_group_size=8 \
--relative_mse_tolerance=0.01 \
--finetune_batch_size=32 \
--finetune_max_epochs=10 \
--finetune_early_stop=3 \
--finetune_keep_best \
--local_batch_size=1 \
--offload_activations \
--resume \
--save $SAVE_PATH

# python main.py $MODEL_PATH $DATASET_PATH \
# --nsamples=1024 \
# --val_size=32 \
# --num_codebooks=5 \
# --nbits_per_codebook=8 \
# --in_group_size=8 \
# --relative_mse_tolerance=0.01 \
# --finetune_batch_size=32 \
# --finetune_max_epochs=10 \
# --finetune_early_stop=3 \
# --finetune_keep_best \
# --local_batch_size=1 \
# --offload_activations \
# --load ANONYMOUS
10 changes: 10 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# run the following in terminal:
sbatch \
--account=ANONYMOUS \
--nodes=1 \
--gres=gpu:l40s:4\
--ntasks-per-node=1 \
--mem=120G \
--cpus-per-task=4 \
--time=40:00:00 \
quantize.sh
18 changes: 15 additions & 3 deletions src/aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,28 @@ def get_scales(self) -> torch.Tensor:
def shape(self) -> Tuple[int, int]:
return self.out_features, self.in_features

def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ...):
def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ..., num_codebooks: Optional[int] = None):
"""
Differentably reconstruct the weight (or parts thereof) from compressed components
:param selection: By default, reconstruct the entire weight. If selection is specified, this method will instead
reconstruct a portion of weight for the corresponding output dimensions (used for parallelism).
The indices / slices must correspond to output channels (if out_group_size==1) or groups (if > 1).
Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size )

:param num_codebooks: Number of codebooks to use for reconstruction. If None, all codebooks are used.
If specified, only the first `num_codebooks` will be used.
"""
weight = _dequantize_weight(self.get_codes()[selection], self.get_codebooks(), self.get_scales()[selection])
# 检查 num_codebooks 参数
if num_codebooks is not None:
num_codebooks = min(num_codebooks, self.num_codebooks)

# FOR DROP-BY-DROP's INFERENCE, modify num_codebooks (i.e. num_codebooks = 3) to simulate "dropping" of codebooks without any
# additional retraining or finetuning. Just load the quantized model through $SAVE_PATH in the shell script.
weight = _dequantize_weight(
self.get_codes()[selection],
self.get_codebooks(),
self.get_scales()[selection],
num_codebooks
)
return weight

@torch.no_grad()
Expand Down
101 changes: 82 additions & 19 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,35 +61,98 @@ def maybe_script(fn: callable) -> callable:
return torch.jit.script(fn) if should_script else fn


# @maybe_script
# def _dequantize_weight(
# codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
# ) -> torch.Tensor:
# """
# Decode float weights from quantization codes. Differentiable.
# :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
# :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
# :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
# :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
# """
# num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
# num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
# out_features = num_out_groups * out_group_size
# in_features = num_in_groups * in_group_size
# codebook_offsets = torch.arange(
# 0, num_codebooks * codebook_size, codebook_size, device=codes.device
# ) # shape: [num_codebooks]
# reconstructed_weight_flat = F.embedding_bag(
# codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
# ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]

# reconstructed_weight_groupwise = reconstructed_weight_flat.view(
# list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
# )
# if scales is not None:
# reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
# return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])

@maybe_script
def _dequantize_weight(
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None, num_codebooks: Optional[int] = None
) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:param num_codebooks: Number of codebooks to use. If None, all available codebooks are used.
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]

reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])

# 使用所有可用的 codebook 或仅使用指定数量的 codebook
available_codebooks = codebooks.shape[0]
effective_num_codebooks = available_codebooks if num_codebooks is None else min(num_codebooks, available_codebooks)

# 如果要使用所有 codebook,使用原始方法
if effective_num_codebooks == available_codebooks:
num_out_groups, num_in_groups, _ = codes.shape[-3:]
_, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, available_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]

reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])

# 使用部分 codebook
else:
num_out_groups, num_in_groups, _ = codes.shape[-3:]
_, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size

# 只使用前 effective_num_codebooks 个 codebook
selected_codes = codes[..., :effective_num_codebooks]
selected_codebooks = codebooks[:effective_num_codebooks]

codebook_offsets = torch.arange(
0, effective_num_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [effective_num_codebooks]

reconstructed_weight_flat = F.embedding_bag(
selected_codes.flatten(0, -2) + codebook_offsets,
selected_codebooks.flatten(0, 1).flatten(-2, -1),
mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]

reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])

@contextlib.contextmanager
def using_tf32(enabled: bool):
Expand Down