Support GPTQ models with column-packed up/gate tensor

The GPTQ code path for column-packed packed tensors assumed that this is
always a QKV matrix. However, models (e.g. Phi-3) can also have
column-packed MLP up/gate matrices.
This commit is contained in:
Daniël de Kok 2024-06-04 15:16:15 +00:00
parent df71aafdcc
commit b5f7f98dd8

View File

@ -121,24 +121,30 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str):
def _get_qweight(self, name: str, blocks: int):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
single_size = total_size // 3
assert (
total_size % blocks == 0
), f"Prepacked quantized matrix is not divisible by {blocks}"
single_size = total_size // blocks
world_size = self.process_group.size()
rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[:, start:stop]
k = slice_[:, start + single_size : stop + single_size]
v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=1)
weights = []
for block in range(blocks):
weights.append(
slice_[:, start + block * single_size : stop + block * single_size]
)
weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device)
return weight
@ -157,7 +163,7 @@ class Weights:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight = self._get_qweight(f"{prefix}.qweight")
qweight = self._get_qweight(f"{prefix}.qweight", blocks)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
@ -165,8 +171,8 @@ class Weights:
bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
scales = self._get_qweight(f"{prefix}.scales", blocks)
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq":