mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from pathlib import Path
|
|
from text_generation_server.utils.hub import (
|
|
weight_files,
|
|
download_weights,
|
|
weight_hub_files,
|
|
)
|
|
from safetensors import safe_open
|
|
import argparse
|
|
|
|
|
|
def load_model(ckpt_path):
|
|
model_args = {"torch_dtype": "auto"}
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
ckpt_path, device_map="auto", **model_args, trust_remote_code=True
|
|
)
|
|
model.eval()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def set_nested_attribute(obj, attribute_path, value):
|
|
keys = attribute_path.split(".")
|
|
current_obj = obj
|
|
for key in keys[:-1]:
|
|
current_obj = getattr(current_obj, key)
|
|
setattr(current_obj, keys[-1], value)
|
|
|
|
|
|
def apply_kv_scales_to_model(model, layer_scales_map):
|
|
for layer_name, scale_value in layer_scales_map.items():
|
|
scale_param = torch.nn.Parameter(torch.tensor(scale_value), requires_grad=False)
|
|
set_nested_attribute(model, layer_name, scale_param)
|
|
|
|
|
|
def extract_kv_scales(quantized_model):
|
|
def fetch_parameters(filename):
|
|
with safe_open(filename, framework="pt") as f:
|
|
for name in f.keys():
|
|
param_tensor = f.get_tensor(name)
|
|
yield name, param_tensor
|
|
|
|
checkpoint_dir = Path(quantized_model)
|
|
if not checkpoint_dir.is_dir():
|
|
hub_filenames = weight_hub_files(quantized_model)
|
|
downloaded_files = download_weights(hub_filenames, quantized_model)
|
|
downloaded_files = weight_files(quantized_model, extension=".safetensors")
|
|
|
|
layer_scales_map = {}
|
|
for tensor_file in downloaded_files:
|
|
for name, param in fetch_parameters(tensor_file):
|
|
if ".kv_scale" in name:
|
|
layer_scales_map[name] = param.item()
|
|
|
|
return layer_scales_map
|
|
|
|
|
|
def main(quantized_model, model_id, save_path):
|
|
layer_scales_map = extract_kv_scales(quantized_model)
|
|
|
|
model, tokenizer = load_model(model_id)
|
|
|
|
apply_kv_scales_to_model(model, layer_scales_map)
|
|
|
|
model.save_pretrained(save_path)
|
|
tokenizer.save_pretrained(save_path)
|
|
|
|
print(f"Model saved to {save_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Extract FP8 KV cache scales and add them to a FP16 model."
|
|
)
|
|
parser.add_argument(
|
|
"--quantized-model",
|
|
type=str,
|
|
help="Path to the FP8 model checkpoint to extract KV cache scales",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="Model ID of the FP16 model to save the KV cache scales",
|
|
)
|
|
parser.add_argument(
|
|
"--save-path",
|
|
type=str,
|
|
help="Path to save the FP16 model with the kv scales",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args.quantized_model, args.model, args.save_path)
|