diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index fe839cf4..8ec2a5ae 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -8,6 +8,7 @@ from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.log import log_master app = typer.Typer() @@ -87,15 +88,17 @@ def serve( ) if len(lora_adapter_ids) > 0: - logger.warning( - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + log_master( + logger.warning, + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", ) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: - logger.warning( - f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." + log_master( + logger.warning, + f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", ) global CUDA_GRAPHS CUDA_GRAPHS = None diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 99c490d5..54da63e8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,6 +3,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -136,7 +137,10 @@ if ENGINE != "triton": try: import flash_attn_2_cuda - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) except ImportError as e: if major >= 8: architecture_suffix = f"-{SYSTEM}" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index 925b0b2d..aae2bd1a 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -4,19 +4,11 @@ from functools import lru_cache import bitsandbytes as bnb import torch from bitsandbytes.nn import Int8Params, Params4bit -from loguru import logger -from text_generation_server.utils.weights import Weight - - -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class BNBWeight(Weight): +class BNBWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module): @dataclass -class BNBFP4Weight(Weight): +class BNBFP4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -90,7 +82,7 @@ class BNBFP4Weight(Weight): @dataclass -class BNBNF4Weight(Weight): +class BNBNF4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index f003f914..b1e5235a 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -2,11 +2,11 @@ from dataclasses import dataclass import torch from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class EETQWeight(Weight): +class EETQWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index fe083b68..4568f8a0 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,19 +1,29 @@ import torch from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union, List +from loguru import logger from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False try: import fbgemm_gpu.experimental.gen_ai - major, _ = torch.cuda.get_device_capability() - HAS_FBGEMM_MM = major == 9 - HAS_FBGEMM_DYN = major >= 8 + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 except (ImportError, ModuleNotFoundError): - HAS_FBGEMM_MM = False - HAS_FBGEMM_DYN = False + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") def get_fp8_linear() -> torch.nn.Module: @@ -33,7 +43,7 @@ def get_fp8_linear() -> torch.nn.Module: def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): - if HAS_FBGEMM_DYN: + if FBGEMM_DYN_AVAILABLE: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype ) @@ -54,18 +64,83 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): return qweight, scale +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float]): + self.activation_scale_ub = activation_scale_ub + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = torch.cat(w, dim=dim) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=0) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + + return UnquantizedWeight(w) + + @dataclass -class Fp8Weight: +class Fp8Weight(Weight): weight: torch.Tensor dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None - input_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.input_scale, bias, self.dtype + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype ) @@ -82,7 +157,13 @@ class Fp8Linear(torch.nn.Module): self.dtype = dtype self.qweight = qweight self.scale = scale - self.scale_upper_bound = scale_upper_bound + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) self.bias = bias if bias is not None else None @@ -104,7 +185,9 @@ class Fp8Linear(torch.nn.Module): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - if HAS_FBGEMM_MM: + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 kernels") + qinput, scale = fp8_quantize( input, scale_upper_bound=self.scale_upper_bound ) diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 4d45822b..dc3b832f 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -9,11 +9,12 @@ from loguru import logger from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.utils.log import log_master try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error("exllamav2_kernels not installed.") + log_master(logger.warning, "exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 725aa544..a43cdfed 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_master # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -47,9 +48,7 @@ torch.set_grad_enabled(False) __all__ = [ "Model", - "BLOOMSharded", "CausalLM", - "GalacticaSharded", "Seq2SeqLM", "get_model", ] @@ -125,7 +124,7 @@ try: ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False @@ -137,7 +136,7 @@ MAMBA_AVAILABLE = True try: from text_generation_server.models.mamba import Mamba except ImportError as e: - logger.warning(f"Could not import Mamba: {e}") + log_master(logger.warning, f"Could not import Mamba: {e}") MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: @@ -312,8 +311,11 @@ def get_model( # These quantizers only work with float16 params. dtype = torch.float16 elif quantize == "fp8": - # gemm kernels are fp8xfp8->bf16 - dtype = torch.bfloat16 + from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + + if FBGEMM_MM_AVAILABLE: + # fbgemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. @@ -436,7 +438,9 @@ def get_model( speculate = get_speculate() if speculate > 0: - logger.info(f"Using speculation {method} with {speculate} input ids.") + log_master( + logger.info, f"Using speculation {method} with {speculate} input ids." + ) if model_type is None: # TODO: fix how we determine model type for Mamba @@ -451,10 +455,10 @@ def get_model( if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: - logger.info(f"Auto selecting quantization method {method}") + log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method else: - logger.info(f"Unknown quantization method {method}") + log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( @@ -596,7 +600,7 @@ def get_model( ) except RuntimeError as e: # Lots of legacy models with various weight names. - logger.warning(f"Couldn't load flash gpt2 variant: {e}") + log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") return CausalLM.fallback( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5237a484..df635ff2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -418,7 +418,22 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.layers = nn.ModuleList( + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else "{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + self.layers.extend( [ FlashLlamaLayer( index=layer_id, @@ -430,9 +445,26 @@ class FlashLlamaModel(torch.nn.Module): config=config, weights=weights, ) - for layer_id in range(config.num_hidden_layers) + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1) ] ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + self.norm = FastRMSNorm.load( prefix="model.norm" if not prefix else f"{prefix}.model.norm", weights=weights, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2888f1f7..cfffafa1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model +from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - hub, ) from text_generation_server.models.types import ( Batch, @@ -1156,31 +1155,36 @@ class FlashCausalLM(Model): f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) - logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." + log_master( + logger.info, + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): - logger.info( - f"The file {tunableop_filepath} already exists and will be reused." + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", ) torch.cuda.tunable.read_file(tunableop_filepath) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info( - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) if CUDA_GRAPHS: try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: @@ -1188,7 +1192,9 @@ class FlashCausalLM(Model): except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) return int(num_blocks * BLOCK_SIZE) @@ -1540,8 +1546,7 @@ class FlashCausalLM(Model): left = 0 if n_accepted_ids > 1: - if RANK == 0: - logger.debug(f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06035ccd..ac42df30 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,15 +1,16 @@ import torch import os from loguru import logger -from typing import Dict +from typing import Dict, Optional + +from text_generation_server.utils.log import log_master MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - logger.info("Using FLASH_DECODING") - + log_master(logger.info, "Using FLASH_DECODING") cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: @@ -26,11 +27,9 @@ else: if cuda_graphs is not None: cuda_graphs.sort(reverse=True) - CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. -global MODEL_ID MODEL_ID = None @@ -41,8 +40,7 @@ def set_model_id(model_id: str): # NOTE: eventually we should move this into the router and pass back the # index in all cases. -global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX: Dict[str, int] = None +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 09130b85..e7748bb9 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import ( AdapterParameters, AdapterSource, ) +from text_generation_server.utils.log import log_master from loguru import logger @@ -204,8 +205,9 @@ class Model(ABC): f"order to use the dynamic adapter loading feature." ) - logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + log_master( + logger.info, + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", ) weight_names = tuple([v[0] for v in self.target_to_layer.values()]) ( @@ -240,8 +242,9 @@ class Model(ABC): layer_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: - logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + log_master( + logger.warning, + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f869f8b5..308d5a3d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,4 +1,3 @@ -from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils.log import log_master from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}" + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py index b1456f1e..4385c71e 100644 --- a/server/text_generation_server/utils/log.py +++ b/server/text_generation_server/utils/log.py @@ -1,6 +1,15 @@ from functools import lru_cache +from text_generation_server.utils.dist import RANK @lru_cache(10) -def log_once(log, msg: str): - log(msg) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 14dbf58b..8ff6ddf1 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -11,6 +11,7 @@ from text_generation_server.utils.weights import ( ) +# TODO: Split this config to have a single config type per quant method @dataclass class _QuantizerConfig: bits: int @@ -21,6 +22,11 @@ class _QuantizerConfig: sym: bool +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + # We should probably do this with Pytantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): @@ -63,6 +69,17 @@ def _get_quantizer_config(model_id, revision): desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": quant_method = "awq" + # FP8 config + except KeyError: + try: + filename = os.path.join(model_id, filename) + with open(filename, "r") as f: + data = json.load(f) + return _FP8QuantizerConfig( + activation_scale_ub=data["activation_scale_ub"] + ) + except: + pass except Exception: filename = "quant_config.json" try: @@ -99,6 +116,12 @@ def get_loader( if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return GPTQWeightsLoader( bits=quantizer_config.bits, desc_act=quantizer_config.desc_act, @@ -127,18 +150,28 @@ def get_loader( from text_generation_server.layers.exl2 import Exl2WeightsLoader return Exl2WeightsLoader() - elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Weight - - return DefaultWeightsLoader(Fp8Weight) elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return MarlinWeightsLoader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) - elif quantize is None: - return DefaultWeightsLoader() + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub) else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 108ba6e7..66bb6051 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,7 +3,7 @@ import torch from abc import ABC, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Type from safetensors import safe_open from dataclasses import dataclass @@ -84,9 +84,8 @@ class Weight(ABC): @dataclass -class UnquantizedWeight: +class UnquantizedWeight(Weight): weight: torch.Tensor - dtype: torch.dtype def get_linear(self, bias: torch.Tensor): from text_generation_server.layers.linear import FastLinear, FastLinearROCm @@ -100,7 +99,7 @@ class UnquantizedWeight: class DefaultWeightsLoader(WeightsLoader): """Weight loader that loads (unquantized) Torch tensors.""" - def __init__(self, weight_class: Optional = None): + def __init__(self, weight_class: Type[UnquantizedWeight]): """Create a loader. Weights will be wrapped using the given `weights_class`, normally this will be `UnquantizedWeight`, but a quantizer-specific class such as `Fp8Weight` can be used to quantize the weights during loading. @@ -121,92 +120,21 @@ class DefaultWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): - w = weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes + + return self.weight_class( + weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ), ) - if w.dtype == torch.float8_e4m3fn: - # FIXME: here to avoid circular import - from text_generation_server.layers.fp8 import Fp8Weight - - if self.weight_class is not None and self.weight_class != Fp8Weight: - raise RuntimeError( - f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" - ) - # FIXME: here to avoid circular import - from text_generation_server.layers.fp8 import Fp8Weight - - # FP8 branch - scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, cast=False - ) - input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False) - return Fp8Weight( - weight=w, - weight_scale=scale, - input_scale=input_scale, - dtype=weights.dtype, - ) - - if self.weight_class is None: - return UnquantizedWeight(w, dtype=weights.dtype) - return self.weight_class(w, dtype=weights.dtype) - def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - w = torch.cat(w, dim=dim) - - # FP8 branch - if w.dtype == torch.float8_e4m3fn: - # FIXME: here to avoid circular import - from text_generation_server.layers.fp8 import Fp8Weight - - if self.weight_class is not None and self.weight_class != Fp8Weight: - raise RuntimeError( - f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" - ) - - scale = [ - weights.get_sharded(f"{p}.weight_scale", dim=0, cast=False) - for p in prefixes - ] - scale = torch.cat(scale, dim=0) - input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale", cast=False) - return Fp8Weight( - weight=w, - weight_scale=scale, - input_scale=input_scale, - dtype=weights.dtype, - ) - - if self.weight_class is None: - return UnquantizedWeight(w, dtype=weights.dtype) - return self.weight_class(w, dtype=weights.dtype) + return self.weight_class(torch.cat(w, dim=dim)) def get_weights_row(self, weights: "Weights", prefix: str): - w = weights.get_sharded(f"{prefix}.weight", dim=1) - # FP8 branch - if w.dtype == torch.float8_e4m3fn: - # FIXME: here to avoid circular import - from text_generation_server.layers.fp8 import Fp8Weight - - if self.weight_class is not None and self.weight_class != Fp8Weight: - raise RuntimeError( - f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" - ) - - scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, cast=False) - input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False) - return Fp8Weight( - weight=w, - weight_scale=scale, - input_scale=input_scale, - dtype=weights.dtype, - ) - - if self.weight_class is None: - return UnquantizedWeight(w, dtype=weights.dtype) - return self.weight_class(w, dtype=weights.dtype) + return self.weight_class( + weights.get_sharded(f"{prefix}.weight", dim=1), + ) class Weights: @@ -280,7 +208,7 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True, cast=True): + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -295,14 +223,14 @@ class Weights: torch.int32, torch.int64, ] - and cast + and to_dtype ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int, cast=True): + def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -323,12 +251,15 @@ class Weights: # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. # FP8 uses torch.float8_e4m3fn. - if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) and cast: + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int, cast=True): + def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -337,10 +268,14 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim, cast=cast) + return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]], cast=True + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -394,7 +329,7 @@ class Weights: torch.int32, torch.int64, ] - and cast + and to_dtype ): tensor = tensor.to(dtype=self.dtype)