Update solution to account for GPTQ.

This commit is contained in:
Nicolas Patry 2023-09-08 14:36:49 +00:00
parent 2a1f306e26
commit e349f57d10
4 changed files with 106 additions and 38 deletions

View File

@ -149,6 +149,26 @@ class LlamaRMSNorm(nn.Module):
return normed_hidden_states, res
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=False,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
@ -205,39 +225,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
try:
def load_packed(cls, config, prefix: str, weights, bias: bool):
packed_tensor = weights.get_tensor(prefix, to_device=False)
#QKV
total_size = packed_tensor.size()[0]
single_size = total_size // 3
q_tensor = packed_tensor[0: single_size, :]
k_tensor = packed_tensor[single_size: single_size * 2, :]
v_tensor = packed_tensor[single_size * 2 : total_size, :]
q_weight = weights.get_tensor_shard(q_tensor, dim = 0)
k_weight = weights.get_tensor_shard(k_tensor, dim = 0)
v_weight = weights.get_tensor_shard(v_tensor, dim = 0)
weight = torch.concat([q_weight, k_weight, v_weight], dim = 0)
return cls(get_linear(weight, None, config.quantize))
self.query_key_value = load_packed(TensorParallelColumnLinear,
config,
prefix=f"{prefix}.W_pack.weight",
weights=weights,
bias=False,
)
except:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,

View File

@ -29,9 +29,15 @@ def _remove_duplicate_names(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)
if len(shared) == 1:
# Force contiguous
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
else:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)
keep_name = sorted(list(complete_names))[0]

View File

@ -324,6 +324,19 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix, quantize=config.quantize
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
return cls.load_multi(config, [prefix], weights, bias, dim=0)

View File

@ -76,11 +76,11 @@ class Weights:
def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
@ -110,6 +110,66 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
single_size = total_size // 3
world_size = self.process_group.size()
rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[:, start:stop]
k = slice_[:, start+single_size:stop+single_size]
v = slice_[:, start+2*single_size:stop+2*single_size]
weight = torch.cat([q,k,v], dim=1)
weight = weight.to(device=self.device)
return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
"""
if quantize == "gptq":
try:
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
g_idx = self.get_tensor(f"{prefix}.g_idx")
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
single_size = total_size // 3
world_size = self.process_group.size()
rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[start:stop]
k = slice_[start+single_size:stop+single_size]
v = slice_[start+2*single_size:stop+2*single_size]
weight = torch.cat([q,k,v], dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try: