From 599c17b0873a9957012a6b2fb0171b662bdf66bc Mon Sep 17 00:00:00 2001 From: dillondesilva Date: Sun, 13 Apr 2025 21:24:58 +1000 Subject: [PATCH 1/5] Updated quantize.py to include experimental EmbeddingQuantizer and SharedEmbeddingQuantizer --- torchchat/utils/quantize.py | 196 ++++++++++++++++++++++++++---------- 1 file changed, 142 insertions(+), 54 deletions(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 6246f1c05..1ab85f038 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -20,20 +20,35 @@ # torchao Quantizer: # * Int8DynActInt4WeightQuantizer: dynamic quantization for int8 acitvation and int4 weight. Using torchao API. # + + from __future__ import annotations import json # from functools import reduce # from math import gcd -from typing import Dict, Optional, Callable, Any, List + +from typing import Any, Callable, Dict, List, Optional import torch import torch.nn as nn import torch.nn.functional as F # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group' + from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + EmbeddingQuantizer, + int8_dynamic_activation_intx_weight, + IntxWeightEmbeddingQuantizer, + SharedEmbeddingQuantizer, +) +from torchao.quantization.granularity import PerGroup, PerRow from torchao.quantization.quant_api import ( int4_weight_only, Int4WeightOnlyQuantizer, @@ -45,51 +60,52 @@ find_multiple, get_device_str, get_precision, - set_precision, name_to_dtype, + set_precision, state_dict_device, use_et_backend, ) -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, - IntxWeightEmbeddingQuantizer, -) -from torchao.quantization.granularity import ( - PerGroup, - PerRow, -) -from torchao.dtypes import PlainLayout # Flag for whether the a8wxdq quantizer is available. + torchao_experimental_load_error: Optional[Exception] = None ######################################################################### ### handle arg validation ### + import inspect + def get_named_parameters(func: Callable) -> List[str]: # Get the signature of the function + signature = inspect.signature(func) # Extract the parameters from the signature + parameters = signature.parameters # Filter and return named parameters + named_params = [ - name for name, param in parameters.items() - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + name + for name, param in parameters.items() + if param.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) ] return named_params -def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: + +def validate_args( + named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None +) -> Dict[str, Any]: for key in q_kwargs.keys(): if key not in named_params: - print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") + print( + f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring." + ) del q_kwargs[key] return q_kwargs @@ -117,32 +133,32 @@ def quantize_model( if isinstance(quantize_options, str): quantize_options = json.loads(quantize_options) - for quantizer, q_kwargs in quantize_options.items(): if quantizer not in quantizer_class_dict: raise RuntimeError(f"unknown quantizer {quantizer} specified") else: # Use tensor subclass API for int4 weight only. + if (device == "cuda" or device == "xpu") and quantizer == "linear:int4": quantize_(model, int4_weight_only(q_kwargs["groupsize"])) if not support_tensor_subclass: unwrap_tensor_subclass(model) continue - if quantizer == "linear:a8wxdq": if get_precision() != torch.float32: - print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." + ) set_precision(torch.float32) - group_size = q_kwargs["groupsize"] bit_width = q_kwargs["bitwidth"] has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerRow() if group_size == -1 else PerGroup(group_size) + granularity = PerRow() if group_size == -1 else PerGroup(group_size) weight_dtype = getattr(torch, f"int{bit_width}") try: quantize_( - model, + model, int8_dynamic_activation_intx_weight( weight_dtype=weight_dtype, granularity=granularity, @@ -154,7 +170,7 @@ def quantize_model( print("Encountered error during quantization: {e}") print("Trying with PlainLayout") quantize_( - model, + model, int8_dynamic_activation_intx_weight( weight_dtype=weight_dtype, granularity=granularity, @@ -162,23 +178,63 @@ def quantize_model( layout=PlainLayout(), ), ) - if not support_tensor_subclass: unwrap_tensor_subclass(model) continue - if quantizer == "embedding:wx": # These quantizers require float32 input weights. Note that after quantization, # the weights will no longer be float32, but lowbit integers + if get_precision() != torch.float32: - print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." + ) set_precision(torch.float32) - if quantizer == "linear:afpwx" and device != "mps": - raise RuntimeError("linear:afpwx quantization can only run on mps device!") + raise RuntimeError( + "linear:afpwx quantization can only run on mps device!" + ) + if quantizer == "experimental:embedding": + has_weight_zeros = q_kwargs["has_weight_zeros"] + granularity = PerRow() if group_size == -1 else PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bit_width}") + try: + quantize_( + model, + EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + use_fallback=False, + ), + ) + except Exception as e: + print( + "Encountered error during quantization with experimental EmbeddingQuantization: {e}" + ) + if quantizer == "experimental:shared": + has_weight_zeros = q_kwargs["has_weight_zeros"] + granularity = PerRow() if group_size == -1 else PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bit_width}") + + try: + quantize_( + model, + SharedEmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + use_fallback=False, + ), + ) + except Exception as e: + print( + "Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}" + ) # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat + precision = get_precision() q = quantizer_class_dict[quantizer] @@ -186,13 +242,14 @@ def quantize_model( q_kwargs = validate_args(named_params, q_kwargs, quantizer) # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs + if "tokenizer" in named_params: q_kwargs["tokenizer"] = tokenizer quant_handler = q(device=device, precision=precision, **q_kwargs) # quantize model - model = quant_handler.quantize(model) + model = quant_handler.quantize(model) ######################################################################### @@ -201,7 +258,13 @@ def quantize_model( class QuantHandler: - def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None): + def __init__( + self, + model: Optional[nn.Module] = None, + device="cpu", + precision=None, + tokenizer=None, + ): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -219,6 +282,7 @@ def quantized_model(self) -> nn.Module: return self.model_ # fallback for TC QuantHandlers that do not implement the method .quantize() + def quantize(self, model: nn.Module) -> nn.Module: self.model_ = model return self.quantized_model() @@ -229,7 +293,15 @@ def quantize(self, model: nn.Module) -> nn.Module: class PrecisionHandler(QuantHandler): - def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype): + def __init__( + self, + model: Optional[nn.Module] = None, + device="cpu", + precision=None, + tokenizer=None, + *, + dtype, + ): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -258,7 +330,15 @@ def quantized_model(self) -> nn.Module: class ExecutorHandler(QuantHandler): - def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator): + def __init__( + self, + model: Optional[nn.Module] = None, + device="cpu", + precision=None, + tokenizer=None, + *, + accelerator, + ): self.model_ = model if isinstance(accelerator, str): @@ -335,31 +415,36 @@ def dynamically_quantize_per_channel( padding = groupsize - (x_shape_1 % groupsize) x = F.pad(x, (0, padding)) items = groupsize - # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps x = x.view(x.shape[0], x.shape[1] // items, items) # get min and max + min_val, max_val = torch.aminmax(x, dim=2) # print(f"min_val {min_val}") # print(f"max_val {max_val}") # calculate scales and zero_points based on min and max # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) device = min_val_neg.device # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) scales = max_val_pos / (float(quant_max - quant_min) / 2) # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) # quantize based on qmin/qmax/scales/zp # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) x_round = torch.round(x_div) x_zp = x_round + zero_points.unsqueeze(-1) @@ -375,6 +460,7 @@ def dynamically_quantize_per_channel( def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype=torch.float): # needed for GPTQ with padding + if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 @@ -420,9 +506,9 @@ def unpack_scales_and_zeros(scales_and_zeros): def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): assert groupsize > 1 # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: groupsize = w.shape[-1] - assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 @@ -458,6 +544,7 @@ def group_dequantize_tensor_from_qparams( ): assert groupsize > 1 # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: groupsize = w_int32.shape[-1] assert w_int32.shape[-1] % groupsize == 0 @@ -489,6 +576,7 @@ def linear_int8_aoti(input, weight, scales): n_groups = scales.numel() // scales.shape[0] # we special-case channel-wise, because we know how to make that fast + if n_groups == 1: scales = scales.view(-1) if ( @@ -498,14 +586,15 @@ def linear_int8_aoti(input, weight, scales): ): lin = F.linear(input, weight.to(dtype=input.dtype)) # print(f"linear shape {lin.shape}, scales shape {scales.shape}") + return lin * scales # Use int8pack_mm for CPU eager + return torch.ops.aten._weight_int8pack_mm( input.reshape(-1, input.shape[-1]), weight, scales, ).reshape(input.shape[:-1] + (weight.shape[0],)) - return F.linear( input, ( @@ -552,14 +641,15 @@ def linear_int8_et(input, weight, scales): n_groups = scales.numel() // scales.shape[0] # we special-case channel-wise, because we know how to make that fast + if n_groups == 1: scales = scales.view(-1) if True: lin = F.linear(input, weight.to(dtype=input.dtype)) # print(f"linear shape {lin.shape}, scales shape {scales.shape}") - return lin * scales + return lin * scales return _qdq_dynamic_quantized_linear( x_fp32=input.float(), x_quant_min=-128, @@ -572,7 +662,6 @@ def linear_int8_et(input, weight, scales): weight_quant_max=127, bias_fp32=None, ).to(dtype=input.dtype) - return F.linear( input, ( @@ -604,10 +693,8 @@ def __init__( super().__init__() if dtype is None: dtype = torch.get_default_dtype() - if device is None: device = "cpu" - assert not bias, "Bias is not supported by LinearInt8" self.in_features = in_features self.out_features = out_features @@ -626,7 +713,6 @@ def __init__( else: n_groups = (in_features + groupsize - 1) // groupsize scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) - self.register_buffer("weight", weight.to(device)) self.register_buffer("scales", scales.to(device)) @@ -646,7 +732,7 @@ class WeightOnlyInt8QuantHandler(QuantHandler): def __init__( self, model: Optional[nn.Module] = None, - device = None, + device=None, precision=None, tokenizer=None, *, @@ -676,9 +762,9 @@ def quantize(self, module): range_max = 127 else: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - for name, child in module.named_children(): # print(f"name: {name}") + if isinstance(child, nn.Linear): if ( (self.node_type == "*") @@ -686,12 +772,14 @@ def quantize(self, module): or (self.node_type == "!output" and name != "output") ): # print(f"{name, child}") + input_weight = child.weight.float() # print(f"{name, child}") # print(f"in_features: {child.in_features}") # print(f"out_features: {child.out_features}") # print(f"expanded weight shape {input_weight.shape}") + weight, scales, _ = dynamically_quantize_per_channel( input_weight, range_min, @@ -716,7 +804,6 @@ def quantize(self, module): ) else: self.quantize(child) - return module def quantized_model(self) -> nn.Module: @@ -758,7 +845,6 @@ def __init__( raise RuntimeError( f"QUantized embedding does not support bitwidth={bitwidth}" ) - if weight is None: groups_per_row = (embedding_dim + groupsize - 1) // groupsize weight = torch.empty( @@ -774,7 +860,6 @@ def __init__( dtype=dtype, device=device, ).squeeze(dim=-1) - self.register_buffer( "weight", weight, @@ -813,7 +898,6 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor: weight = weight.to(torch.int8).add(-8) else: weight = self.weight - scales = self.scales.view(weight.shape[0], -1) result_weights = F.embedding(indices, weight) @@ -871,9 +955,9 @@ def quantize(self, module): range_max = 127 else: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - for name, child in module.named_children(): # print(f"name: {name}") + if isinstance(child, nn.Embedding): # print(f"Embedding identified: {fqn, mod}") # print(f"weights size: {child.weight.size()}") @@ -882,6 +966,7 @@ def quantize(self, module): # print( # f"quantize {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}" # ) + weight, scales, _ = dynamically_quantize_per_channel( child.weight.float(), range_min, @@ -894,7 +979,6 @@ def quantize(self, module): if self.bitwidth == 4: if weight.shape[-1] % 2 != 0: raise RuntimeError("automatic padding not implemented yet") - weight_range_shifted = weight.add(8).view(torch.uint8) weight_view = weight_range_shifted.view( weight.shape[0], weight.shape[1] // 2, 2 @@ -903,12 +987,12 @@ def quantize(self, module): weight_odd = weight_view[:, :, 1] weight_packed = weight_even + weight_odd weight = weight_packed - weight = weight scales = scales.squeeze(dim=-1) # print(f"{name, child}") # print(f"weights size: {child.weight.size()}") + setattr( module, name, @@ -924,7 +1008,6 @@ def quantize(self, module): ) else: self.quantize(child) - return module def quantized_model(self) -> nn.Module: @@ -937,6 +1020,7 @@ def quantized_model(self) -> nn.Module: # Map each quantizer configuration to a class implementing that quantizer # Must come last because __future__ annotations don't work for naked # class references + quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, "embedding:wx": IntxWeightEmbeddingQuantizer, @@ -944,8 +1028,10 @@ def quantized_model(self) -> nn.Module: "precision": PrecisionHandler, "executor": ExecutorHandler, "linear:int4": Int4WeightOnlyQuantizer, - "linear:a8wxdq": None, # uses quantize_ API + "linear:a8wxdq": None, # uses quantize_ API "linear:a8w4dq": Int8DynActInt4WeightQuantizer, + "experimental:embedding": EmbeddingQuantizer, + "experimental:shared": SharedEmbeddingQuantizer, } try: @@ -956,6 +1042,7 @@ def quantized_model(self) -> nn.Module: torchao_build_path = f"{os.getcwd()}/torchao-build" # Try loading quantizer + torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location( "torchao_experimental_quant_api", f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py", @@ -968,9 +1055,11 @@ def quantized_model(self) -> nn.Module: torchao_experimental_quant_api ) from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer + quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer # Try loading custom op + try: libname = "libtorchao_ops_mps_aten.dylib" libpath = f"{torchao_build_path}/cmake-out/lib/{libname}" @@ -978,6 +1067,5 @@ def quantized_model(self) -> nn.Module: print("Loaded torchao mps ops.") except Exception as e: print("Unable to load torchao mps ops library.") - except Exception as e: print("Unable to import torchao experimental quant_api with error: ", e) From 33d8d9ef4a730d02ecca5207577a38b7fa01c340 Mon Sep 17 00:00:00 2001 From: dillondesilva Date: Fri, 18 Apr 2025 12:48:22 +1000 Subject: [PATCH 2/5] Bumped pinned torchao version and modified quantize based on first round of PR comments. Fixes to usage of EmbeddingQuantizer and SharedEmbeddingQuantizer --- install/.pins/torchao-pin.txt | 2 +- torchchat/utils/quantize.py | 116 ++++++++++++++++++---------------- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/install/.pins/torchao-pin.txt b/install/.pins/torchao-pin.txt index c1b84754c..856d40b55 100644 --- a/install/.pins/torchao-pin.txt +++ b/install/.pins/torchao-pin.txt @@ -1 +1 @@ -711fa0809f06fc97febd0c3fe72563c3fe227e51 +a96eeb1c7d7ba24cf0ccfc105141729acfed22bf diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 1ab85f038..20cfa74f2 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -38,21 +38,15 @@ # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group' from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa -from torchao.dtypes import PlainLayout -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - EmbeddingQuantizer, - int8_dynamic_activation_intx_weight, - IntxWeightEmbeddingQuantizer, - SharedEmbeddingQuantizer, -) -from torchao.quantization.granularity import PerGroup, PerRow +from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout +from torchao.experimental.quant_api import EmbeddingQuantizer, SharedEmbeddingQuantizer +from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( int4_weight_only, Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer, + Int8DynamicActivationIntxWeightConfig, + MappingType, quantize_, ) from torchao.utils import unwrap_tensor_subclass @@ -138,6 +132,56 @@ def quantize_model( raise RuntimeError(f"unknown quantizer {quantizer} specified") else: # Use tensor subclass API for int4 weight only. + if quantizer == "experimental:embedding": + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + weight_granularity = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + + try: + model = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=weight_granularity, + mapping_type=weight_mapping_type, + use_fallback=False, + ).quantize(model) + except Exception as e: + print( + "Encountered error during quantization with experimental EmbeddingQuantization: {e}" + ) + if quantizer == "experimental:shared": + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + weight_granularity = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + + try: + model = SharedEmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=weight_granularity, + mapping_type=weight_mapping_type, + use_fallback=False, + ).quantize(model) + except Exception as e: + print( + "Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}" + ) if (device == "cuda" or device == "xpu") and quantizer == "linear:int4": quantize_(model, int4_weight_only(q_kwargs["groupsize"])) @@ -153,13 +197,13 @@ def quantize_model( group_size = q_kwargs["groupsize"] bit_width = q_kwargs["bitwidth"] has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerRow() if group_size == -1 else PerGroup(group_size) + granularity = PerAxis() if group_size == -1 else PerGroup(group_size) weight_dtype = getattr(torch, f"int{bit_width}") try: quantize_( model, - int8_dynamic_activation_intx_weight( + Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, granularity=granularity, has_weight_zeros=has_weight_zeros, @@ -168,14 +212,14 @@ def quantize_model( ) except Exception as e: print("Encountered error during quantization: {e}") - print("Trying with PlainLayout") + print("Trying with QDQLayout") quantize_( model, - int8_dynamic_activation_intx_weight( + Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, granularity=granularity, has_weight_zeros=has_weight_zeros, - layout=PlainLayout(), + layout=QDQLayout(), ), ) if not support_tensor_subclass: @@ -194,44 +238,6 @@ def quantize_model( raise RuntimeError( "linear:afpwx quantization can only run on mps device!" ) - if quantizer == "experimental:embedding": - has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerRow() if group_size == -1 else PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bit_width}") - - try: - quantize_( - model, - EmbeddingQuantizer( - weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - use_fallback=False, - ), - ) - except Exception as e: - print( - "Encountered error during quantization with experimental EmbeddingQuantization: {e}" - ) - if quantizer == "experimental:shared": - has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerRow() if group_size == -1 else PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bit_width}") - - try: - quantize_( - model, - SharedEmbeddingQuantizer( - weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - use_fallback=False, - ), - ) - except Exception as e: - print( - "Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}" - ) # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat @@ -1023,7 +1029,7 @@ def quantized_model(self) -> nn.Module: quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, - "embedding:wx": IntxWeightEmbeddingQuantizer, + "embedding:wx": None, "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, From 89ff85f2b1d6a491021646717f3a521dfee553cd Mon Sep 17 00:00:00 2001 From: dillondesilva Date: Mon, 12 May 2025 22:57:58 +1000 Subject: [PATCH 3/5] addressing style nits + quant order --- torchchat/utils/quantize.py | 311 +++++++++++++++++------------------- 1 file changed, 144 insertions(+), 167 deletions(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 7f7f0e06d..713a5e6d1 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -68,21 +68,16 @@ ######################################################################### ### handle arg validation ### - import inspect - def get_named_parameters(func: Callable) -> List[str]: # Get the signature of the function - signature = inspect.signature(func) # Extract the parameters from the signature - parameters = signature.parameters # Filter and return named parameters - named_params = [ name for name, param in parameters.items() @@ -125,160 +120,166 @@ def quantize_model( if isinstance(quantize_options, str): quantize_options = json.loads(quantize_options) + + ordered_quantize_options = { } + for quantizer in quantizer_class_dict.keys(): + if quantizer in quantize_options: + ordered_quantize_options |= { quantizer : quantize_options.pop(quantizer) } + + if len(quantize_options) != 0: + raise RuntimeError(f"unknown quantizer(s) {quantize_options.keys()} specified") + + quantize_options = ordered_quantize_options for quantizer, q_kwargs in quantize_options.items(): - if quantizer not in quantizer_class_dict: - raise RuntimeError(f"unknown quantizer {quantizer} specified") - else: - # Use tensor subclass API for int4 weight only. - if quantizer == "experimental:embedding": - group_size = q_kwargs["groupsize"] - bit_width = q_kwargs["bitwidth"] - has_weight_zeros = q_kwargs["has_weight_zeros"] - weight_granularity = ( - PerAxis() if group_size == -1 else PerGroup(group_size) - ) - weight_dtype = getattr(torch, f"int{bit_width}") - weight_mapping_type = ( - MappingType.ASYMMETRIC - if has_weight_zeros - else MappingType.SYMMETRIC + if quantizer == "experimental:embedding": + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + weight_granularity = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + + try: + model = EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=weight_granularity, + mapping_type=weight_mapping_type, + use_fallback=False, + ).quantize(model) + except Exception as e: + print( + "Encountered error during quantization with experimental EmbeddingQuantization: {e}" ) + if quantizer == "experimental:shared": + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + weight_granularity = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) - try: - model = EmbeddingQuantizer( - weight_dtype=weight_dtype, - granularity=weight_granularity, - mapping_type=weight_mapping_type, - use_fallback=False, - ).quantize(model) - except Exception as e: - print( - "Encountered error during quantization with experimental EmbeddingQuantization: {e}" - ) - if quantizer == "experimental:shared": - group_size = q_kwargs["groupsize"] - bit_width = q_kwargs["bitwidth"] - has_weight_zeros = q_kwargs["has_weight_zeros"] - weight_granularity = ( - PerAxis() if group_size == -1 else PerGroup(group_size) + try: + model = SharedEmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=weight_granularity, + mapping_type=weight_mapping_type, + use_fallback=False, + ).quantize(model) + except Exception as e: + print( + "Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}" ) - weight_dtype = getattr(torch, f"int{bit_width}") - weight_mapping_type = ( - MappingType.ASYMMETRIC - if has_weight_zeros - else MappingType.SYMMETRIC + # Use tensor subclass API for int4 weight only. + if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4": + quantize_(model, int4_weight_only(q_kwargs["groupsize"])) + if not support_tensor_subclass: + unwrap_tensor_subclass(model) + continue + if quantizer == "linear:a8wxdq": + if get_precision() != torch.float32: + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." ) + set_precision(torch.float32) + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs["has_weight_zeros"] + granularity = PerAxis() if group_size == -1 else PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) - try: - model = SharedEmbeddingQuantizer( + try: + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, - granularity=weight_granularity, - mapping_type=weight_mapping_type, - use_fallback=False, - ).quantize(model) - except Exception as e: - print( - "Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}" - ) - - if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4": - quantize_(model, int4_weight_only(q_kwargs["groupsize"])) - if not support_tensor_subclass: - unwrap_tensor_subclass(model) - continue - if quantizer == "linear:a8wxdq": - if get_precision() != torch.float32: - print( - f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." - ) - set_precision(torch.float32) - group_size = q_kwargs["groupsize"] - bit_width = q_kwargs["bitwidth"] - has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerAxis() if group_size == -1 else PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bit_width}") - weight_mapping_type = ( - MappingType.ASYMMETRIC - if has_weight_zeros - else MappingType.SYMMETRIC + weight_granularity=granularity, + weight_mapping_type=weight_mapping_type, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + ), ) - - try: - quantize_( - model, - Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=granularity, - weight_mapping_type=weight_mapping_type, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), - ), - ) - except Exception as e: - print("Encountered error during quantization: {e}") - print("Trying with QDQLayout") - quantize_( - model, - Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=granularity, - weight_mapping_type=weight_mapping_type, - layout=QDQLayout(), - ), - ) - if not support_tensor_subclass: - unwrap_tensor_subclass(model) - continue - if quantizer == "embedding:wx": - # These quantizers require float32 input weights. Note that after quantization, - # the weights will no longer be float32, but lowbit integers - - if get_precision() != torch.float32: - print( - f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." - ) - set_precision(torch.float32) - - group_size = q_kwargs["groupsize"] - bit_width = q_kwargs["bitwidth"] - has_weight_zeros = q_kwargs.get("has_weight_zeros", True) - q_kwargs["granularity"] = ( - PerAxis() if group_size == -1 else PerGroup(group_size) + except Exception as e: + print("Encountered error during quantization: {e}") + print("Trying with QDQLayout") + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=granularity, + weight_mapping_type=weight_mapping_type, + layout=QDQLayout(), + ), ) - q_kwargs["weight_dtype"] = getattr(torch, f"int{bit_width}") - q_kwargs["mapping_type"] = ( - MappingType.ASYMMETRIC - if has_weight_zeros - else MappingType.SYMMETRIC + if not support_tensor_subclass: + unwrap_tensor_subclass(model) + continue + if quantizer == "embedding:wx": + # These quantizers require float32 input weights. Note that after quantization, + # the weights will no longer be float32, but lowbit integers + + if get_precision() != torch.float32: + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." ) - q_kwargs["use_fallback"] = False - del q_kwargs["groupsize"] - del q_kwargs["bitwidth"] + set_precision(torch.float32) - if quantizer == "linear:afpwx" and device != "mps": - raise RuntimeError( - "linear:afpwx quantization can only run on mps device!" - ) - # We set global precision from quantize options if it is specified at cli.py:485 - # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs.get("has_weight_zeros", True) + q_kwargs["granularity"] = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + q_kwargs["weight_dtype"] = getattr(torch, f"int{bit_width}") + q_kwargs["mapping_type"] = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + q_kwargs["use_fallback"] = False + del q_kwargs["groupsize"] + del q_kwargs["bitwidth"] - precision = get_precision() + if quantizer == "linear:afpwx" and device != "mps": + raise RuntimeError( + "linear:afpwx quantization can only run on mps device!" + ) + # We set global precision from quantize options if it is specified at cli.py:485 + # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat - q = quantizer_class_dict[quantizer] - named_params = get_named_parameters(q.__init__) - q_kwargs = validate_args(named_params, q_kwargs, quantizer) + precision = get_precision() - # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs + q = quantizer_class_dict[quantizer] + named_params = get_named_parameters(q.__init__) + q_kwargs = validate_args(named_params, q_kwargs, quantizer) - if "tokenizer" in named_params: - q_kwargs["tokenizer"] = tokenizer - if quantizer == "embedding:wx": - quant_handler = q(**q_kwargs) - else: - quant_handler = q(device=device, precision=precision, **q_kwargs) + # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs - # quantize model + if "tokenizer" in named_params: + q_kwargs["tokenizer"] = tokenizer + if quantizer == "embedding:wx": + quant_handler = q(**q_kwargs) + else: + quant_handler = q(device=device, precision=precision, **q_kwargs) + + # quantize model - model = quant_handler.quantize(model) + model = quant_handler.quantize(model) ######################################################################### @@ -445,35 +446,29 @@ def dynamically_quantize_per_channel( x = F.pad(x, (0, padding)) items = groupsize # default setup for affine quantization of activations - eps = torch.finfo(torch.float32).eps x = x.view(x.shape[0], x.shape[1] // items, items) # get min and max - min_val, max_val = torch.aminmax(x, dim=2) # print(f"min_val {min_val}") # print(f"max_val {max_val}") # calculate scales and zero_points based on min and max # reference: https://fburl.com/code/srbiybme - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) device = min_val_neg.device # reference: https://fburl.com/code/4wll53rk - max_val_pos = torch.max(-min_val_neg, max_val_pos) scales = max_val_pos / (float(quant_max - quant_min) / 2) # ensure scales is the same dtype as the original tensor scales = torch.clamp(scales, min=eps).to(x.dtype) zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) - # quantize based on qmin/qmax/scales/zp # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 - x_div = x / scales.unsqueeze(-1) x_round = torch.round(x_div) x_zp = x_round + zero_points.unsqueeze(-1) @@ -489,7 +484,6 @@ def dynamically_quantize_per_channel( def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype=torch.float): # needed for GPTQ with padding - if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 @@ -535,7 +529,6 @@ def unpack_scales_and_zeros(scales_and_zeros): def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): assert groupsize > 1 # needed for GPTQ single column quantize - if groupsize > w.shape[-1] and scales.shape[-1] == 1: groupsize = w.shape[-1] assert w.shape[-1] % groupsize == 0 @@ -573,7 +566,6 @@ def group_dequantize_tensor_from_qparams( ): assert groupsize > 1 # needed for GPTQ single column dequantize - if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: groupsize = w_int32.shape[-1] assert w_int32.shape[-1] % groupsize == 0 @@ -605,7 +597,6 @@ def linear_int8_aoti(input, weight, scales): n_groups = scales.numel() // scales.shape[0] # we special-case channel-wise, because we know how to make that fast - if n_groups == 1: scales = scales.view(-1) if ( @@ -615,10 +606,8 @@ def linear_int8_aoti(input, weight, scales): ): lin = F.linear(input, weight.to(dtype=input.dtype)) # print(f"linear shape {lin.shape}, scales shape {scales.shape}") - return lin * scales # Use int8pack_mm for CPU eager - return torch.ops.aten._weight_int8pack_mm( input.reshape(-1, input.shape[-1]), weight, @@ -670,14 +659,12 @@ def linear_int8_et(input, weight, scales): n_groups = scales.numel() // scales.shape[0] # we special-case channel-wise, because we know how to make that fast - if n_groups == 1: scales = scales.view(-1) if True: lin = F.linear(input, weight.to(dtype=input.dtype)) # print(f"linear shape {lin.shape}, scales shape {scales.shape}") - return lin * scales return _qdq_dynamic_quantized_linear( x_fp32=input.float(), @@ -793,7 +780,6 @@ def quantize(self, module): raise ValueError(f"Unsupported bitwidth {self.bitwidth}") for name, child in module.named_children(): # print(f"name: {name}") - if isinstance(child, nn.Linear): if ( (self.node_type == "*") @@ -801,14 +787,12 @@ def quantize(self, module): or (self.node_type == "!output" and name != "output") ): # print(f"{name, child}") - input_weight = child.weight.float() # print(f"{name, child}") # print(f"in_features: {child.in_features}") # print(f"out_features: {child.out_features}") # print(f"expanded weight shape {input_weight.shape}") - weight, scales, _ = dynamically_quantize_per_channel( input_weight, range_min, @@ -986,7 +970,6 @@ def quantize(self, module): raise ValueError(f"Unsupported bitwidth {self.bitwidth}") for name, child in module.named_children(): # print(f"name: {name}") - if isinstance(child, nn.Embedding): # print(f"Embedding identified: {fqn, mod}") # print(f"weights size: {child.weight.size()}") @@ -995,7 +978,6 @@ def quantize(self, module): # print( # f"quantize {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}" # ) - weight, scales, _ = dynamically_quantize_per_channel( child.weight.float(), range_min, @@ -1021,7 +1003,6 @@ def quantize(self, module): # print(f"{name, child}") # print(f"weights size: {child.weight.size()}") - setattr( module, name, @@ -1049,7 +1030,6 @@ def quantized_model(self) -> nn.Module: # Map each quantizer configuration to a class implementing that quantizer # Must come last because __future__ annotations don't work for naked # class references - quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, "embedding:wx": EmbeddingQuantizer, @@ -1059,8 +1039,8 @@ def quantized_model(self) -> nn.Module: "linear:int4": Int4WeightOnlyQuantizer, "linear:a8wxdq": None, # uses quantize_ API "linear:a8w4dq": Int8DynActInt4WeightQuantizer, - "experimental:embedding": EmbeddingQuantizer, "experimental:shared": SharedEmbeddingQuantizer, + "experimental:embedding": EmbeddingQuantizer, } try: @@ -1071,7 +1051,6 @@ def quantized_model(self) -> nn.Module: torchao_build_path = f"{os.getcwd()}/torchao-build" # Try loading quantizer - torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location( "torchao_experimental_quant_api", f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py", @@ -1084,11 +1063,9 @@ def quantized_model(self) -> nn.Module: torchao_experimental_quant_api ) from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer - quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer # Try loading custom op - try: libname = "libtorchao_ops_mps_aten.dylib" libpath = f"{torchao_build_path}/cmake-out/lib/{libname}" From 30afac596ddc49e97135e81521cb61b264808e61 Mon Sep 17 00:00:00 2001 From: dillondesilva Date: Mon, 12 May 2025 23:08:01 +1000 Subject: [PATCH 4/5] couple other fixes to address style nits missing in earlier commit --- torchchat/utils/quantize.py | 47 +++++-------------------------------- 1 file changed, 6 insertions(+), 41 deletions(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 713a5e6d1..f3a196d2f 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -20,8 +20,6 @@ # torchao Quantizer: # * Int8DynActInt4WeightQuantizer: dynamic quantization for int8 acitvation and int4 weight. Using torchao API. # - - from __future__ import annotations import json @@ -30,7 +28,6 @@ # from math import gcd from typing import Any, Callable, Dict, List, Optional - import torch import torch.nn as nn import torch.nn.functional as F @@ -62,7 +59,6 @@ # Flag for whether the a8wxdq quantizer is available. - torchao_experimental_load_error: Optional[Exception] = None ######################################################################### @@ -79,10 +75,8 @@ def get_named_parameters(func: Callable) -> List[str]: # Filter and return named parameters named_params = [ - name - for name, param in parameters.items() - if param.kind - in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + name for name, param in parameters.items() + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) ] return named_params @@ -90,9 +84,7 @@ def get_named_parameters(func: Callable) -> List[str]: def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: for key in list(q_kwargs.keys()): if key not in named_params: - print( - f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring." - ) + print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") del q_kwargs[key] return q_kwargs @@ -232,7 +224,6 @@ def quantize_model( if quantizer == "embedding:wx": # These quantizers require float32 input weights. Note that after quantization, # the weights will no longer be float32, but lowbit integers - if get_precision() != torch.float32: print( f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." @@ -261,7 +252,6 @@ def quantize_model( ) # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat - precision = get_precision() q = quantizer_class_dict[quantizer] @@ -269,7 +259,6 @@ def quantize_model( q_kwargs = validate_args(named_params, q_kwargs, quantizer) # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs - if "tokenizer" in named_params: q_kwargs["tokenizer"] = tokenizer if quantizer == "embedding:wx": @@ -278,7 +267,6 @@ def quantize_model( quant_handler = q(device=device, precision=precision, **q_kwargs) # quantize model - model = quant_handler.quantize(model) @@ -288,13 +276,7 @@ def quantize_model( class QuantHandler: - def __init__( - self, - model: Optional[nn.Module] = None, - device="cpu", - precision=None, - tokenizer=None, - ): + def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -312,7 +294,6 @@ def quantized_model(self) -> nn.Module: return self.model_ # fallback for TC QuantHandlers that do not implement the method .quantize() - def quantize(self, model: nn.Module) -> nn.Module: self.model_ = model return self.quantized_model() @@ -323,15 +304,7 @@ def quantize(self, model: nn.Module) -> nn.Module: class PrecisionHandler(QuantHandler): - def __init__( - self, - model: Optional[nn.Module] = None, - device="cpu", - precision=None, - tokenizer=None, - *, - dtype, - ): + def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -360,15 +333,7 @@ def quantized_model(self) -> nn.Module: class ExecutorHandler(QuantHandler): - def __init__( - self, - model: Optional[nn.Module] = None, - device="cpu", - precision=None, - tokenizer=None, - *, - accelerator, - ): + def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator): self.model_ = model if isinstance(accelerator, str): From 0c1155f73e38dc348633f835a3da5f6e11948318 Mon Sep 17 00:00:00 2001 From: dillondesilva Date: Mon, 12 May 2025 23:11:16 +1000 Subject: [PATCH 5/5] minor changes for nits --- torchchat/utils/quantize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index f3a196d2f..7686344db 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -429,7 +429,6 @@ def dynamically_quantize_per_channel( max_val_pos = torch.max(-min_val_neg, max_val_pos) scales = max_val_pos / (float(quant_max - quant_min) / 2) # ensure scales is the same dtype as the original tensor - scales = torch.clamp(scales, min=eps).to(x.dtype) zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) # quantize based on qmin/qmax/scales/zp @@ -713,7 +712,7 @@ class WeightOnlyInt8QuantHandler(QuantHandler): def __init__( self, model: Optional[nn.Module] = None, - device=None, + device = None, precision=None, tokenizer=None, *,