diff --git a/aq_engine.py b/aq_engine.py index ea14eec..00c6467 100644 --- a/aq_engine.py +++ b/aq_engine.py @@ -105,6 +105,7 @@ def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight: ) return self.quantized_weight + ### modified _compute_mse for DropbyDrop def _compute_mse(self, selection: Union[slice, ellipsis] = ...) -> torch.Tensor: """ Compute the activation MSE error = ||X @ quantized_weight - X @ reference_weight||^2 @@ -114,21 +115,58 @@ def _compute_mse(self, selection: Union[slice, ellipsis] = ...) -> torch.Tensor: The indices / slices must correspond to output channels (if out_group_size==1) or groups (if > 1). Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size ) """ + # assert self.quantized_weight is not None, "must be called inside / after AQUtil.quantize" + # quantized_weight = self.quantized_weight(selection) + + # if isinstance(selection, ellipsis): + # reference_weight = self.layer.weight.detach().to(quantized_weight.dtype) + # else: + # assert isinstance(selection, slice) + # out_channel_selection = slice( + # selection.start * self.quantized_weight.out_group_size, + # selection.stop * self.quantized_weight.out_group_size, + # ) + + # reference_weight = self.layer.weight.detach()[out_channel_selection].to(quantized_weight.dtype) + # delta_weight = (quantized_weight - reference_weight).to(self.XTX.dtype) + # return (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features + assert self.quantized_weight is not None, "must be called inside / after AQUtil.quantize" - quantized_weight = self.quantized_weight(selection) if isinstance(selection, ellipsis): - reference_weight = self.layer.weight.detach().to(quantized_weight.dtype) + reference_weight = self.layer.weight.detach().to(self.quantized_weight.codebooks.dtype) else: assert isinstance(selection, slice) out_channel_selection = slice( selection.start * self.quantized_weight.out_group_size, selection.stop * self.quantized_weight.out_group_size, ) + reference_weight = self.layer.weight.detach()[out_channel_selection].to(self.quantized_weight.codebooks.dtype) + + total_codebooks = self.quantized_weight.num_codebooks + + # EXAMPLE - 35W + codebook_weights = torch.tensor([0,0, 0.5, 0, 0.5], device=self.device, dtype=self.XTX.dtype) + + total_loss = torch.tensor(0.0, device=self.device, dtype=self.XTX.dtype) + + # Inspired by Matryoshka Representation Learning + for i in range(1, total_codebooks + 1): + + quantized_weight_i = self.quantized_weight(selection, num_codebooks=i) + + delta_weight = (quantized_weight_i - reference_weight).to(self.XTX.dtype) + mse_i = (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features + + # Ensure all tensors are on the same device before computation + mse_i = mse_i.to(self.device) + codebook_weight = codebook_weights[i-1].to(self.device) + + #total_loss = total_loss + mse_i + total_loss = total_loss + codebook_weight* mse_i + + return total_loss - reference_weight = self.layer.weight.detach()[out_channel_selection].to(quantized_weight.dtype) - delta_weight = (quantized_weight - reference_weight).to(self.XTX.dtype) - return (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features def _replace_and_compute_mse(self, params_to_replace: nn.ParameterDict, selection: slice) -> torch.Tensor: """Utility for parallelism: replace the specified parameters of self.quantized_weight, then compute MSE""" diff --git a/quantize.sh b/quantize.sh new file mode 100644 index 0000000..8e8eb35 --- /dev/null +++ b/quantize.sh @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --job-name=quantize +#SBATCH --output=slurm-%j-MAT35-gemma2b.out +#SBATCH --error=slurm-%j-MAT35-gemma2b.err + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export MODEL_PATH=/model-weights/gemma-2b/ +export DATASET_PATH=wikitext2 +export SAVE_PATH=ANONYMOUS + +python main.py $MODEL_PATH $DATASET_PATH \ + --nsamples=1024 \ + --val_size=32 \ + --num_codebooks=5 \ + --nbits_per_codebook=8 \ + --in_group_size=8 \ + --relative_mse_tolerance=0.01 \ + --finetune_batch_size=32 \ + --finetune_max_epochs=10 \ + --finetune_early_stop=3 \ + --finetune_keep_best \ + --local_batch_size=1 \ + --offload_activations \ + --resume \ + --save $SAVE_PATH + +# python main.py $MODEL_PATH $DATASET_PATH \ +# --nsamples=1024 \ +# --val_size=32 \ +# --num_codebooks=5 \ +# --nbits_per_codebook=8 \ +# --in_group_size=8 \ +# --relative_mse_tolerance=0.01 \ +# --finetune_batch_size=32 \ +# --finetune_max_epochs=10 \ +# --finetune_early_stop=3 \ +# --finetune_keep_best \ +# --local_batch_size=1 \ +# --offload_activations \ +# --load ANONYMOUS diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..3c95fdd --- /dev/null +++ b/run.sh @@ -0,0 +1,10 @@ +# run the following in terminal: +sbatch \ + --account=ANONYMOUS \ + --nodes=1 \ + --gres=gpu:l40s:4\ + --ntasks-per-node=1 \ + --mem=120G \ + --cpus-per-task=4 \ + --time=40:00:00 \ + quantize.sh diff --git a/src/aq.py b/src/aq.py index df2f681..1585f3d 100644 --- a/src/aq.py +++ b/src/aq.py @@ -197,16 +197,28 @@ def get_scales(self) -> torch.Tensor: def shape(self) -> Tuple[int, int]: return self.out_features, self.in_features - def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ...): + def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ..., num_codebooks: Optional[int] = None): """ Differentably reconstruct the weight (or parts thereof) from compressed components :param selection: By default, reconstruct the entire weight. If selection is specified, this method will instead reconstruct a portion of weight for the corresponding output dimensions (used for parallelism). The indices / slices must correspond to output channels (if out_group_size==1) or groups (if > 1). Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size ) - + :param num_codebooks: Number of codebooks to use for reconstruction. If None, all codebooks are used. + If specified, only the first `num_codebooks` will be used. """ - weight = _dequantize_weight(self.get_codes()[selection], self.get_codebooks(), self.get_scales()[selection]) + # 检查 num_codebooks 参数 + if num_codebooks is not None: + num_codebooks = min(num_codebooks, self.num_codebooks) + + # FOR DROP-BY-DROP's INFERENCE, modify num_codebooks (i.e. num_codebooks = 3) to simulate "dropping" of codebooks without any + # additional retraining or finetuning. Just load the quantized model through $SAVE_PATH in the shell script. + weight = _dequantize_weight( + self.get_codes()[selection], + self.get_codebooks(), + self.get_scales()[selection], + num_codebooks + ) return weight @torch.no_grad() diff --git a/src/utils.py b/src/utils.py index 9dd77ed..e3ee3f4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -61,35 +61,98 @@ def maybe_script(fn: callable) -> callable: return torch.jit.script(fn) if should_script else fn +# @maybe_script +# def _dequantize_weight( +# codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None +# ) -> torch.Tensor: +# """ +# Decode float weights from quantization codes. Differentiable. +# :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] +# :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] +# :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] +# :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] +# """ +# num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] +# num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape +# out_features = num_out_groups * out_group_size +# in_features = num_in_groups * in_group_size +# codebook_offsets = torch.arange( +# 0, num_codebooks * codebook_size, codebook_size, device=codes.device +# ) # shape: [num_codebooks] +# reconstructed_weight_flat = F.embedding_bag( +# codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" +# ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] + +# reconstructed_weight_groupwise = reconstructed_weight_flat.view( +# list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] +# ) +# if scales is not None: +# reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) +# return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + @maybe_script def _dequantize_weight( - codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None + codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None, num_codebooks: Optional[int] = None ) -> torch.Tensor: """ Decode float weights from quantization codes. Differentiable. :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :param num_codebooks: Number of codebooks to use. If None, all available codebooks are used. :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] """ - num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] - num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape - out_features = num_out_groups * out_group_size - in_features = num_in_groups * in_group_size - codebook_offsets = torch.arange( - 0, num_codebooks * codebook_size, codebook_size, device=codes.device - ) # shape: [num_codebooks] - reconstructed_weight_flat = F.embedding_bag( - codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" - ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] - - reconstructed_weight_groupwise = reconstructed_weight_flat.view( - list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] - ) - if scales is not None: - reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) - return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) - + # 使用所有可用的 codebook 或仅使用指定数量的 codebook + available_codebooks = codebooks.shape[0] + effective_num_codebooks = available_codebooks if num_codebooks is None else min(num_codebooks, available_codebooks) + + # 如果要使用所有 codebook,使用原始方法 + if effective_num_codebooks == available_codebooks: + num_out_groups, num_in_groups, _ = codes.shape[-3:] + _, codebook_size, out_group_size, in_group_size = codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, available_codebooks * codebook_size, codebook_size, device=codes.device + ) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] + ) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) + return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + # 使用部分 codebook + else: + num_out_groups, num_in_groups, _ = codes.shape[-3:] + _, codebook_size, out_group_size, in_group_size = codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + + # 只使用前 effective_num_codebooks 个 codebook + selected_codes = codes[..., :effective_num_codebooks] + selected_codebooks = codebooks[:effective_num_codebooks] + + codebook_offsets = torch.arange( + 0, effective_num_codebooks * codebook_size, codebook_size, device=codes.device + ) # shape: [effective_num_codebooks] + + reconstructed_weight_flat = F.embedding_bag( + selected_codes.flatten(0, -2) + codebook_offsets, + selected_codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] + ) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) + return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) @contextlib.contextmanager def using_tf32(enabled: bool):