diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 7595665d..db088a39 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -91,18 +91,8 @@ Options: ## KV_CACHE_DTYPE ```shell --kv-cache-dtype - Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria - [env: KV_CACHE_DTYPE=] - [default: auto] - -``` -## QUANTIZATION_PARAM_PATH -```shell - --quantization-param-path - Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria - - [env: QUANTIZATION_PARAM_PATH=] + [possible values: fp8, fp8_e5m2] ``` ## TRUST_REMOTE_CODE @@ -244,7 +234,7 @@ Options: --hostname The IP address to listen on - [env: HOSTNAME=] + [env: HOSTNAME=hf-amd-mi250-dev] [default: 0.0.0.0] ``` @@ -253,7 +243,7 @@ Options: -p, --port The port to listen on - [env: PORT=] + [env: PORT=80] [default: 3000] ``` @@ -289,7 +279,7 @@ Options: --huggingface-hub-cache The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance - [env: HUGGINGFACE_HUB_CACHE=] + [env: HUGGINGFACE_HUB_CACHE=/data] ``` ## WEIGHTS_CACHE_OVERRIDE diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0d8ae6c8..6b2a3269 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -144,6 +144,28 @@ impl std::fmt::Display for Dtype { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum KvDtype { + #[clap(name = "fp8")] + Fp8, + #[clap(name = "fp8_e5m2")] + Fp8e5m2, +} + +impl std::fmt::Display for KvDtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + KvDtype::Fp8 => { + write!(f, "fp8") + }, + KvDtype::Fp8e5m2 => { + write!(f, "fp8_e5m2") + }, + } + } +} + #[derive(Clone, Copy, Debug, ValueEnum)] enum RopeScaling { Linear, @@ -214,22 +236,12 @@ struct Args { #[clap(long, env, value_enum)] dtype: Option, - /// Data type for kv cache storage. If "auto", will use model - /// data type. FP8_E5M2 (without scaling) is only supported on cuda - /// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead - /// supported for common inference criteria. - #[clap(default_value = "auto", long, env)] - kv_cache_dtype: Option, - - /// Path to the JSON file containing the KV cache - /// scaling factors. This should generally be supplied, when - /// KV cache dtype is FP8. Otherwise, KV cache scaling factors - /// default to 1.0, which may cause accuracy issues. - /// FP8_E5M2 (without scaling) is only supported on cuda version - /// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead - /// supported for common inference criteria. - #[clap(long, env)] - quantization_param_path: Option, + // Specify the data type for KV cache. By default, it uses the model's data type. + // CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fn)'. + // If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided. + // If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy." + #[clap(long, env, value_enum)] + kv_cache_dtype: Option, /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// encouraged when loading a model with custom code to ensure no malicious code has been @@ -481,8 +493,7 @@ fn shard_manager( quantize: Option, speculate: Option, dtype: Option, - kv_cache_dtype: Option, - quantization_param_path: Option, + kv_cache_dtype: Option, trust_remote_code: bool, uds_path: String, rank: usize, @@ -556,12 +567,7 @@ fn shard_manager( if let Some(kv_cache_dtype) = kv_cache_dtype { shard_args.push("--kv-cache-dtype".to_string()); - shard_args.push(kv_cache_dtype) - } - - if let Some(quantization_param_path) = quantization_param_path { - shard_args.push("--quantization-param-path".to_string()); - shard_args.push(quantization_param_path) + shard_args.push(kv_cache_dtype.to_string()); } // Model optional revision @@ -1067,8 +1073,7 @@ fn spawn_shards( let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; - let kv_cache_dtype = args.kv_cache_dtype.clone(); - let quantization_param_path = args.quantization_param_path.clone(); + let kv_cache_dtype = args.kv_cache_dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; @@ -1087,7 +1092,6 @@ fn spawn_shards( speculate, dtype, kv_cache_dtype, - quantization_param_path, trust_remote_code, uds_path, rank, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index cd03a595..14d2df01 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -7,7 +7,7 @@ from loguru import logger from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download - +from text_generation_server.utils.import_utils import SYSTEM app = typer.Typer() @@ -38,7 +38,6 @@ def serve( speculate: Optional[int] = None, dtype: Optional[Dtype] = None, kv_cache_dtype: str = "auto", - quantization_param_path: Optional[str] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -92,6 +91,13 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) + + if kv_cache_dtype in {"fp8", "fp8_e5m2"} and SYSTEM not in {"cuda", "rocm"}: + raise RuntimeError(f"{kv_cache_dtype} KV cache is only supported on Nvidia and AMD GPUs.") + + if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda": + raise RuntimeError(f"fp8_e5m2 KV cache is only supported on Nvidia GPUs.") + server.serve( model_id, revision, @@ -100,7 +106,6 @@ def serve( speculate, dtype, kv_cache_dtype, - quantization_param_path, trust_remote_code, uds_path, max_input_tokens, diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 583337bd..242b6fa1 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -20,8 +20,10 @@ def reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale) def paged_attention( @@ -34,6 +36,8 @@ def paged_attention( block_tables: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -78,8 +82,8 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, + kv_cache_dtype, + kv_scale, ) else: # Run PagedAttention V2. @@ -111,8 +115,8 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, + kv_cache_dtype, + kv_scale, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 91ed5818..9cf3c4f8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -25,8 +25,10 @@ def reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale) def paged_attention( @@ -39,6 +41,8 @@ def paged_attention( block_tables: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -83,8 +87,8 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, + kv_cache_dtype, + kv_scale, ) else: # Run PagedAttention V2. @@ -116,8 +120,8 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, + kv_cache_dtype, + kv_scale, ) diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index 8b6cb87b..628e5789 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -39,6 +39,8 @@ def reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slots @@ -55,6 +57,8 @@ def paged_attention( block_tables: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): query = query.contiguous() block_size = value_cache.shape[3] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1ae9aed0..4a8813a7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -114,7 +114,6 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) - class ModelType(enum.Enum): IDEFICS2 = { "type": "idefics2", @@ -245,6 +244,11 @@ class ModelType(enum.Enum): "multimodal": True, } +FP8_KVCACHE_SUPPORTED_MODELS = { + "llama", + "baichun", + "phi3", +} __GLOBALS = locals() for data in ModelType: @@ -259,7 +263,6 @@ def get_model( speculate: Optional[int], dtype: Optional[str], kv_cache_dtype: Optional[str], - quantization_param_path: Optional[str], trust_remote_code: bool, max_input_tokens: int, ) -> Model: @@ -278,9 +281,6 @@ def get_model( dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") - - if kv_cache_dtype not in {"auto", "fp8"}: - raise RuntimeError(f"Unknown kv_cache_dtype {kv_cache_dtype}") if speculate is not None: set_speculate(speculate) @@ -292,6 +292,11 @@ def get_model( ) model_type = config_dict.get("model_type", None) + if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto": + raise RuntimeError( + f"kv_cache_dtype is only supported for Llama models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}" + ) + speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -600,7 +605,6 @@ def get_model( speculator=speculator, dtype=dtype, kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, trust_remote_code=trust_remote_code, ) elif sharded: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a8da1c1f..75ab4906 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -25,6 +25,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN +from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -42,9 +43,8 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -from text_generation_server.utils.weights_utils import kv_cache_scales_loader -from loguru import logger +from loguru import logger if SYSTEM == "rocm": try: from vllm import _custom_C @@ -135,15 +135,13 @@ class FlashLlamaAttention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) - # This will be overwritten by model initialization if we are using it. - # N.B. currently we only support per tensor scalar scaling factors - # & only applicable to ROCm (AMD GPU). - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - self.kv_scale = 1.0 - self.kv_cache_dtype = "auto" + self.kv_cache_dtype = config.kv_cache_dtype + + if self.kv_cache_dtype == "fp8": + self.kv_scale = weights.get_kv_cache_scaling_factor(prefix, self.kv_cache_dtype) + else: + self.kv_scale = 1.0 + logger.info(f"kv_cache_dtype: {self.kv_cache_dtype}, kv_scale: {self.kv_scale}") def forward( self, @@ -170,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots, self.kv_cache_dtype, self.kv_scale) # output tensor attn_output = torch.empty_like(query) @@ -406,10 +404,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() - process_group = weights.process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( prefix=( "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" @@ -427,10 +421,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): prefix=suffix if not prefix else f"{prefix}.{suffix}", weights=weights, ) - self.config = config - - for layer_idx in range(config.num_hidden_layers): - self.model.layers[layer_idx].self_attn.kv_cache_dtype = config.kv_cache_dtype def forward( self, @@ -462,33 +452,3 @@ class FlashLlamaForCausalLM(torch.nn.Module): hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - self.tp_rank, - self.tp_world_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if SYSTEM == "rocm": - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.kv_scale = scaling_factor - logger.info( - f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}" - ) - else: - raise RuntimeError( - "Self attention has no KV cache scaling " "factor attribute!" - ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a85d85bd..5b1c051d 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -29,7 +29,6 @@ class FlashLlama(FlashCausalLM): speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, kv_cache_dtype: Optional[str] = "auto", - quantization_param_path: Optional[str] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -75,6 +74,7 @@ class FlashLlama(FlashCausalLM): prefix = "" model = FlashLlamaForCausalLM(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, @@ -82,10 +82,8 @@ class FlashLlama(FlashCausalLM): num_layers=len(model.model.layers), num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, - dtype=dtype, + dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype, device=device, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, rank=rank, world_size=world_size, ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a64e474b..aca96552 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -198,7 +198,6 @@ def serve( speculate: Optional[int], dtype: Optional[str], kv_cache_dtype: Optional[str], - quantization_param_path: Optional[str], trust_remote_code: bool, uds_path: Path, max_input_tokens: int, @@ -211,7 +210,6 @@ def serve( speculate: Optional[int] = None, dtype: Optional[str] = None, kv_cache_dtype: Optional[str] = "auto", - quantization_param_path: Optional[str] = None, trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" @@ -234,7 +232,6 @@ def serve( speculate, dtype, kv_cache_dtype, - quantization_param_path, trust_remote_code, max_input_tokens, ) @@ -272,6 +269,6 @@ def serve( set_model_id(model_id) asyncio.run( serve_inner( - model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, quantization_param_path, trust_remote_code + model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, trust_remote_code ) ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index e6142525..a12a5de5 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError +from server.text_generation_server.utils.import_utils import SYSTEM import torch from loguru import logger from huggingface_hub import hf_hub_download @@ -88,7 +89,7 @@ class Weights: # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if tensor.dtype not in [torch.int16, torch.int32,torch.int64] and not tensor_name.endswith("kv_scale"): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) @@ -762,6 +763,40 @@ class Weights: except Exception: pass + def get_kv_cache_scaling_factor(self, prefix: str, kv_cache_dtype: str): + try: + kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist() + except RuntimeError: + if kv_cache_dtype == "fp8": + log_once( + logger.warning, + "Could not find the `kv_scale` in checkpoint for `fp8_e4m3`. Using scaling factor" + "`1.0`. This may result in accuracy issues. Please ensure the checkpoint includes " + "the correct KV cache scaling factor.", + ) + + kv_scale = 1.0 + else: + if kv_cache_dtype == "fp8_e5m2": + raise RuntimeError( + "Found `kv_scale` in the checkpoint, but `fp8_e5m2` KV dtype do not support `kv_scale` > 1.0" + ) + + if not isinstance(kv_scale, float): + raise RuntimeError( + "Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache" + ) + + # ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3. + # The multiplication by 2 compensates for the different numeric representation + # between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms. + # After this adjustment, the overall effect is equivalent to the scaling applied without + # it on Nvidia GPUs. + if SYSTEM == "rocm": + kv_scale *= 2.0 + + return kv_scale + def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: """