rebase and update

This commit is contained in:
Mohit Sharma 2024-06-24 08:15:36 +00:00
parent 084de9907c
commit 81fd601c44
11 changed files with 124 additions and 119 deletions

View File

@ -91,18 +91,8 @@ Options:
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
[env: KV_CACHE_DTYPE=]
[default: auto]
```
## QUANTIZATION_PARAM_PATH
```shell
--quantization-param-path <QUANTIZATION_PARAM_PATH>
Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
[env: QUANTIZATION_PARAM_PATH=]
[possible values: fp8, fp8_e5m2]
```
## TRUST_REMOTE_CODE
@ -244,7 +234,7 @@ Options:
--hostname <HOSTNAME>
The IP address to listen on
[env: HOSTNAME=]
[env: HOSTNAME=hf-amd-mi250-dev]
[default: 0.0.0.0]
```
@ -253,7 +243,7 @@ Options:
-p, --port <PORT>
The port to listen on
[env: PORT=]
[env: PORT=80]
[default: 3000]
```
@ -289,7 +279,7 @@ Options:
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance
[env: HUGGINGFACE_HUB_CACHE=]
[env: HUGGINGFACE_HUB_CACHE=/data]
```
## WEIGHTS_CACHE_OVERRIDE

View File

@ -144,6 +144,28 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KvDtype {
#[clap(name = "fp8")]
Fp8,
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
impl std::fmt::Display for KvDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
KvDtype::Fp8 => {
write!(f, "fp8")
},
KvDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
},
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
@ -214,22 +236,12 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Data type for kv cache storage. If "auto", will use model
/// data type. FP8_E5M2 (without scaling) is only supported on cuda
/// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
/// supported for common inference criteria.
#[clap(default_value = "auto", long, env)]
kv_cache_dtype: Option<String>,
/// Path to the JSON file containing the KV cache
/// scaling factors. This should generally be supplied, when
/// KV cache dtype is FP8. Otherwise, KV cache scaling factors
/// default to 1.0, which may cause accuracy issues.
/// FP8_E5M2 (without scaling) is only supported on cuda version
/// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
/// supported for common inference criteria.
#[clap(long, env)]
quantization_param_path: Option<String>,
// Specify the data type for KV cache. By default, it uses the model's data type.
// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fn)'.
// If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided.
// If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KvDtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
@ -481,8 +493,7 @@ fn shard_manager(
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
kv_cache_dtype: Option<String>,
quantization_param_path: Option<String>,
kv_cache_dtype: Option<KvDtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
@ -556,12 +567,7 @@ fn shard_manager(
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype)
}
if let Some(quantization_param_path) = quantization_param_path {
shard_args.push("--quantization-param-path".to_string());
shard_args.push(quantization_param_path)
shard_args.push(kv_cache_dtype.to_string());
}
// Model optional revision
@ -1067,8 +1073,7 @@ fn spawn_shards(
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype.clone();
let quantization_param_path = args.quantization_param_path.clone();
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
@ -1087,7 +1092,6 @@ fn spawn_shards(
speculate,
dtype,
kv_cache_dtype,
quantization_param_path,
trust_remote_code,
uds_path,
rank,

View File

@ -7,7 +7,7 @@ from loguru import logger
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
from text_generation_server.utils.import_utils import SYSTEM
app = typer.Typer()
@ -38,7 +38,6 @@ def serve(
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: str = "auto",
quantization_param_path: Optional[str] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@ -92,6 +91,13 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
if kv_cache_dtype in {"fp8", "fp8_e5m2"} and SYSTEM not in {"cuda", "rocm"}:
raise RuntimeError(f"{kv_cache_dtype} KV cache is only supported on Nvidia and AMD GPUs.")
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
raise RuntimeError(f"fp8_e5m2 KV cache is only supported on Nvidia GPUs.")
server.serve(
model_id,
revision,
@ -100,7 +106,6 @@ def serve(
speculate,
dtype,
kv_cache_dtype,
quantization_param_path,
trust_remote_code,
uds_path,
max_input_tokens,

View File

@ -20,8 +20,10 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
def paged_attention(
@ -34,6 +36,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -78,8 +82,8 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
@ -111,8 +115,8 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
kv_cache_dtype,
kv_scale,
)

View File

@ -25,8 +25,10 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
def paged_attention(
@ -39,6 +41,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -83,8 +87,8 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
@ -116,8 +120,8 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
kv_cache_dtype,
kv_scale,
)

View File

@ -39,6 +39,8 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
@ -55,6 +57,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
query = query.contiguous()
block_size = value_cache.shape[3]

View File

@ -114,7 +114,6 @@ except ImportError as e:
if MAMBA_AVAILABLE:
__all__.append(Mamba)
class ModelType(enum.Enum):
IDEFICS2 = {
"type": "idefics2",
@ -245,6 +244,11 @@ class ModelType(enum.Enum):
"multimodal": True,
}
FP8_KVCACHE_SUPPORTED_MODELS = {
"llama",
"baichun",
"phi3",
}
__GLOBALS = locals()
for data in ModelType:
@ -259,7 +263,6 @@ def get_model(
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
quantization_param_path: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
@ -279,9 +282,6 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if kv_cache_dtype not in {"auto", "fp8"}:
raise RuntimeError(f"Unknown kv_cache_dtype {kv_cache_dtype}")
if speculate is not None:
set_speculate(speculate)
else:
@ -292,6 +292,11 @@ def get_model(
)
model_type = config_dict.get("model_type", None)
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
raise RuntimeError(
f"kv_cache_dtype is only supported for Llama models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
)
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
@ -600,7 +605,6 @@ def get_model(
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
trust_remote_code=trust_remote_code,
)
elif sharded:

View File

@ -25,6 +25,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
@ -42,9 +43,8 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights_utils import kv_cache_scales_loader
from loguru import logger
from loguru import logger
if SYSTEM == "rocm":
try:
from vllm import _custom_C
@ -135,15 +135,13 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_cache_dtype = config.kv_cache_dtype
if self.kv_cache_dtype == "fp8":
self.kv_scale = weights.get_kv_cache_scaling_factor(prefix, self.kv_cache_dtype)
else:
self.kv_scale = 1.0
self.kv_cache_dtype = "auto"
logger.info(f"kv_cache_dtype: {self.kv_cache_dtype}, kv_scale: {self.kv_scale}")
def forward(
self,
@ -170,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots, self.kv_cache_dtype, self.kv_scale)
# output tensor
attn_output = torch.empty_like(query)
@ -406,10 +404,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
@ -427,10 +421,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.config = config
for layer_idx in range(config.num_hidden_layers):
self.model.layers[layer_idx].self_attn.kv_cache_dtype = config.kv_cache_dtype
def forward(
self,
@ -462,33 +452,3 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
self.tp_rank,
self.tp_world_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
layer_self_attn = self.model.layers[layer_idx].self_attn
if SYSTEM == "rocm":
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
logger.info(
f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}"
)
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)

View File

@ -29,7 +29,6 @@ class FlashLlama(FlashCausalLM):
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
kv_cache_dtype: Optional[str] = "auto",
quantization_param_path: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
@ -75,6 +74,7 @@ class FlashLlama(FlashCausalLM):
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,
@ -82,10 +82,8 @@ class FlashLlama(FlashCausalLM):
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
device=device,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
rank=rank,
world_size=world_size,
)

View File

@ -198,7 +198,6 @@ def serve(
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
quantization_param_path: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
@ -211,7 +210,6 @@ def serve(
speculate: Optional[int] = None,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = "auto",
quantization_param_path: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
@ -234,7 +232,6 @@ def serve(
speculate,
dtype,
kv_cache_dtype,
quantization_param_path,
trust_remote_code,
max_input_tokens,
)
@ -272,6 +269,6 @@ def serve(
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, quantization_param_path, trust_remote_code
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, trust_remote_code
)
)

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
from server.text_generation_server.utils.import_utils import SYSTEM
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
@ -88,7 +89,7 @@ class Weights:
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
if tensor.dtype not in [torch.int16, torch.int32,torch.int64] and not tensor_name.endswith("kv_scale"):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
@ -762,6 +763,40 @@ class Weights:
except Exception:
pass
def get_kv_cache_scaling_factor(self, prefix: str, kv_cache_dtype: str):
try:
kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist()
except RuntimeError:
if kv_cache_dtype == "fp8":
log_once(
logger.warning,
"Could not find the `kv_scale` in checkpoint for `fp8_e4m3`. Using scaling factor"
"`1.0`. This may result in accuracy issues. Please ensure the checkpoint includes "
"the correct KV cache scaling factor.",
)
kv_scale = 1.0
else:
if kv_cache_dtype == "fp8_e5m2":
raise RuntimeError(
"Found `kv_scale` in the checkpoint, but `fp8_e5m2` KV dtype do not support `kv_scale` > 1.0"
)
if not isinstance(kv_scale, float):
raise RuntimeError(
"Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache"
)
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
# The multiplication by 2 compensates for the different numeric representation
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
# After this adjustment, the overall effect is equivalent to the scaling applied without
# it on Nvidia GPUs.
if SYSTEM == "rocm":
kv_scale *= 2.0
return kv_scale
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""