add kvcache fp8 support

This commit is contained in:
mohit@huggingface.co 2024-05-23 16:00:18 +00:00
parent efb73fcb59
commit 8c437a80bc
11 changed files with 317 additions and 12 deletions

View File

@ -85,6 +85,23 @@ Options:
[env: DTYPE=] [env: DTYPE=]
[possible values: float16, bfloat16] [possible values: float16, bfloat16]
```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
[env: KV_CACHE_DTYPE=]
[default: auto]
```
## QUANTIZATION_PARAM_PATH
```shell
--quantization-param-path <QUANTIZATION_PARAM_PATH>
Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria
[env: QUANTIZATION_PARAM_PATH=]
``` ```
## TRUST_REMOTE_CODE ## TRUST_REMOTE_CODE
```shell ```shell

View File

@ -184,6 +184,23 @@ struct Args {
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
dtype: Option<Dtype>, dtype: Option<Dtype>,
/// Data type for kv cache storage. If "auto", will use model
/// data type. FP8_E5M2 (without scaling) is only supported on cuda
/// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
/// supported for common inference criteria.
#[clap(default_value = "auto", long, env)]
kv_cache_dtype: Option<String>,
/// Path to the JSON file containing the KV cache
/// scaling factors. This should generally be supplied, when
/// KV cache dtype is FP8. Otherwise, KV cache scaling factors
/// default to 1.0, which may cause accuracy issues.
/// FP8_E5M2 (without scaling) is only supported on cuda version
/// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
/// supported for common inference criteria.
#[clap(long, env)]
quantization_param_path: Option<String>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been /// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision. /// contributed in a newer revision.
@ -434,6 +451,8 @@ fn shard_manager(
quantize: Option<Quantization>, quantize: Option<Quantization>,
speculate: Option<usize>, speculate: Option<usize>,
dtype: Option<Dtype>, dtype: Option<Dtype>,
kv_cache_dtype: Option<String>,
quantization_param_path: Option<String>,
trust_remote_code: bool, trust_remote_code: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
@ -503,6 +522,16 @@ fn shard_manager(
shard_args.push(dtype.to_string()) shard_args.push(dtype.to_string())
} }
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype)
}
if let Some(quantization_param_path) = quantization_param_path {
shard_args.push("--quantization-param-path".to_string());
shard_args.push(quantization_param_path)
}
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_args.push("--revision".to_string()); shard_args.push("--revision".to_string());
@ -1000,6 +1029,8 @@ fn spawn_shards(
let quantize = args.quantize; let quantize = args.quantize;
let speculate = args.speculate; let speculate = args.speculate;
let dtype = args.dtype; let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype.clone();
let quantization_param_path = args.quantization_param_path.clone();
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port; let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
@ -1017,6 +1048,8 @@ fn spawn_shards(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
quantization_param_path,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
rank, rank,

View File

@ -35,6 +35,8 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
kv_cache_dtype: str = "auto",
quantization_param_path: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -94,6 +96,8 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
quantization_param_path,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
) )

View File

@ -0,0 +1,90 @@
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!"
)
return self
@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}."
)
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}."
)
for i in range(tp_size):
assert (
i in self.scaling_factor
), f"KV cache scales map for TP rank {i} not found."
return self
@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}."
)
return self
class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema
@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!"
)
return self

View File

@ -260,6 +260,8 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
quantization_param_path: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
@ -273,6 +275,9 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
if kv_cache_dtype not in {"auto", "fp8"}:
raise RuntimeError(f"Unknown kv_cache_dtype {kv_cache_dtype}")
if speculate is not None: if speculate is not None:
set_speculate(speculate) set_speculate(speculate)
else: else:
@ -563,6 +568,8 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:

View File

@ -25,7 +25,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
@ -39,6 +38,8 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights_utils import kv_cache_scales_loader
from loguru import logger
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -126,6 +127,16 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.kv_cache_dtype = "auto"
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -152,7 +163,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.kv_cache_dtype,
self.kv_scale,
) )
# output tensor # output tensor
@ -182,6 +199,8 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables, block_tables,
input_lengths, input_lengths,
max_s, max_s,
self.kv_cache_dtype,
self.kv_scale,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -387,6 +406,10 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
@ -404,6 +427,10 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefix=suffix if not prefix else f"{prefix}.suffix", prefix=suffix if not prefix else f"{prefix}.suffix",
weights=weights, weights=weights,
) )
self.config = config
for layer_idx in range(config.num_hidden_layers):
self.model.layers[layer_idx].self_attn.kv_cache_dtype = config.kv_cache_dtype
def forward( def forward(
self, self,
@ -435,3 +462,33 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
self.tp_rank,
self.tp_world_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
layer_self_attn = self.model.layers[layer_idx].self_attn
if SYSTEM == "rocm":
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
logger.info(
f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}"
)
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)

View File

@ -692,6 +692,8 @@ class FlashCausalLM(Model):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
kv_cache_dtype: str = "auto",
quantization_param_path: Optional[str] = None,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
@ -713,6 +715,37 @@ class FlashCausalLM(Model):
sliding_window=sliding_window, sliding_window=sliding_window,
) )
if kv_cache_dtype == "fp8":
self.kv_cache_dtype = torch.uint8
else:
self.kv_cache_dtype = self.dtype
if kv_cache_dtype == "fp8" and SYSTEM == "rocm":
logger.info(f"Using KV cache data type: {kv_cache_dtype}")
# Currently scaled KV cache is only enabled on ROCm
if quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling "
"factors provided but model "
f"{self.model.__class__} does not "
"support loading scaling factors."
)
else:
logger.info(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
elif quantization_param_path is not None:
logger.info(
"KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
@property @property
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
@ -782,7 +815,7 @@ class FlashCausalLM(Model):
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.sliding_window is not None, self.sliding_window is not None,
self.dtype, self.kv_cache_dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
@ -801,7 +834,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size() dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
@ -823,7 +856,7 @@ class FlashCausalLM(Model):
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.sliding_window is not None, self.sliding_window is not None,
self.dtype, self.kv_cache_dtype,
self.device, self.device,
) )
@ -878,7 +911,7 @@ class FlashCausalLM(Model):
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")

View File

@ -29,6 +29,8 @@ class FlashLlama(FlashCausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
kv_cache_dtype: Optional[str] = "auto",
quantization_param_path: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
@ -72,6 +74,7 @@ class FlashLlama(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
config.kv_cache_dtype = kv_cache_dtype
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -91,6 +94,8 @@ class FlashLlama(FlashCausalLM):
head_size=model.model.head_size, head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )

View File

@ -193,6 +193,8 @@ def serve(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
quantization_param_path: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
): ):
@ -203,6 +205,8 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = "auto",
quantization_param_path: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
@ -224,6 +228,8 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
quantization_param_path,
trust_remote_code, trust_remote_code,
) )
except Exception: except Exception:
@ -256,6 +262,6 @@ def serve(
set_model_id(model_id) set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, quantization_param_path, trust_remote_code
) )
) )

View File

@ -1,3 +1,4 @@
from loguru import logger
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -21,6 +22,8 @@ def reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
): ):
if SYSTEM == "xpu": if SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(
@ -28,7 +31,7 @@ def reshape_and_cache(
) )
else: else:
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
) )
@ -42,6 +45,8 @@ def attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights # Copyright 2023 The vLLM team. All rights
@ -99,8 +104,8 @@ def attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scale,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -132,6 +137,6 @@ def attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scale,
) )

View File

@ -0,0 +1,48 @@
from typing import Optional, Tuple, Iterable
from loguru import logger
import json
from text_generation_server.layers.schema import QuantParamSchema
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
return []