diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index f4339c0c3..7686344db 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -26,16 +26,17 @@ # from functools import reduce # from math import gcd -from typing import Dict, Optional, Callable, Any, List +from typing import Any, Callable, Dict, List, Optional import torch import torch.nn as nn import torch.nn.functional as F # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group' + from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout -from torchao.experimental.quant_api import EmbeddingQuantizer +from torchao.experimental.quant_api import EmbeddingQuantizer, SharedEmbeddingQuantizer from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( int4_weight_only, @@ -50,8 +51,8 @@ find_multiple, get_device_str, get_precision, - set_precision, name_to_dtype, + set_precision, state_dict_device, use_et_backend, ) @@ -79,6 +80,7 @@ def get_named_parameters(func: Callable) -> List[str]: ] return named_params + def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: for key in list(q_kwargs.keys()): if key not in named_params: @@ -111,105 +113,161 @@ def quantize_model( if isinstance(quantize_options, str): quantize_options = json.loads(quantize_options) + ordered_quantize_options = { } + for quantizer in quantizer_class_dict.keys(): + if quantizer in quantize_options: + ordered_quantize_options |= { quantizer : quantize_options.pop(quantizer) } + + if len(quantize_options) != 0: + raise RuntimeError(f"unknown quantizer(s) {quantize_options.keys()} specified") + + quantize_options = ordered_quantize_options for quantizer, q_kwargs in quantize_options.items(): - if quantizer not in quantizer_class_dict: - raise RuntimeError(f"unknown quantizer {quantizer} specified") - else: - # Use tensor subclass API for int4 weight only. - if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4": - quantize_(model, int4_weight_only(q_kwargs["groupsize"])) - if not support_tensor_subclass: - unwrap_tensor_subclass(model) - continue - - if quantizer == "linear:a8wxdq": - if get_precision() != torch.float32: - print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") - set_precision(torch.float32) - - group_size = q_kwargs["groupsize"] - bit_width = q_kwargs["bitwidth"] - has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerAxis() if group_size == -1 else PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bit_width}") - weight_mapping_type = ( - MappingType.ASYMMETRIC - if has_weight_zeros - else MappingType.SYMMETRIC + if quantizer == "experimental:embedding": + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + weight_granularity = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + + try: + model = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=weight_granularity, + mapping_type=weight_mapping_type, + use_fallback=False, + ).quantize(model) + except Exception as e: + print( + "Encountered error during quantization with experimental EmbeddingQuantization: {e}" ) + if quantizer == "experimental:shared": + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + weight_granularity = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) - try: - quantize_( - model, - Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=granularity, - weight_mapping_type=weight_mapping_type, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), - ), - ) - except Exception as e: - print("Encountered error during quantization: {e}") - print("Trying with QDQLayout") - quantize_( - model, - Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=granularity, - weight_mapping_type=weight_mapping_type, - layout=QDQLayout(), - ), - ) + try: + model = SharedEmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=weight_granularity, + mapping_type=weight_mapping_type, + use_fallback=False, + ).quantize(model) + except Exception as e: + print( + "Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}" + ) + # Use tensor subclass API for int4 weight only. + if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4": + quantize_(model, int4_weight_only(q_kwargs["groupsize"])) + if not support_tensor_subclass: + unwrap_tensor_subclass(model) + continue + if quantizer == "linear:a8wxdq": + if get_precision() != torch.float32: + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." + ) + set_precision(torch.float32) + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + granularity = PerAxis() if group_size == -1 else PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) - if not support_tensor_subclass: - unwrap_tensor_subclass(model) - continue - - if quantizer == "embedding:wx": - # These quantizers require float32 input weights. Note that after quantization, - # the weights will no longer be float32, but lowbit integers - if get_precision() != torch.float32: - print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") - set_precision(torch.float32) - - group_size = q_kwargs["groupsize"] - bit_width = q_kwargs["bitwidth"] - has_weight_zeros = q_kwargs.get("has_weight_zeros", True) - q_kwargs["granularity"] = ( - PerAxis() if group_size == -1 else PerGroup(group_size) + try: + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=granularity, + weight_mapping_type=weight_mapping_type, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + ), + ) + except Exception as e: + print("Encountered error during quantization: {e}") + print("Trying with QDQLayout") + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=granularity, + weight_mapping_type=weight_mapping_type, + layout=QDQLayout(), + ), ) - q_kwargs["weight_dtype"] = getattr(torch, f"int{bit_width}") - q_kwargs["mapping_type"] = ( - MappingType.ASYMMETRIC - if has_weight_zeros - else MappingType.SYMMETRIC + if not support_tensor_subclass: + unwrap_tensor_subclass(model) + continue + if quantizer == "embedding:wx": + # These quantizers require float32 input weights. Note that after quantization, + # the weights will no longer be float32, but lowbit integers + if get_precision() != torch.float32: + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." ) - q_kwargs["use_fallback"] = False - del q_kwargs["groupsize"] - del q_kwargs["bitwidth"] - - if quantizer == "linear:afpwx" and device != "mps": - raise RuntimeError("linear:afpwx quantization can only run on mps device!") - - # We set global precision from quantize options if it is specified at cli.py:485 - # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat - precision = get_precision() - - q = quantizer_class_dict[quantizer] - named_params = get_named_parameters(q.__init__) - q_kwargs = validate_args(named_params, q_kwargs, quantizer) - - # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs - if "tokenizer" in named_params: - q_kwargs["tokenizer"] = tokenizer - if quantizer == "embedding:wx": - quant_handler = q(**q_kwargs) - else: - quant_handler = q(device=device, precision=precision, **q_kwargs) + set_precision(torch.float32) - # quantize model - model = quant_handler.quantize(model) + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs.get("has_weight_zeros", True) + q_kwargs["granularity"] = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + q_kwargs["weight_dtype"] = getattr(torch, f"int{bit_width}") + q_kwargs["mapping_type"] = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + q_kwargs["use_fallback"] = False + del q_kwargs["groupsize"] + del q_kwargs["bitwidth"] + + if quantizer == "linear:afpwx" and device != "mps": + raise RuntimeError( + "linear:afpwx quantization can only run on mps device!" + ) + # We set global precision from quantize options if it is specified at cli.py:485 + # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat + precision = get_precision() + + q = quantizer_class_dict[quantizer] + named_params = get_named_parameters(q.__init__) + q_kwargs = validate_args(named_params, q_kwargs, quantizer) + + # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs + if "tokenizer" in named_params: + q_kwargs["tokenizer"] = tokenizer + if quantizer == "embedding:wx": + quant_handler = q(**q_kwargs) + else: + quant_handler = q(device=device, precision=precision, **q_kwargs) + # quantize model + model = quant_handler.quantize(model) ######################################################################### @@ -352,7 +410,6 @@ def dynamically_quantize_per_channel( padding = groupsize - (x_shape_1 % groupsize) x = F.pad(x, (0, padding)) items = groupsize - # default setup for affine quantization of activations eps = torch.finfo(torch.float32).eps @@ -374,7 +431,6 @@ def dynamically_quantize_per_channel( # ensure scales is the same dtype as the original tensor scales = torch.clamp(scales, min=eps).to(x.dtype) zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) - # quantize based on qmin/qmax/scales/zp # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 x_div = x / scales.unsqueeze(-1) @@ -439,7 +495,6 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128) # needed for GPTQ single column quantize if groupsize > w.shape[-1] and scales.shape[-1] == 1: groupsize = w.shape[-1] - assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 @@ -522,7 +577,6 @@ def linear_int8_aoti(input, weight, scales): weight, scales, ).reshape(input.shape[:-1] + (weight.shape[0],)) - return F.linear( input, ( @@ -576,7 +630,6 @@ def linear_int8_et(input, weight, scales): lin = F.linear(input, weight.to(dtype=input.dtype)) # print(f"linear shape {lin.shape}, scales shape {scales.shape}") return lin * scales - return _qdq_dynamic_quantized_linear( x_fp32=input.float(), x_quant_min=-128, @@ -589,7 +642,6 @@ def linear_int8_et(input, weight, scales): weight_quant_max=127, bias_fp32=None, ).to(dtype=input.dtype) - return F.linear( input, ( @@ -621,10 +673,8 @@ def __init__( super().__init__() if dtype is None: dtype = torch.get_default_dtype() - if device is None: device = "cpu" - assert not bias, "Bias is not supported by LinearInt8" self.in_features = in_features self.out_features = out_features @@ -643,7 +693,6 @@ def __init__( else: n_groups = (in_features + groupsize - 1) // groupsize scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) - self.register_buffer("weight", weight.to(device)) self.register_buffer("scales", scales.to(device)) @@ -693,7 +742,6 @@ def quantize(self, module): range_max = 127 else: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - for name, child in module.named_children(): # print(f"name: {name}") if isinstance(child, nn.Linear): @@ -733,7 +781,6 @@ def quantize(self, module): ) else: self.quantize(child) - return module def quantized_model(self) -> nn.Module: @@ -775,7 +822,6 @@ def __init__( raise RuntimeError( f"QUantized embedding does not support bitwidth={bitwidth}" ) - if weight is None: groups_per_row = (embedding_dim + groupsize - 1) // groupsize weight = torch.empty( @@ -791,7 +837,6 @@ def __init__( dtype=dtype, device=device, ).squeeze(dim=-1) - self.register_buffer( "weight", weight, @@ -830,7 +875,6 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor: weight = weight.to(torch.int8).add(-8) else: weight = self.weight - scales = self.scales.view(weight.shape[0], -1) result_weights = F.embedding(indices, weight) @@ -888,7 +932,6 @@ def quantize(self, module): range_max = 127 else: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - for name, child in module.named_children(): # print(f"name: {name}") if isinstance(child, nn.Embedding): @@ -911,7 +954,6 @@ def quantize(self, module): if self.bitwidth == 4: if weight.shape[-1] % 2 != 0: raise RuntimeError("automatic padding not implemented yet") - weight_range_shifted = weight.add(8).view(torch.uint8) weight_view = weight_range_shifted.view( weight.shape[0], weight.shape[1] // 2, 2 @@ -920,7 +962,6 @@ def quantize(self, module): weight_odd = weight_view[:, :, 1] weight_packed = weight_even + weight_odd weight = weight_packed - weight = weight scales = scales.squeeze(dim=-1) @@ -941,7 +982,6 @@ def quantize(self, module): ) else: self.quantize(child) - return module def quantized_model(self) -> nn.Module: @@ -961,8 +1001,10 @@ def quantized_model(self) -> nn.Module: "precision": PrecisionHandler, "executor": ExecutorHandler, "linear:int4": Int4WeightOnlyQuantizer, - "linear:a8wxdq": None, # uses quantize_ API + "linear:a8wxdq": None, # uses quantize_ API "linear:a8w4dq": Int8DynActInt4WeightQuantizer, + "experimental:shared": SharedEmbeddingQuantizer, + "experimental:embedding": EmbeddingQuantizer, } try: