From e349f57d101df5d31b13926fab119afe6d0bd975 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 8 Sep 2023 14:36:49 +0000 Subject: [PATCH] Update solution to account for GPTQ. --- .../custom_modeling/flash_llama_modeling.py | 55 +++++++--------- .../text_generation_server/utils/convert.py | 12 +++- server/text_generation_server/utils/layers.py | 13 ++++ .../text_generation_server/utils/weights.py | 64 ++++++++++++++++++- 4 files changed, 106 insertions(+), 38 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 41db4fa7..55b1aae9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -149,6 +149,26 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -205,39 +225,8 @@ class FlashLlamaAttention(torch.nn.Module): self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) - if config.num_attention_heads != config.num_key_value_heads: - self.query_key_value = _load_gqa(config, prefix, weights) - else: - try: - def load_packed(cls, config, prefix: str, weights, bias: bool): - packed_tensor = weights.get_tensor(prefix, to_device=False) - #QKV - total_size = packed_tensor.size()[0] - single_size = total_size // 3 - q_tensor = packed_tensor[0: single_size, :] - k_tensor = packed_tensor[single_size: single_size * 2, :] - v_tensor = packed_tensor[single_size * 2 : total_size, :] - q_weight = weights.get_tensor_shard(q_tensor, dim = 0) - k_weight = weights.get_tensor_shard(k_tensor, dim = 0) - v_weight = weights.get_tensor_shard(v_tensor, dim = 0) - weight = torch.concat([q_weight, k_weight, v_weight], dim = 0) - return cls(get_linear(weight, None, config.quantize)) - - self.query_key_value = load_packed(TensorParallelColumnLinear, - config, - prefix=f"{prefix}.W_pack.weight", - weights=weights, - bias=False, - ) - - except: - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) + + self.query_key_value = load_attention(config, prefix, weights) self.o_proj = TensorParallelRowLinear.load( config, diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 8d414eca..0b62f520 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -29,9 +29,15 @@ def _remove_duplicate_names( [name for name in shared if _is_complete(state_dict[name])] ) if not complete_names: - raise RuntimeError( - f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." - ) + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) keep_name = sorted(list(complete_names))[0] diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 745c1d2e..83c02459 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -324,6 +324,19 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + @classmethod def load(cls, config, prefix: str, weights, bias: bool): return cls.load_multi(config, [prefix], weights, bias, dim=0) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 25d7ff1d..bde23f72 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -76,11 +76,11 @@ class Weights: def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() - f = self._get_handle(filename) - slice_ = f.get_slice(tensor_name) size = slice_.get_shape()[dim] block_size = size // world_size start = rank * block_size @@ -110,6 +110,66 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) + + def _get_qweight(self, name: str): + slice_ = self._get_slice(name) + total_size = slice_.get_shape()[1] + assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[:, start:stop] + k = slice_[:, start+single_size:stop+single_size] + v = slice_[:, start+2*single_size:stop+2*single_size] + weight = torch.cat([q,k,v], dim=1) + weight = weight.to(device=self.device) + return weight + + def get_weights_col_packed_qkv(self, prefix: str, quantize: str): + """ + Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being + already alternating Q,K,V within the main tensor + """ + if quantize == "gptq": + try: + qweight = self._get_qweight(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") + scales = scales.to(dtype=self.dtype) + g_idx = self.get_tensor(f"{prefix}.g_idx") + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + slice_ = self._get_slice(f"{prefix}.weight") + total_size = slice_.get_shape()[0] + assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[start:stop] + k = slice_[start+single_size:stop+single_size] + v = slice_[start+2*single_size:stop+2*single_size] + weight = torch.cat([q,k,v], dim=0) + weight = weight.to(device=self.device) + weight = weight.to(dtype=self.dtype) + return weight + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: