text-generation-inference/server/text_generation_server/utils/weights_utils.py

49 lines
1.8 KiB
Python
Raw Normal View History

2024-05-23 16:00:18 +00:00
from typing import Optional, Tuple, Iterable
from loguru import logger
import json
from text_generation_server.layers.schema import QuantParamSchema
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
return []