mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into quantization_docs
This commit is contained in:
commit
6473cf852e
@ -23,6 +23,8 @@
|
||||
title: Streaming
|
||||
- local: conceptual/quantization
|
||||
title: Quantization
|
||||
- local: conceptual/tensor_parallelism
|
||||
title: Tensor Parallelism
|
||||
- local: conceptual/paged_attention
|
||||
title: PagedAttention
|
||||
- local: conceptual/safetensors
|
||||
|
14
docs/source/conceptual/tensor_parallelism.md
Normal file
14
docs/source/conceptual/tensor_parallelism.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Tensor Parallelism
|
||||
|
||||
Tensor parallelism is a technique used to fit a large model in multiple GPUs. For example, when multiplying the input tensors with the first weight tensor, the matrix multiplication is equivalent to splitting the weight tensor column-wise, multiplying each column with the input separately, and then concatenating the separate outputs. These outputs are then transferred from the GPUs and concatenated together to get the final result, like below 👇
|
||||
|
||||

|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Tensor Parallelism only works for [models officially supported](../supported_models), it will not work when falling back to `transformers`. You can get more information about unsupported models [here](../basic_tutorials/non_core_models).
|
||||
|
||||
</Tip>
|
||||
|
||||
You can learn a lot more details about tensor-parallelism from [the `transformers` docs](https://huggingface.co/docs/transformers/main/en/perf_train_gpu_many#tensor-parallelism).
|
@ -182,7 +182,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "llama":
|
||||
elif model_type == "llama" or model_type == "baichuan":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
|
@ -149,6 +149,26 @@ class LlamaRMSNorm(nn.Module):
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
@ -205,16 +225,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
|
@ -29,9 +29,15 @@ def _remove_duplicate_names(
|
||||
[name for name in shared if _is_complete(state_dict[name])]
|
||||
)
|
||||
if not complete_names:
|
||||
raise RuntimeError(
|
||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||
)
|
||||
if len(shared) == 1:
|
||||
# Force contiguous
|
||||
name = list(shared)[0]
|
||||
state_dict[name] = state_dict[name].clone()
|
||||
complete_names = {name}
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||
)
|
||||
|
||||
keep_name = sorted(list(complete_names))[0]
|
||||
|
||||
|
@ -331,6 +331,19 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
@ -62,7 +62,7 @@ class Weights:
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str):
|
||||
def get_tensor(self, tensor_name: str, to_device = True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
@ -70,16 +70,17 @@ class Weights:
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
@ -109,6 +110,66 @@ class Weights:
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
|
||||
def _get_qweight(self, name: str):
|
||||
slice_ = self._get_slice(name)
|
||||
total_size = slice_.get_shape()[1]
|
||||
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[:, start:stop]
|
||||
k = slice_[:, start+single_size:stop+single_size]
|
||||
v = slice_[:, start+2*single_size:stop+2*single_size]
|
||||
weight = torch.cat([q,k,v], dim=1)
|
||||
weight = weight.to(device=self.device)
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
"""
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[start:stop]
|
||||
k = slice_[start+single_size:stop+single_size]
|
||||
v = slice_[start+2*single_size:stop+2*single_size]
|
||||
weight = torch.cat([q,k,v], dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
@ -137,6 +198,22 @@ class Weights:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
block_size = var.size()[dim] // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
if dim == 0:
|
||||
tensor = var[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = var[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
|
Loading…
Reference in New Issue
Block a user