mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
* Add support for FP8 KV cache scales Since FP8 only has limited dynamic range, we can scale keys/values before storing them into the cache (and unscale them in attention). To avoid rescaling the cache as the absmax values change, good scales are usually determined per layer using calibration calibration data and stored in the checkpoint. This change adds support for for using key-value scales and loading them from checkpoints in the two most common formats: - Separate per-layer `k_scale` and `v_scale` scalars. - Per-layer `kv_scale` scalar (older format). Currently, scales are only used with an `float8_e4m3fn` cache. Besides adding support for key/value scales, the `fp8_quantize` function is also extended to support quantization with a kernel vendored from vLLM. This is slightly faster than the PyTorch implementation, but also scales in FP32, potentially improving accuracy. * Update FP8 KV cache test to use checkpoint with scales * `can_scale`: check that the attention is flashinfer
440 lines
15 KiB
Python
440 lines
15 KiB
Python
import torch
|
|
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union, Type
|
|
from safetensors import safe_open
|
|
from dataclasses import dataclass
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
|
|
class WeightsLoader(ABC):
|
|
"""
|
|
Instances of this type implement higher-level weight loading.
|
|
|
|
At a low-level, every weight is stored in the Safetensors format.
|
|
The interpretation of weights may be different however, for instance
|
|
could be packed, quantized weights. Loaders are responsible for
|
|
interpreting the raw tensors, sharding tensors in a manner compatible
|
|
with the format, etc.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_weights(self, weights: "Weights", prefix: str):
|
|
"""
|
|
Get weights at the given prefix and apply without tensor paralllism.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: "Weights",
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
"""
|
|
Get the packed weights at the given prefix with column-splitting for
|
|
tensor parallelism. This method should be used when multiple different
|
|
weights are packed into a tensor, for instance, query/key/value
|
|
weights or a gate/up projection.
|
|
|
|
The `block_sizes` determines the proportions of the packed tensors.
|
|
The columns are split in equally sized blocks when `block_sizes` is an
|
|
`int`, or in blocks proportional given to the sizes. For instance
|
|
`[2, 1, 1]` will divide an input with dimensionality `1024` in
|
|
`[512, 256, 256]`.
|
|
"""
|
|
...
|
|
|
|
def get_weights_col(self, weights: "Weights", prefix: str):
|
|
"""
|
|
Get weights at the given prefix and apply column-splitting for tensor
|
|
paralllism.
|
|
"""
|
|
return weights.get_multi_weights_col([prefix], 0)
|
|
|
|
@abstractmethod
|
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
|
"""
|
|
Get the weights at the given prefixes, column-split them for tensor
|
|
parallelim, and then concatenate the weights along the given dimension.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
|
"""
|
|
Get the weights at the given prefix and apply row-splitting for tensor
|
|
parallism.
|
|
"""
|
|
...
|
|
|
|
|
|
class Weight(ABC):
|
|
"""Instances of this type implement unquantized/quantized/to-be
|
|
quantized weights."""
|
|
|
|
@abstractmethod
|
|
def get_linear(self, bias: torch.Tensor):
|
|
"""Create a linear layer from this weight."""
|
|
...
|
|
|
|
|
|
@dataclass
|
|
class UnquantizedWeight(Weight):
|
|
weight: torch.Tensor
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
|
|
|
if SYSTEM == "rocm":
|
|
return FastLinearROCm(self.weight, bias)
|
|
else:
|
|
return FastLinear(self.weight, bias)
|
|
|
|
|
|
class DefaultWeightsLoader(WeightsLoader):
|
|
"""Weight loader that loads (unquantized) Torch tensors."""
|
|
|
|
def __init__(self, weight_class: Type[UnquantizedWeight]):
|
|
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
|
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
|
such as `Fp8Weight` can be used to quantize the weights during loading.
|
|
"""
|
|
self.weight_class = weight_class
|
|
|
|
"""
|
|
Loader that uses tensors as-is with the exception of applying sharding
|
|
and/or concatenation.
|
|
"""
|
|
|
|
def get_weights(self, weights: "Weights", prefix: str):
|
|
return weights.get_tensor(f"{prefix}.weight")
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: "Weights",
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
return self.weight_class(
|
|
weights.get_packed_sharded(
|
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
|
),
|
|
)
|
|
|
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
|
return self.weight_class(torch.cat(w, dim=dim))
|
|
|
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
|
return self.weight_class(
|
|
weights.get_sharded(f"{prefix}.weight", dim=1),
|
|
)
|
|
|
|
|
|
class Weights:
|
|
def __init__(
|
|
self,
|
|
filenames: List[Path],
|
|
device,
|
|
dtype,
|
|
process_group,
|
|
weights_loader: WeightsLoader,
|
|
aliases: Optional[Dict[str, List[str]]] = None,
|
|
prefix: Optional[str] = None,
|
|
):
|
|
routing = {}
|
|
for filename in filenames:
|
|
with safe_open(filename, framework="pytorch") as f:
|
|
for k in f.keys():
|
|
if k in routing:
|
|
raise RuntimeError(
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
)
|
|
routing[k] = filename
|
|
if aliases is None:
|
|
aliases = {}
|
|
self.aliases = aliases
|
|
self.routing = routing
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.process_group = process_group
|
|
self.prefix = prefix
|
|
self.weights_loader = weights_loader
|
|
self._handles = {}
|
|
|
|
def _get_handle(self, filename):
|
|
if filename not in self._handles:
|
|
f = safe_open(filename, framework="pytorch")
|
|
self._handles[filename] = f
|
|
|
|
return self._handles[filename]
|
|
|
|
def get_filename(self, tensor_name: str) -> (str, str):
|
|
names = [tensor_name]
|
|
if self.prefix is not None:
|
|
prefixed = f"{self.prefix}.{tensor_name}"
|
|
names.append(prefixed)
|
|
for name in names:
|
|
filename = self.routing.get(name, None)
|
|
if filename is not None:
|
|
return str(filename), name
|
|
|
|
aliases = self.aliases.get(name, [])
|
|
for alias in aliases:
|
|
filename = self.routing.get(alias, None)
|
|
if filename is not None:
|
|
return str(filename), alias
|
|
raise RuntimeError(f"weight {tensor_name} does not exist")
|
|
|
|
def _get_slice(self, tensor_name: str):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
return slice_
|
|
|
|
def has_tensor(self, tensor_name: str):
|
|
try:
|
|
self.get_filename(tensor_name)
|
|
except Exception:
|
|
return False
|
|
return True
|
|
|
|
def get_shape(self, tensor_name: str):
|
|
return self._get_slice(tensor_name).get_shape()
|
|
|
|
def get_tensor(
|
|
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
|
|
) -> torch.Tensor:
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
tensor = f.get_tensor(tensor_name)
|
|
# Special case for gptq which shouldn't convert
|
|
# u4 which are disguised as int32. Exl2 uses int16
|
|
# as well. FP8 uses torch.float8_e4m3fn
|
|
if (
|
|
tensor.dtype
|
|
not in [
|
|
torch.float8_e4m3fn,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
]
|
|
and to_dtype
|
|
):
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
if to_device:
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_partial_sharded(
|
|
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
|
|
):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
world_size = self.process_group.size()
|
|
rank = self.process_group.rank()
|
|
|
|
size = slice_.get_shape()[dim]
|
|
block_size = (size + world_size - 1) // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
|
|
if dim == 0:
|
|
tensor = slice_[start:stop]
|
|
elif dim == 1:
|
|
tensor = slice_[:, start:stop]
|
|
else:
|
|
raise NotImplementedError("Let's make that generic when needed")
|
|
# Special case for gptq which shouldn't convert
|
|
# u4 which are disguised as int32. exl2 uses int16.
|
|
# FP8 uses torch.float8_e4m3fn.
|
|
if (
|
|
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
|
and to_dtype
|
|
):
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
if to_device:
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
|
|
filename, tensor_name = self.get_filename(tensor_name)
|
|
f = self._get_handle(filename)
|
|
slice_ = f.get_slice(tensor_name)
|
|
world_size = self.process_group.size()
|
|
size = slice_.get_shape()[dim]
|
|
assert (
|
|
size % world_size == 0
|
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
|
return self.get_partial_sharded(
|
|
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
|
|
)
|
|
|
|
def get_packed_sharded(
|
|
self,
|
|
tensor_name: str,
|
|
dim: int,
|
|
block_sizes: Union[int, List[int]],
|
|
to_dtype=True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Get a shard from a tensor that packs multiple tensors.
|
|
|
|
When a tensor packs multiple tensors (such as QKV or an up
|
|
projection + gate projection), sharding with `get_sharded` is not
|
|
safe since it would not split the packed tensors across shards.
|
|
|
|
This method shards a tensor, such that the packed tensors are
|
|
split across shards.
|
|
|
|
The columns are split in equally sized blocks when blocks is an `int`, or
|
|
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
|
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
|
convenient for e.g. splitting QKV without knowing the storage details of
|
|
quantized weights.
|
|
"""
|
|
slice_ = self._get_slice(tensor_name)
|
|
total_size = slice_.get_shape()[dim]
|
|
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
|
|
|
world_size = self.process_group.size()
|
|
rank = self.process_group.rank()
|
|
|
|
tensors = []
|
|
block_offset = 0
|
|
for block_size in block_sizes:
|
|
assert (
|
|
block_size % world_size == 0
|
|
), f"Prepacked tensor cannot be sharded across {world_size} shards"
|
|
shard_block_size = block_size // world_size
|
|
start = rank * shard_block_size
|
|
stop = (rank + 1) * shard_block_size
|
|
if dim == 0:
|
|
tensor = slice_[block_offset + start : block_offset + stop]
|
|
elif dim == 1:
|
|
tensor = slice_[:, block_offset + start : block_offset + stop]
|
|
else:
|
|
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
|
|
tensors.append(tensor)
|
|
block_offset += block_size
|
|
tensor = torch.cat(tensors, dim=dim)
|
|
tensor = tensor.to(device=self.device)
|
|
|
|
# Avoid casting quantizer dtypes.
|
|
if (
|
|
tensor.dtype
|
|
not in [
|
|
torch.float8_e4m3fn,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
]
|
|
and to_dtype
|
|
):
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
|
|
return tensor
|
|
|
|
def get_weights(self, prefix: str):
|
|
return self.weights_loader.get_weights(self, prefix)
|
|
|
|
def get_weights_col_packed_qkv(
|
|
self,
|
|
prefix: str,
|
|
num_heads: int,
|
|
num_key_value_heads: int,
|
|
):
|
|
return self.get_weights_col_packed(
|
|
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
|
|
)
|
|
|
|
def get_weights_col_packed_gate_up(self, prefix: str):
|
|
return self.get_weights_col_packed(prefix, 2)
|
|
|
|
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
|
|
"""
|
|
The columns are split in equally sized blocks when blocks is an `int`, or
|
|
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
|
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
|
convenient for e.g. splitting QKV without knowing the storage details of
|
|
quantized weights.
|
|
"""
|
|
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
|
|
|
|
def get_weights_col(self, prefix: str):
|
|
return self.weights_loader.get_weights_col(self, prefix)
|
|
|
|
def get_multi_weights_col(self, prefixes: List[str], dim: int):
|
|
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
|
|
|
|
def get_tensor_shard(self, var, dim):
|
|
world_size = self.process_group.size()
|
|
rank = self.process_group.rank()
|
|
block_size = var.size()[dim] // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
if dim == 0:
|
|
tensor = var[start:stop]
|
|
elif dim == 1:
|
|
tensor = var[:, start:stop]
|
|
else:
|
|
raise NotImplementedError("Let's make that generic when needed")
|
|
tensor = tensor.to(dtype=self.dtype)
|
|
tensor = tensor.to(device=self.device)
|
|
return tensor
|
|
|
|
def get_weights_row(self, prefix: str):
|
|
return self.weights_loader.get_weights_row(self, prefix)
|
|
|
|
@contextmanager
|
|
def use_loader(self, weights_loader: WeightsLoader):
|
|
"""
|
|
This method is a context manager that can be used to use `Weights` with
|
|
a different loader for the duration of the context.
|
|
"""
|
|
|
|
old_loader = self.weights_loader
|
|
self.weights_loader = weights_loader
|
|
try:
|
|
yield
|
|
finally:
|
|
self.weights_loader = old_loader
|
|
|
|
@property
|
|
def loader(self):
|
|
return self.weights_loader
|
|
|
|
|
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
|
"""
|
|
Convert block count or proportions to block sizes.
|
|
|
|
This function accepts
|
|
|
|
- The number of blocks (int), in which case the block size is
|
|
total_size//blocks; or
|
|
- A list of block sizes (List[int]).
|
|
|
|
In the latter case, if sum(blocks) < total_size, the ratios between
|
|
the block sizes will be preserved. For instance, if blocks is
|
|
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
|
[512, 256, 256].
|
|
"""
|
|
if isinstance(blocks, list):
|
|
total_blocks = sum(blocks)
|
|
assert (
|
|
total_size % total_blocks == 0
|
|
), f"Cannot split {total_size} in proportional blocks: {blocks}"
|
|
part_size = total_size // total_blocks
|
|
return [part_size * block for block in blocks]
|
|
else:
|
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
|
single_size = total_size // blocks
|
|
return [single_size] * blocks
|