diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4603d577..04bd422f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -49,6 +49,7 @@ def _load_multi_mqa_gptq( q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] qweight = torch.cat([q_tensor, kv_tensor], dim=1) + qweight = qweight.to(device=weights.device) slice_ = weights._get_slice(f"{prefix}.c_attn.scales") shape = slice_.get_shape() @@ -59,6 +60,7 @@ def _load_multi_mqa_gptq( q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] scales = torch.cat([q_tensor, kv_tensor], dim=1) + scales = scales.to(device=weights.device) slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") shape = slice_.get_shape() @@ -69,8 +71,10 @@ def _load_multi_mqa_gptq( q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] qzeros = torch.cat([q_tensor, kv_tensor], dim=1) + qzeros = qzeros.to(device=weights.device) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = g_idx.to(device=weights.device) bits, groupsize = weights._get_gptq_qparams() from text_generation_server.utils.layers import HAS_EXLLAMA @@ -88,6 +92,7 @@ def _load_multi_mqa_gptq( q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) + bias = bias.to(device=weights.device) return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) else: