mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
# What does this PR do? Some models are already converted, and do not have those values in the file, this enables users to use them with less friction. Went for pure env based because adding flags would end up (imo) very tedious to maintain. There's a lot of sanitation to do: those flags would be errors if not used in conjuction with `--quantize gptq`. Then the flags need to exist in the launcher and the server passing them all throughout all function calls. This PR is intended as an easy escape hatch, not the defacto method to use gptq in TGI. Fixes #500
159 lines
5.9 KiB
Python
159 lines
5.9 KiB
Python
from pathlib import Path
|
|
from typing import List, Dict, Optional
|
|
from safetensors import safe_open, SafetensorError
|
|
import torch
|
|
|
|
|
|
class Weights:
|
|
def __init__(
|
|
self,
|
|
filenames: List[Path],
|
|
device,
|
|
dtype,
|
|
process_group,
|
|
aliases: Optional[Dict[str, List[str]]] = None,
|
|
):
|
|
routing = {}
|
|
for filename in filenames:
|
|
with safe_open(filename, framework="pytorch") as f:
|
|
for k in f.keys():
|
|
if k in routing:
|
|
raise RuntimeError(
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
)
|
|
routing[k] = filename
|
|
if aliases is None:
|
|
aliases = {}
|
|
self.aliases = aliases
|
|
self.routing = routing
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.process_group = process_group
|
|
self._handles = {}
|
|
|
|
def _get_handle(self, filename):
|
|
if filename not in self._handles:
|
|
f = safe_open(filename, framework="pytorch")
|
|
self._handles[filename] = f
|
|
|
|
return self._handles[filename]
|
|
|
|
def get_filename(self, tensor_name: str) -> (str, str):
|
|
filename = self.routing.get(tensor_name, None)
|
|
if filename is None:
|
|
aliases = self.aliases.get(tensor_name, [])
|
|
for alias in aliases:
|
|
filename = self.routing.get(alias, None)
|
|
if filename is not None:
|
|
return str(filename), alias
|
|
raise RuntimeError(f"weight {tensor_name} does not exist")
|
|
return str(filename), tensor_name
|
|
|
|
def _get_slice(self, tensor_name: str):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
return slice_
|
|
|
|
def get_shape(self, tensor_name: str):
|
|
return self._get_slice(tensor_name).get_shape()
|
|
|
|
def get_tensor(self, tensor_name: str):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
tensor = f.get_tensor(tensor_name)
|
|
# Special case for gptq which shouldn't convert
|
|
# u4 which are disguised as int32
|
|
if tensor.dtype not in [torch.int32, torch.int64]:
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_sharded(self, tensor_name: str, dim: int):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
world_size = self.process_group.size()
|
|
rank = self.process_group.rank()
|
|
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
size = slice_.get_shape()[dim]
|
|
block_size = size // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
|
|
assert (
|
|
size % world_size == 0
|
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
|
|
|
if dim == 0:
|
|
tensor = slice_[start:stop]
|
|
elif dim == 1:
|
|
tensor = slice_[:, start:stop]
|
|
else:
|
|
raise NotImplementedError("Let's make that generic when needed")
|
|
# Special case for gptq which shouldn't convert
|
|
# u4 which are disguised as int32
|
|
if tensor.dtype != torch.int32:
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
|
if quantize == "gptq":
|
|
try:
|
|
qweight = torch.cat(
|
|
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
|
)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
)
|
|
|
|
qzeros = torch.cat(
|
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
)
|
|
scales = torch.cat(
|
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
)
|
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
for w2 in w[1:]:
|
|
torch.testing.assert_close(w2, w[0])
|
|
g_idx = w[0]
|
|
|
|
try:
|
|
bits = self.get_tensor("gptq_bits").item()
|
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
|
except SafetensorError as e:
|
|
try:
|
|
import os
|
|
|
|
bits = int(os.getenv("GTPQ_BITS"))
|
|
groupsize = int(os.getenv("GTPQ_GROUPSIZE"))
|
|
except Exception:
|
|
raise e
|
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
|
else:
|
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
|
weight = torch.cat(w, dim=dim)
|
|
return weight
|
|
|
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
|
if quantize == "gptq":
|
|
try:
|
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
)
|
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
|
scales = self.get_tensor(f"{prefix}.scales")
|
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
|
|
bits = self.get_tensor("gptq_bits").item()
|
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
|
|
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
|
else:
|
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
|
return weight
|