mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
rebase and update
This commit is contained in:
parent
084de9907c
commit
81fd601c44
@ -91,18 +91,8 @@ Options:
|
|||||||
## KV_CACHE_DTYPE
|
## KV_CACHE_DTYPE
|
||||||
```shell
|
```shell
|
||||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
|
|
||||||
|
|
||||||
[env: KV_CACHE_DTYPE=]
|
[env: KV_CACHE_DTYPE=]
|
||||||
[default: auto]
|
[possible values: fp8, fp8_e5m2]
|
||||||
|
|
||||||
```
|
|
||||||
## QUANTIZATION_PARAM_PATH
|
|
||||||
```shell
|
|
||||||
--quantization-param-path <QUANTIZATION_PARAM_PATH>
|
|
||||||
Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
|
|
||||||
|
|
||||||
[env: QUANTIZATION_PARAM_PATH=]
|
|
||||||
|
|
||||||
```
|
```
|
||||||
## TRUST_REMOTE_CODE
|
## TRUST_REMOTE_CODE
|
||||||
@ -244,7 +234,7 @@ Options:
|
|||||||
--hostname <HOSTNAME>
|
--hostname <HOSTNAME>
|
||||||
The IP address to listen on
|
The IP address to listen on
|
||||||
|
|
||||||
[env: HOSTNAME=]
|
[env: HOSTNAME=hf-amd-mi250-dev]
|
||||||
[default: 0.0.0.0]
|
[default: 0.0.0.0]
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -253,7 +243,7 @@ Options:
|
|||||||
-p, --port <PORT>
|
-p, --port <PORT>
|
||||||
The port to listen on
|
The port to listen on
|
||||||
|
|
||||||
[env: PORT=]
|
[env: PORT=80]
|
||||||
[default: 3000]
|
[default: 3000]
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -289,7 +279,7 @@ Options:
|
|||||||
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
|
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
|
||||||
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance
|
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance
|
||||||
|
|
||||||
[env: HUGGINGFACE_HUB_CACHE=]
|
[env: HUGGINGFACE_HUB_CACHE=/data]
|
||||||
|
|
||||||
```
|
```
|
||||||
## WEIGHTS_CACHE_OVERRIDE
|
## WEIGHTS_CACHE_OVERRIDE
|
||||||
|
@ -144,6 +144,28 @@ impl std::fmt::Display for Dtype {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum KvDtype {
|
||||||
|
#[clap(name = "fp8")]
|
||||||
|
Fp8,
|
||||||
|
#[clap(name = "fp8_e5m2")]
|
||||||
|
Fp8e5m2,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for KvDtype {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// To keep in track with `server`.
|
||||||
|
match self {
|
||||||
|
KvDtype::Fp8 => {
|
||||||
|
write!(f, "fp8")
|
||||||
|
},
|
||||||
|
KvDtype::Fp8e5m2 => {
|
||||||
|
write!(f, "fp8_e5m2")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum RopeScaling {
|
enum RopeScaling {
|
||||||
Linear,
|
Linear,
|
||||||
@ -214,22 +236,12 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
|
||||||
/// Data type for kv cache storage. If "auto", will use model
|
// Specify the data type for KV cache. By default, it uses the model's data type.
|
||||||
/// data type. FP8_E5M2 (without scaling) is only supported on cuda
|
// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fn)'.
|
||||||
/// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
|
// If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided.
|
||||||
/// supported for common inference criteria.
|
// If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
|
||||||
#[clap(default_value = "auto", long, env)]
|
#[clap(long, env, value_enum)]
|
||||||
kv_cache_dtype: Option<String>,
|
kv_cache_dtype: Option<KvDtype>,
|
||||||
|
|
||||||
/// Path to the JSON file containing the KV cache
|
|
||||||
/// scaling factors. This should generally be supplied, when
|
|
||||||
/// KV cache dtype is FP8. Otherwise, KV cache scaling factors
|
|
||||||
/// default to 1.0, which may cause accuracy issues.
|
|
||||||
/// FP8_E5M2 (without scaling) is only supported on cuda version
|
|
||||||
/// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
|
|
||||||
/// supported for common inference criteria.
|
|
||||||
#[clap(long, env)]
|
|
||||||
quantization_param_path: Option<String>,
|
|
||||||
|
|
||||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
@ -481,8 +493,7 @@ fn shard_manager(
|
|||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
speculate: Option<usize>,
|
speculate: Option<usize>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
kv_cache_dtype: Option<String>,
|
kv_cache_dtype: Option<KvDtype>,
|
||||||
quantization_param_path: Option<String>,
|
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
@ -556,12 +567,7 @@ fn shard_manager(
|
|||||||
|
|
||||||
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||||
shard_args.push("--kv-cache-dtype".to_string());
|
shard_args.push("--kv-cache-dtype".to_string());
|
||||||
shard_args.push(kv_cache_dtype)
|
shard_args.push(kv_cache_dtype.to_string());
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(quantization_param_path) = quantization_param_path {
|
|
||||||
shard_args.push("--quantization-param-path".to_string());
|
|
||||||
shard_args.push(quantization_param_path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
@ -1067,8 +1073,7 @@ fn spawn_shards(
|
|||||||
let quantize = args.quantize;
|
let quantize = args.quantize;
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
let kv_cache_dtype = args.kv_cache_dtype.clone();
|
let kv_cache_dtype = args.kv_cache_dtype;
|
||||||
let quantization_param_path = args.quantization_param_path.clone();
|
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
let master_port = args.master_port;
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
@ -1087,7 +1092,6 @@ fn spawn_shards(
|
|||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
quantization_param_path,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@ -38,7 +38,6 @@ def serve(
|
|||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str = "auto",
|
||||||
quantization_param_path: Optional[str] = None,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
@ -92,6 +91,13 @@ def serve(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kv_cache_dtype in {"fp8", "fp8_e5m2"} and SYSTEM not in {"cuda", "rocm"}:
|
||||||
|
raise RuntimeError(f"{kv_cache_dtype} KV cache is only supported on Nvidia and AMD GPUs.")
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
|
||||||
|
raise RuntimeError(f"fp8_e5m2 KV cache is only supported on Nvidia GPUs.")
|
||||||
|
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -100,7 +106,6 @@ def serve(
|
|||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
quantization_param_path,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
@ -20,8 +20,10 @@ def reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
@ -34,6 +36,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
@ -78,8 +82,8 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -111,8 +115,8 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,8 +25,10 @@ def reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
@ -39,6 +41,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
@ -83,8 +87,8 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -116,8 +120,8 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +39,8 @@ def reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots
|
key, value, key_cache, value_cache, slots
|
||||||
@ -55,6 +57,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
kv_scale: int = 1.0,
|
||||||
):
|
):
|
||||||
query = query.contiguous()
|
query = query.contiguous()
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
|
@ -114,7 +114,6 @@ except ImportError as e:
|
|||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
IDEFICS2 = {
|
IDEFICS2 = {
|
||||||
"type": "idefics2",
|
"type": "idefics2",
|
||||||
@ -245,6 +244,11 @@ class ModelType(enum.Enum):
|
|||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FP8_KVCACHE_SUPPORTED_MODELS = {
|
||||||
|
"llama",
|
||||||
|
"baichun",
|
||||||
|
"phi3",
|
||||||
|
}
|
||||||
|
|
||||||
__GLOBALS = locals()
|
__GLOBALS = locals()
|
||||||
for data in ModelType:
|
for data in ModelType:
|
||||||
@ -259,7 +263,6 @@ def get_model(
|
|||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
quantization_param_path: Optional[str],
|
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
@ -278,9 +281,6 @@ def get_model(
|
|||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
if kv_cache_dtype not in {"auto", "fp8"}:
|
|
||||||
raise RuntimeError(f"Unknown kv_cache_dtype {kv_cache_dtype}")
|
|
||||||
|
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
@ -292,6 +292,11 @@ def get_model(
|
|||||||
)
|
)
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
|
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"kv_cache_dtype is only supported for Llama models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
speculator = None
|
speculator = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
medusa_model_id = model_id
|
medusa_model_id = model_id
|
||||||
@ -600,7 +605,6 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
quantization_param_path=quantization_param_path,
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
|
@ -25,6 +25,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -42,9 +43,8 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights_utils import kv_cache_scales_loader
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
from vllm import _custom_C
|
||||||
@ -135,15 +135,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
# This will be overwritten by model initialization if we are using it.
|
self.kv_cache_dtype = config.kv_cache_dtype
|
||||||
# N.B. currently we only support per tensor scalar scaling factors
|
|
||||||
# & only applicable to ROCm (AMD GPU).
|
if self.kv_cache_dtype == "fp8":
|
||||||
# The scaling factor convention we are assuming is
|
self.kv_scale = weights.get_kv_cache_scaling_factor(prefix, self.kv_cache_dtype)
|
||||||
# quantized_value * scaling_factor ~= true_value
|
else:
|
||||||
# which is consistent with the practice of setting
|
self.kv_scale = 1.0
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
logger.info(f"kv_cache_dtype: {self.kv_cache_dtype}, kv_scale: {self.kv_scale}")
|
||||||
self.kv_scale = 1.0
|
|
||||||
self.kv_cache_dtype = "auto"
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -170,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots, self.kv_cache_dtype, self.kv_scale)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -406,10 +404,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
|
||||||
self.tp_rank = process_group.rank()
|
|
||||||
self.tp_world_size = process_group.size()
|
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||||
@ -427,10 +421,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.config = config
|
|
||||||
|
|
||||||
for layer_idx in range(config.num_hidden_layers):
|
|
||||||
self.model.layers[layer_idx].self_attn.kv_cache_dtype = config.kv_cache_dtype
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -462,33 +452,3 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# If this function is called, it should always initialize KV cache scale
|
|
||||||
# factors (or else raise an exception). Thus, handled exceptions should
|
|
||||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
|
||||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
|
||||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
|
||||||
quantization_param_path,
|
|
||||||
self.tp_rank,
|
|
||||||
self.tp_world_size,
|
|
||||||
self.config.num_hidden_layers,
|
|
||||||
self.config.__class__.model_type,
|
|
||||||
):
|
|
||||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
# The scaling factor convention we are assuming is
|
|
||||||
# quantized_value * scaling_factor ~= true_value
|
|
||||||
# which is consistent with the practice of setting
|
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
|
||||||
scaling_factor *= 2
|
|
||||||
|
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
|
||||||
layer_self_attn.kv_scale = scaling_factor
|
|
||||||
logger.info(
|
|
||||||
f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Self attention has no KV cache scaling " "factor attribute!"
|
|
||||||
)
|
|
||||||
|
@ -29,7 +29,6 @@ class FlashLlama(FlashCausalLM):
|
|||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
quantization_param_path: Optional[str] = None,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
@ -75,6 +74,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
|
|
||||||
prefix = ""
|
prefix = ""
|
||||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -82,10 +82,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
|
||||||
device=device,
|
device=device,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
quantization_param_path=quantization_param_path,
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
@ -198,7 +198,6 @@ def serve(
|
|||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
quantization_param_path: Optional[str],
|
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
@ -211,7 +210,6 @@ def serve(
|
|||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
quantization_param_path: Optional[str] = None,
|
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
@ -234,7 +232,6 @@ def serve(
|
|||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
quantization_param_path,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
@ -272,6 +269,6 @@ def serve(
|
|||||||
set_model_id(model_id)
|
set_model_id(model_id)
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, quantization_param_path, trust_remote_code
|
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, trust_remote_code
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from safetensors import safe_open, SafetensorError
|
from safetensors import safe_open, SafetensorError
|
||||||
|
from server.text_generation_server.utils.import_utils import SYSTEM
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@ -88,7 +89,7 @@ class Weights:
|
|||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. Exl2 uses int16
|
# u4 which are disguised as int32. Exl2 uses int16
|
||||||
# as well.
|
# as well.
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if tensor.dtype not in [torch.int16, torch.int32,torch.int64] and not tensor_name.endswith("kv_scale"):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
@ -762,6 +763,40 @@ class Weights:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_kv_cache_scaling_factor(self, prefix: str, kv_cache_dtype: str):
|
||||||
|
try:
|
||||||
|
kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist()
|
||||||
|
except RuntimeError:
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
log_once(
|
||||||
|
logger.warning,
|
||||||
|
"Could not find the `kv_scale` in checkpoint for `fp8_e4m3`. Using scaling factor"
|
||||||
|
"`1.0`. This may result in accuracy issues. Please ensure the checkpoint includes "
|
||||||
|
"the correct KV cache scaling factor.",
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_scale = 1.0
|
||||||
|
else:
|
||||||
|
if kv_cache_dtype == "fp8_e5m2":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Found `kv_scale` in the checkpoint, but `fp8_e5m2` KV dtype do not support `kv_scale` > 1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(kv_scale, float):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
|
||||||
|
# The multiplication by 2 compensates for the different numeric representation
|
||||||
|
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
|
||||||
|
# After this adjustment, the overall effect is equivalent to the scaling applied without
|
||||||
|
# it on Nvidia GPUs.
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
kv_scale *= 2.0
|
||||||
|
|
||||||
|
return kv_scale
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user