mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixing GTPQ device santacoder.
This commit is contained in:
parent
7faef69015
commit
900ac49454
@ -49,6 +49,7 @@ def _load_multi_mqa_gptq(
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qweight = qweight.to(device=weights.device)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
||||
shape = slice_.get_shape()
|
||||
@ -59,6 +60,7 @@ def _load_multi_mqa_gptq(
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
scales = scales.to(device=weights.device)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
||||
shape = slice_.get_shape()
|
||||
@ -69,8 +71,10 @@ def _load_multi_mqa_gptq(
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qzeros = qzeros.to(device=weights.device)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
bits, groupsize = weights._get_gptq_qparams()
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
@ -88,6 +92,7 @@ def _load_multi_mqa_gptq(
|
||||
q_tensor = slice_[start:stop]
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
bias = bias.to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user