diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 1e5b6fd23..c364b9a6f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -85,6 +85,23 @@ Options: [env: DTYPE=] [possible values: float16, bfloat16] +``` +## KV_CACHE_DTYPE +```shell + --kv-cache-dtype + Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria + + [env: KV_CACHE_DTYPE=] + [default: auto] + +``` +## QUANTIZATION_PARAM_PATH +```shell + --quantization-param-path + Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria + + [env: QUANTIZATION_PARAM_PATH=] + ``` ## TRUST_REMOTE_CODE ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f2f5a99b9..f80eced30 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -184,6 +184,23 @@ struct Args { #[clap(long, env, value_enum)] dtype: Option, + /// Data type for kv cache storage. If "auto", will use model + /// data type. FP8_E5M2 (without scaling) is only supported on cuda + /// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead + /// supported for common inference criteria. + #[clap(default_value = "auto", long, env)] + kv_cache_dtype: Option, + + /// Path to the JSON file containing the KV cache + /// scaling factors. This should generally be supplied, when + /// KV cache dtype is FP8. Otherwise, KV cache scaling factors + /// default to 1.0, which may cause accuracy issues. + /// FP8_E5M2 (without scaling) is only supported on cuda version + /// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead + /// supported for common inference criteria. + #[clap(long, env)] + quantization_param_path: Option, + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// encouraged when loading a model with custom code to ensure no malicious code has been /// contributed in a newer revision. @@ -434,6 +451,8 @@ fn shard_manager( quantize: Option, speculate: Option, dtype: Option, + kv_cache_dtype: Option, + quantization_param_path: Option, trust_remote_code: bool, uds_path: String, rank: usize, @@ -503,6 +522,16 @@ fn shard_manager( shard_args.push(dtype.to_string()) } + if let Some(kv_cache_dtype) = kv_cache_dtype { + shard_args.push("--kv-cache-dtype".to_string()); + shard_args.push(kv_cache_dtype) + } + + if let Some(quantization_param_path) = quantization_param_path { + shard_args.push("--quantization-param-path".to_string()); + shard_args.push(quantization_param_path) + } + // Model optional revision if let Some(revision) = revision { shard_args.push("--revision".to_string()); @@ -1000,6 +1029,8 @@ fn spawn_shards( let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; + let kv_cache_dtype = args.kv_cache_dtype.clone(); + let quantization_param_path = args.quantization_param_path.clone(); let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; @@ -1017,6 +1048,8 @@ fn spawn_shards( quantize, speculate, dtype, + kv_cache_dtype, + quantization_param_path, trust_remote_code, uds_path, rank, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index ad623ccc8..89bfacbc7 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -35,6 +35,8 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, + kv_cache_dtype: str = "auto", + quantization_param_path: Optional[str] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -94,6 +96,8 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, + quantization_param_path, trust_remote_code, uds_path, ) diff --git a/server/text_generation_server/layers/schema.py b/server/text_generation_server/layers/schema.py new file mode 100644 index 000000000..ca7d81a3d --- /dev/null +++ b/server/text_generation_server/layers/schema.py @@ -0,0 +1,90 @@ +""" +This file contains the Pydantic schemas for various quantization-related +parameters. When a relevant quantization technique is specified, these +parameters are loaded in the form of a JSON alongside the model weights +and augment the model with additional information needed for use of that +technique. The format of this JSON should be specified by one or more +schemas contained here. + +For example, when the KV cache is quantized to FP8-E4M3 (currently only +possible on ROCm), the model can be optionally augmented with KV cache +scaling factors. +""" + +from typing import Dict, Optional + +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator + + +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!" + ) + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}." + ) + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}." + ) + for i in range(tp_size): + assert ( + i in self.scaling_factor + ), f"KV cache scales map for TP rank {i} not found." + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}." + ) + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!" + ) + return self diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index b319ab5d1..9820acf58 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -260,6 +260,8 @@ def get_model( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], + quantization_param_path: Optional[str], trust_remote_code: bool, ) -> Model: if dtype is None: @@ -272,6 +274,9 @@ def get_model( dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") + + if kv_cache_dtype not in {"auto", "fp8"}: + raise RuntimeError(f"Unknown kv_cache_dtype {kv_cache_dtype}") if speculate is not None: set_speculate(speculate) @@ -563,6 +568,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, trust_remote_code=trust_remote_code, ) elif sharded: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6e23aa2bd..6e8bc56fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -25,7 +25,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import paged_attention, flash_attn @@ -39,6 +38,8 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights_utils import kv_cache_scales_loader +from loguru import logger if SYSTEM == "rocm": try: @@ -126,6 +127,16 @@ class FlashLlamaAttention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + # This will be overwritten by model initialization if we are using it. + # N.B. currently we only support per tensor scalar scaling factors + # & only applicable to ROCm (AMD GPU). + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + self.kv_scale = 1.0 + self.kv_cache_dtype = "auto" + def forward( self, hidden_states, @@ -152,7 +163,13 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.kv_cache_dtype, + self.kv_scale, ) # output tensor @@ -182,6 +199,8 @@ class FlashLlamaAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.kv_cache_dtype, + self.kv_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -387,6 +406,10 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( prefix=( "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" @@ -404,6 +427,10 @@ class FlashLlamaForCausalLM(torch.nn.Module): prefix=suffix if not prefix else f"{prefix}.suffix", weights=weights, ) + self.config = config + + for layer_idx in range(config.num_hidden_layers): + self.model.layers[layer_idx].self_attn.kv_cache_dtype = config.kv_cache_dtype def forward( self, @@ -435,3 +462,33 @@ class FlashLlamaForCausalLM(torch.nn.Module): hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + self.tp_rank, + self.tp_world_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if SYSTEM == "rocm": + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.kv_scale = scaling_factor + logger.info( + f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}" + ) + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 45ddd8569..fff2821b3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -692,6 +692,8 @@ class FlashCausalLM(Model): head_size: int, dtype: torch.dtype, device: torch.device, + kv_cache_dtype: str = "auto", + quantization_param_path: Optional[str] = None, rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, @@ -713,6 +715,37 @@ class FlashCausalLM(Model): sliding_window=sliding_window, ) + if kv_cache_dtype == "fp8": + self.kv_cache_dtype = torch.uint8 + else: + self.kv_cache_dtype = self.dtype + + if kv_cache_dtype == "fp8" and SYSTEM == "rocm": + logger.info(f"Using KV cache data type: {kv_cache_dtype}") + # Currently scaled KV cache is only enabled on ROCm + if quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales(quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling " + "factors provided but model " + f"{self.model.__class__} does not " + "support loading scaling factors." + ) + else: + logger.info( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) + elif quantization_param_path is not None: + logger.info( + "KV cache scaling factors provided, " + "but the KV cache data type is not FP8. " + "KV cache scaling factors will not be used." + ) + @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch @@ -782,7 +815,7 @@ class FlashCausalLM(Model): self.num_kv_heads, self.head_size, self.sliding_window is not None, - self.dtype, + self.kv_cache_dtype, self.device, ) max_bt = batch.max_blocks @@ -801,7 +834,7 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.dtype).element_size() + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size @@ -823,7 +856,7 @@ class FlashCausalLM(Model): self.num_kv_heads, self.head_size, self.sliding_window is not None, - self.dtype, + self.kv_cache_dtype, self.device, ) @@ -878,7 +911,7 @@ class FlashCausalLM(Model): if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: - logger.exception(f"Decode cuda graph warmup failed") + logger.exception("Decode cuda graph warmup failed") else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 796fbd475..10d70a834 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -29,6 +29,8 @@ class FlashLlama(FlashCausalLM): quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + kv_cache_dtype: Optional[str] = "auto", + quantization_param_path: Optional[str] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -72,6 +74,7 @@ class FlashLlama(FlashCausalLM): ) config.quantize = quantize config.speculator = speculator + config.kv_cache_dtype = kv_cache_dtype torch.distributed.barrier(group=self.process_group) @@ -91,6 +94,8 @@ class FlashLlama(FlashCausalLM): head_size=model.model.head_size, dtype=dtype, device=device, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, rank=rank, world_size=world_size, ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e549b7cbe..2d40746a0 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -193,6 +193,8 @@ def serve( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], + quantization_param_path: Optional[str], trust_remote_code: bool, uds_path: Path, ): @@ -203,6 +205,8 @@ def serve( quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = "auto", + quantization_param_path: Optional[str] = None, trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" @@ -224,6 +228,8 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, + quantization_param_path, trust_remote_code, ) except Exception: @@ -256,6 +262,6 @@ def serve( set_model_id(model_id) asyncio.run( serve_inner( - model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, quantization_param_path, trust_remote_code ) ) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 6cc30e6d5..af1ff016e 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,3 +1,4 @@ +from loguru import logger import torch from text_generation_server.utils.import_utils import SYSTEM @@ -21,6 +22,8 @@ def reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): if SYSTEM == "xpu": ipex.llm.modules.PagedAttention.reshape_and_cache( @@ -28,7 +31,7 @@ def reshape_and_cache( ) else: cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 + key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale ) @@ -42,6 +45,8 @@ def attention( block_tables: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + kv_cache_dtype: str = "auto", + kv_scale: int = 1.0, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -99,8 +104,8 @@ def attention( block_size, max_s, None, - "auto", - 1.0, + kv_cache_dtype, + kv_scale, ) else: # Run PagedAttention V2. @@ -132,6 +137,6 @@ def attention( block_size, max_s, None, - "auto", - 1.0, + kv_cache_dtype, + kv_scale, ) diff --git a/server/text_generation_server/utils/weights_utils.py b/server/text_generation_server/utils/weights_utils.py new file mode 100644 index 000000000..f96d1a6d0 --- /dev/null +++ b/server/text_generation_server/utils/weights_utils.py @@ -0,0 +1,48 @@ +from typing import Optional, Tuple, Iterable +from loguru import logger +import json +from text_generation_server.layers.schema import QuantParamSchema + + +def kv_cache_scales_loader( + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], +) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + + except FileNotFoundError: + logger.error(f"File or directory '{filename}' not found.") + except json.JSONDecodeError: + logger.error(f"Error decoding JSON in file '{filename}'.") + except Exception as e: + logger.error(f"An error occurred while reading '{filename}': {e}") + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 " + f"for all layers in TP rank {tp_rank} " + "as an error occurred during loading." + ) + return []