mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
145 lines
6.0 KiB
Python
145 lines
6.0 KiB
Python
from pathlib import Path
|
|
from typing import List, Dict, Optional
|
|
from safetensors import safe_open
|
|
import torch
|
|
|
|
class Weights:
|
|
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
|
routing = {}
|
|
for filename in filenames:
|
|
with safe_open(filename, framework="pytorch") as f:
|
|
for k in f.keys():
|
|
if k in routing:
|
|
raise RuntimeError(
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
)
|
|
routing[k] = filename
|
|
if aliases is None:
|
|
aliases = {}
|
|
self.aliases = aliases
|
|
self.routing = routing
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.process_group = process_group
|
|
self._handles = {}
|
|
|
|
def _get_handle(self, filename):
|
|
if filename not in self._handles:
|
|
f = safe_open(filename, framework="pytorch")
|
|
self._handles[filename] = f
|
|
|
|
return self._handles[filename]
|
|
|
|
def get_filename(self, tensor_name: str) -> (str, str):
|
|
filename = self.routing.get(tensor_name, None)
|
|
if filename is None:
|
|
aliases = self.aliases.get(tensor_name, [])
|
|
for alias in aliases:
|
|
filename = self.routing.get(alias, None)
|
|
if filename is not None:
|
|
return str(filename), alias
|
|
raise RuntimeError(f"weight {tensor_name} does not exist")
|
|
return str(filename), tensor_name
|
|
|
|
def _get_slice(self, tensor_name: str):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
return slice_
|
|
|
|
def get_shape(self, tensor_name: str):
|
|
return self._get_slice(tensor_name).get_shape()
|
|
|
|
def get_tensor(self, tensor_name: str):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
tensor = f.get_tensor(tensor_name)
|
|
# Special case for gptq which shouldn't convert
|
|
# u4 which are disguised as int32
|
|
if tensor.dtype not in [torch.int32, torch.int64]:
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_sharded(self, tensor_name: str, dim: int):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
world_size = self.process_group.size()
|
|
rank = self.process_group.rank()
|
|
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
size = slice_.get_shape()[dim]
|
|
block_size = size // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
|
|
assert (
|
|
size % world_size == 0
|
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
|
|
|
if dim == 0:
|
|
tensor = slice_[start:stop]
|
|
elif dim == 1:
|
|
tensor = slice_[:, start:stop]
|
|
else:
|
|
raise NotImplementedError("Let's make that generic when needed")
|
|
# Special case for gptq which shouldn't convert
|
|
# u4 which are disguised as int32
|
|
if tensor.dtype != torch.int32:
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
|
if quantize in ["gptq", "gptq-cuda"]:
|
|
try:
|
|
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
|
except RuntimeError:
|
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
|
|
|
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
|
|
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
|
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
for w2 in w[1:]:
|
|
torch.testing.assert_close(w2, w[0])
|
|
g_idx = w[0]
|
|
|
|
bits = self.get_tensor("gptq_bits").item()
|
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
|
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
|
|
else:
|
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
|
weight = torch.cat(w, dim=dim)
|
|
return weight
|
|
|
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
|
if quantize in ["gptq", "gptq-cuda"]:
|
|
try:
|
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
|
except RuntimeError:
|
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
|
|
|
if quantize == "gptq":
|
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
|
scales = self.get_tensor(f"{prefix}.scales")
|
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
else:
|
|
# Exllama reorders the weights in advance and the activations on the fly, thus
|
|
# the scales and zero-points do not need to be reordered
|
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
|
|
|
# For tp > 1, at this point we know we do not use act-order
|
|
if self.process_group.size() == 1:
|
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
else:
|
|
g_idx = None
|
|
|
|
bits = self.get_tensor("gptq_bits").item()
|
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
|
|
|
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
|
|
else:
|
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
|
return weight
|