mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fixing GTPQ device santacoder.
This commit is contained in:
parent
7faef69015
commit
900ac49454
@ -49,6 +49,7 @@ def _load_multi_mqa_gptq(
|
|||||||
q_tensor = slice_[:, start:stop]
|
q_tensor = slice_[:, start:stop]
|
||||||
kv_tensor = slice_[:, -2 * head_size :]
|
kv_tensor = slice_[:, -2 * head_size :]
|
||||||
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
qweight = qweight.to(device=weights.device)
|
||||||
|
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -59,6 +60,7 @@ def _load_multi_mqa_gptq(
|
|||||||
q_tensor = slice_[:, start:stop]
|
q_tensor = slice_[:, start:stop]
|
||||||
kv_tensor = slice_[:, -2 * head_size :]
|
kv_tensor = slice_[:, -2 * head_size :]
|
||||||
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
scales = scales.to(device=weights.device)
|
||||||
|
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -69,8 +71,10 @@ def _load_multi_mqa_gptq(
|
|||||||
q_tensor = slice_[:, start:stop]
|
q_tensor = slice_[:, start:stop]
|
||||||
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
qzeros = qzeros.to(device=weights.device)
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
|
g_idx = g_idx.to(device=weights.device)
|
||||||
bits, groupsize = weights._get_gptq_qparams()
|
bits, groupsize = weights._get_gptq_qparams()
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
@ -88,6 +92,7 @@ def _load_multi_mqa_gptq(
|
|||||||
q_tensor = slice_[start:stop]
|
q_tensor = slice_[start:stop]
|
||||||
kv_tensor = slice_[-2 * head_size :]
|
kv_tensor = slice_[-2 * head_size :]
|
||||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||||
|
bias = bias.to(device=weights.device)
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user