mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Update solution to account for GPTQ.
This commit is contained in:
parent
2a1f306e26
commit
e349f57d10
@ -149,6 +149,26 @@ class LlamaRMSNorm(nn.Module):
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
@ -205,39 +225,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
try:
|
||||
def load_packed(cls, config, prefix: str, weights, bias: bool):
|
||||
packed_tensor = weights.get_tensor(prefix, to_device=False)
|
||||
#QKV
|
||||
total_size = packed_tensor.size()[0]
|
||||
single_size = total_size // 3
|
||||
q_tensor = packed_tensor[0: single_size, :]
|
||||
k_tensor = packed_tensor[single_size: single_size * 2, :]
|
||||
v_tensor = packed_tensor[single_size * 2 : total_size, :]
|
||||
q_weight = weights.get_tensor_shard(q_tensor, dim = 0)
|
||||
k_weight = weights.get_tensor_shard(k_tensor, dim = 0)
|
||||
v_weight = weights.get_tensor_shard(v_tensor, dim = 0)
|
||||
weight = torch.concat([q_weight, k_weight, v_weight], dim = 0)
|
||||
return cls(get_linear(weight, None, config.quantize))
|
||||
|
||||
self.query_key_value = load_packed(TensorParallelColumnLinear,
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack.weight",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
except:
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
|
@ -29,9 +29,15 @@ def _remove_duplicate_names(
|
||||
[name for name in shared if _is_complete(state_dict[name])]
|
||||
)
|
||||
if not complete_names:
|
||||
raise RuntimeError(
|
||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||
)
|
||||
if len(shared) == 1:
|
||||
# Force contiguous
|
||||
name = list(shared)[0]
|
||||
state_dict[name] = state_dict[name].clone()
|
||||
complete_names = {name}
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||
)
|
||||
|
||||
keep_name = sorted(list(complete_names))[0]
|
||||
|
||||
|
@ -324,6 +324,19 @@ class TensorParallelHead(SuperLayer):
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
@ -76,11 +76,11 @@ class Weights:
|
||||
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
@ -110,6 +110,66 @@ class Weights:
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
|
||||
def _get_qweight(self, name: str):
|
||||
slice_ = self._get_slice(name)
|
||||
total_size = slice_.get_shape()[1]
|
||||
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[:, start:stop]
|
||||
k = slice_[:, start+single_size:stop+single_size]
|
||||
v = slice_[:, start+2*single_size:stop+2*single_size]
|
||||
weight = torch.cat([q,k,v], dim=1)
|
||||
weight = weight.to(device=self.device)
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
"""
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[start:stop]
|
||||
k = slice_[start+single_size:stop+single_size]
|
||||
v = slice_[start+2*single_size:stop+2*single_size]
|
||||
weight = torch.cat([q,k,v], dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user