Fixing GTPQ device santacoder.

This commit is contained in:
Nicolas Patry 2023-07-20 19:08:33 +00:00
parent 7faef69015
commit 900ac49454

View File

@ -49,6 +49,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
qweight = qweight.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
@ -59,6 +60,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
scales = scales.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
@ -69,8 +71,10 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_qparams()
from text_generation_server.utils.layers import HAS_EXLLAMA
@ -88,6 +92,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else: