mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
|
from typing import Optional, Tuple, Iterable
|
||
|
from loguru import logger
|
||
|
import json
|
||
|
from text_generation_server.layers.schema import QuantParamSchema
|
||
|
|
||
|
|
||
|
def kv_cache_scales_loader(
|
||
|
filename: str,
|
||
|
tp_rank: int,
|
||
|
tp_size: int,
|
||
|
num_hidden_layers: int,
|
||
|
model_type: Optional[str],
|
||
|
) -> Iterable[Tuple[int, float]]:
|
||
|
"""
|
||
|
A simple utility to read in KV cache scaling factors that have been
|
||
|
previously serialized to disk. Used by the model to populate the appropriate
|
||
|
KV cache scaling factors. The serialization should represent a dictionary
|
||
|
whose keys are the TP ranks and values are another dictionary mapping layers
|
||
|
to their KV cache scaling factors.
|
||
|
"""
|
||
|
try:
|
||
|
with open(filename) as f:
|
||
|
context = {
|
||
|
"model_type": model_type,
|
||
|
"num_hidden_layers": num_hidden_layers,
|
||
|
"tp_rank": tp_rank,
|
||
|
"tp_size": tp_size,
|
||
|
}
|
||
|
schema_dct = json.load(f)
|
||
|
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
||
|
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||
|
return layer_scales_map.items()
|
||
|
|
||
|
except FileNotFoundError:
|
||
|
logger.error(f"File or directory '{filename}' not found.")
|
||
|
except json.JSONDecodeError:
|
||
|
logger.error(f"Error decoding JSON in file '{filename}'.")
|
||
|
except Exception as e:
|
||
|
logger.error(f"An error occurred while reading '{filename}': {e}")
|
||
|
# This section is reached if and only if any of the excepts are hit
|
||
|
# Return an empty iterable (list) => no KV cache scales are loaded
|
||
|
# which ultimately defaults to 1.0 scales
|
||
|
logger.warning(
|
||
|
"Defaulting to KV cache scaling factors = 1.0 "
|
||
|
f"for all layers in TP rank {tp_rank} "
|
||
|
"as an error occurred during loading."
|
||
|
)
|
||
|
return []
|