mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
add kvcache fp8 support
This commit is contained in:
parent
efb73fcb59
commit
8c437a80bc
@ -85,6 +85,23 @@ Options:
|
|||||||
[env: DTYPE=]
|
[env: DTYPE=]
|
||||||
[possible values: float16, bfloat16]
|
[possible values: float16, bfloat16]
|
||||||
|
|
||||||
|
```
|
||||||
|
## KV_CACHE_DTYPE
|
||||||
|
```shell
|
||||||
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
|
Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
|
||||||
|
|
||||||
|
[env: KV_CACHE_DTYPE=]
|
||||||
|
[default: auto]
|
||||||
|
|
||||||
|
```
|
||||||
|
## QUANTIZATION_PARAM_PATH
|
||||||
|
```shell
|
||||||
|
--quantization-param-path <QUANTIZATION_PARAM_PATH>
|
||||||
|
Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
|
||||||
|
|
||||||
|
[env: QUANTIZATION_PARAM_PATH=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## TRUST_REMOTE_CODE
|
## TRUST_REMOTE_CODE
|
||||||
```shell
|
```shell
|
||||||
|
@ -184,6 +184,23 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
|
||||||
|
/// Data type for kv cache storage. If "auto", will use model
|
||||||
|
/// data type. FP8_E5M2 (without scaling) is only supported on cuda
|
||||||
|
/// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
|
||||||
|
/// supported for common inference criteria.
|
||||||
|
#[clap(default_value = "auto", long, env)]
|
||||||
|
kv_cache_dtype: Option<String>,
|
||||||
|
|
||||||
|
/// Path to the JSON file containing the KV cache
|
||||||
|
/// scaling factors. This should generally be supplied, when
|
||||||
|
/// KV cache dtype is FP8. Otherwise, KV cache scaling factors
|
||||||
|
/// default to 1.0, which may cause accuracy issues.
|
||||||
|
/// FP8_E5M2 (without scaling) is only supported on cuda version
|
||||||
|
/// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
|
||||||
|
/// supported for common inference criteria.
|
||||||
|
#[clap(long, env)]
|
||||||
|
quantization_param_path: Option<String>,
|
||||||
|
|
||||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
/// contributed in a newer revision.
|
/// contributed in a newer revision.
|
||||||
@ -434,6 +451,8 @@ fn shard_manager(
|
|||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
speculate: Option<usize>,
|
speculate: Option<usize>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
kv_cache_dtype: Option<String>,
|
||||||
|
quantization_param_path: Option<String>,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
@ -503,6 +522,16 @@ fn shard_manager(
|
|||||||
shard_args.push(dtype.to_string())
|
shard_args.push(dtype.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||||
|
shard_args.push("--kv-cache-dtype".to_string());
|
||||||
|
shard_args.push(kv_cache_dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(quantization_param_path) = quantization_param_path {
|
||||||
|
shard_args.push("--quantization-param-path".to_string());
|
||||||
|
shard_args.push(quantization_param_path)
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = revision {
|
if let Some(revision) = revision {
|
||||||
shard_args.push("--revision".to_string());
|
shard_args.push("--revision".to_string());
|
||||||
@ -1000,6 +1029,8 @@ fn spawn_shards(
|
|||||||
let quantize = args.quantize;
|
let quantize = args.quantize;
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
|
let kv_cache_dtype = args.kv_cache_dtype.clone();
|
||||||
|
let quantization_param_path = args.quantization_param_path.clone();
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
let master_port = args.master_port;
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
@ -1017,6 +1048,8 @@ fn spawn_shards(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
quantization_param_path,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
|
@ -35,6 +35,8 @@ def serve(
|
|||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
quantization_param_path: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
@ -94,6 +96,8 @@ def serve(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
quantization_param_path,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
)
|
)
|
||||||
|
90
server/text_generation_server/layers/schema.py
Normal file
90
server/text_generation_server/layers/schema.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
This file contains the Pydantic schemas for various quantization-related
|
||||||
|
parameters. When a relevant quantization technique is specified, these
|
||||||
|
parameters are loaded in the form of a JSON alongside the model weights
|
||||||
|
and augment the model with additional information needed for use of that
|
||||||
|
technique. The format of this JSON should be specified by one or more
|
||||||
|
schemas contained here.
|
||||||
|
|
||||||
|
For example, when the KV cache is quantized to FP8-E4M3 (currently only
|
||||||
|
possible on ROCm), the model can be optionally augmented with KV cache
|
||||||
|
scaling factors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheQuantSchema(BaseModel):
|
||||||
|
dtype: str
|
||||||
|
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
||||||
|
# layer indices to their per-tensor KV cache scaling factor.
|
||||||
|
# TODO: Consider pulling this and its validation methods out into its
|
||||||
|
# own schema class (tricky as its members are variable)
|
||||||
|
scaling_factor: Dict[int, Dict[int, float]]
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||||
|
assert self.dtype == "float8_e4m3fn", (
|
||||||
|
"Loaded scaling factors intended for KV cache dtype = "
|
||||||
|
f"{self.dtype} rather than float8_e4m3fn!"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||||
|
context = info.context
|
||||||
|
if context:
|
||||||
|
tp_size = context["tp_size"]
|
||||||
|
num_hidden_layers = context["num_hidden_layers"]
|
||||||
|
assert len(self.scaling_factor) == tp_size, (
|
||||||
|
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
||||||
|
f"but LLM engine is currently running with TP size {tp_size}."
|
||||||
|
)
|
||||||
|
for tp_rank, layer_maps in self.scaling_factor.items():
|
||||||
|
assert len(layer_maps) == num_hidden_layers, (
|
||||||
|
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
||||||
|
f"Expected {num_hidden_layers} layers, got "
|
||||||
|
f"{len(layer_maps)}."
|
||||||
|
)
|
||||||
|
for i in range(tp_size):
|
||||||
|
assert (
|
||||||
|
i in self.scaling_factor
|
||||||
|
), f"KV cache scales map for TP rank {i} not found."
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||||
|
context = info.context
|
||||||
|
if context:
|
||||||
|
tp_rank = context["tp_rank"]
|
||||||
|
num_hidden_layers = context["num_hidden_layers"]
|
||||||
|
layer_scales_map = self.scaling_factor[tp_rank]
|
||||||
|
for i in range(num_hidden_layers):
|
||||||
|
assert i in layer_scales_map, (
|
||||||
|
f"Could not find KV cache scales for layer {i} in "
|
||||||
|
f"TP rank {tp_rank}."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class QuantParamSchema(BaseModel):
|
||||||
|
# TODO: Generalize and extend with more fields
|
||||||
|
# (e.g. weights/activations params) once functionality is enabled
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
model_type: Optional[str]
|
||||||
|
kv_cache: KVCacheQuantSchema
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
||||||
|
context = info.context
|
||||||
|
if context:
|
||||||
|
model_type = context.get("model_type", None)
|
||||||
|
if model_type is not None:
|
||||||
|
assert model_type == self.model_type, (
|
||||||
|
f"Model type is {model_type} but loaded "
|
||||||
|
f"scaling factors belonging to different "
|
||||||
|
f"model type {self.model_type}!"
|
||||||
|
)
|
||||||
|
return self
|
@ -260,6 +260,8 @@ def get_model(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
quantization_param_path: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
@ -273,6 +275,9 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
|
if kv_cache_dtype not in {"auto", "fp8"}:
|
||||||
|
raise RuntimeError(f"Unknown kv_cache_dtype {kv_cache_dtype}")
|
||||||
|
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
else:
|
else:
|
||||||
@ -563,6 +568,8 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
quantization_param_path=quantization_param_path,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
|
@ -25,7 +25,6 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
@ -39,6 +38,8 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights_utils import kv_cache_scales_loader
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
@ -126,6 +127,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
# This will be overwritten by model initialization if we are using it.
|
||||||
|
# N.B. currently we only support per tensor scalar scaling factors
|
||||||
|
# & only applicable to ROCm (AMD GPU).
|
||||||
|
# The scaling factor convention we are assuming is
|
||||||
|
# quantized_value * scaling_factor ~= true_value
|
||||||
|
# which is consistent with the practice of setting
|
||||||
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
|
self.kv_scale = 1.0
|
||||||
|
self.kv_cache_dtype = "auto"
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -152,7 +163,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0],
|
||||||
|
kv[:, 1],
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
slots,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
@ -182,6 +199,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -387,6 +406,10 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||||
@ -404,6 +427,10 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
prefix=suffix if not prefix else f"{prefix}.suffix",
|
prefix=suffix if not prefix else f"{prefix}.suffix",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
for layer_idx in range(config.num_hidden_layers):
|
||||||
|
self.model.layers[layer_idx].self_attn.kv_cache_dtype = config.kv_cache_dtype
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -435,3 +462,33 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
# If this function is called, it should always initialize KV cache scale
|
||||||
|
# factors (or else raise an exception). Thus, handled exceptions should
|
||||||
|
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||||
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
|
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||||
|
quantization_param_path,
|
||||||
|
self.tp_rank,
|
||||||
|
self.tp_world_size,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.__class__.model_type,
|
||||||
|
):
|
||||||
|
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
# The scaling factor convention we are assuming is
|
||||||
|
# quantized_value * scaling_factor ~= true_value
|
||||||
|
# which is consistent with the practice of setting
|
||||||
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
|
scaling_factor *= 2
|
||||||
|
|
||||||
|
if hasattr(layer_self_attn, "kv_scale"):
|
||||||
|
layer_self_attn.kv_scale = scaling_factor
|
||||||
|
logger.info(
|
||||||
|
f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Self attention has no KV cache scaling " "factor attribute!"
|
||||||
|
)
|
||||||
|
@ -692,6 +692,8 @@ class FlashCausalLM(Model):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
quantization_param_path: Optional[str] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
@ -713,6 +715,37 @@ class FlashCausalLM(Model):
|
|||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
self.kv_cache_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = self.dtype
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8" and SYSTEM == "rocm":
|
||||||
|
logger.info(f"Using KV cache data type: {kv_cache_dtype}")
|
||||||
|
# Currently scaled KV cache is only enabled on ROCm
|
||||||
|
if quantization_param_path is not None:
|
||||||
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||||
|
self.model.load_kv_cache_scales(quantization_param_path)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Using FP8 KV cache and scaling "
|
||||||
|
"factors provided but model "
|
||||||
|
f"{self.model.__class__} does not "
|
||||||
|
"support loading scaling factors."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Using FP8 KV cache but no scaling factors "
|
||||||
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
|
"This may lead to less accurate results!"
|
||||||
|
)
|
||||||
|
elif quantization_param_path is not None:
|
||||||
|
logger.info(
|
||||||
|
"KV cache scaling factors provided, "
|
||||||
|
"but the KV cache data type is not FP8. "
|
||||||
|
"KV cache scaling factors will not be used."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
@ -782,7 +815,7 @@ class FlashCausalLM(Model):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.sliding_window is not None,
|
self.sliding_window is not None,
|
||||||
self.dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
@ -801,7 +834,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
@ -823,7 +856,7 @@ class FlashCausalLM(Model):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.sliding_window is not None,
|
self.sliding_window is not None,
|
||||||
self.dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -878,7 +911,7 @@ class FlashCausalLM(Model):
|
|||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception("Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
|
@ -29,6 +29,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
quantization_param_path: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
@ -72,6 +74,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
|
config.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@ -91,6 +94,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
quantization_param_path=quantization_param_path,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
@ -193,6 +193,8 @@ def serve(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
quantization_param_path: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
@ -203,6 +205,8 @@ def serve(
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
quantization_param_path: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
@ -224,6 +228,8 @@ def serve(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
quantization_param_path,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -256,6 +262,6 @@ def serve(
|
|||||||
set_model_id(model_id)
|
set_model_id(model_id)
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, quantization_param_path, trust_remote_code
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
@ -21,6 +22,8 @@ def reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "xpu":
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
@ -28,7 +31,7 @@ def reshape_and_cache(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -42,6 +45,8 @@ def attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
@ -99,8 +104,8 @@ def attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -132,6 +137,6 @@ def attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scale,
|
||||||
)
|
)
|
||||||
|
48
server/text_generation_server/utils/weights_utils.py
Normal file
48
server/text_generation_server/utils/weights_utils.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from typing import Optional, Tuple, Iterable
|
||||||
|
from loguru import logger
|
||||||
|
import json
|
||||||
|
from text_generation_server.layers.schema import QuantParamSchema
|
||||||
|
|
||||||
|
|
||||||
|
def kv_cache_scales_loader(
|
||||||
|
filename: str,
|
||||||
|
tp_rank: int,
|
||||||
|
tp_size: int,
|
||||||
|
num_hidden_layers: int,
|
||||||
|
model_type: Optional[str],
|
||||||
|
) -> Iterable[Tuple[int, float]]:
|
||||||
|
"""
|
||||||
|
A simple utility to read in KV cache scaling factors that have been
|
||||||
|
previously serialized to disk. Used by the model to populate the appropriate
|
||||||
|
KV cache scaling factors. The serialization should represent a dictionary
|
||||||
|
whose keys are the TP ranks and values are another dictionary mapping layers
|
||||||
|
to their KV cache scaling factors.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(filename) as f:
|
||||||
|
context = {
|
||||||
|
"model_type": model_type,
|
||||||
|
"num_hidden_layers": num_hidden_layers,
|
||||||
|
"tp_rank": tp_rank,
|
||||||
|
"tp_size": tp_size,
|
||||||
|
}
|
||||||
|
schema_dct = json.load(f)
|
||||||
|
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
||||||
|
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||||
|
return layer_scales_map.items()
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"File or directory '{filename}' not found.")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Error decoding JSON in file '{filename}'.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An error occurred while reading '{filename}': {e}")
|
||||||
|
# This section is reached if and only if any of the excepts are hit
|
||||||
|
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||||
|
# which ultimately defaults to 1.0 scales
|
||||||
|
logger.warning(
|
||||||
|
"Defaulting to KV cache scaling factors = 1.0 "
|
||||||
|
f"for all layers in TP rank {tp_rank} "
|
||||||
|
"as an error occurred during loading."
|
||||||
|
)
|
||||||
|
return []
|
Loading…
Reference in New Issue
Block a user