diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 5b1c6e21..ff73e24f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,12 +20,12 @@ from text_generation_server.utils.layers import ( FastLayerNorm, get_linear, ) +from safetensors import SafetensorError def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if config.quantize == "gptq": return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -74,8 +74,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - bits = weights.get_tensor("gptq_bits").item() - groupsize = weights.get_tensor("gptq_groupsize").item() + try: + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GTPQ_BITS")) + groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -99,7 +108,6 @@ def _load_multi_mqa_gptq( def _load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 83d9df68..39f66862 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import List, Dict, Optional -from safetensors import safe_open +from safetensors import safe_open, SafetensorError import torch @@ -120,8 +120,17 @@ class Weights: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GTPQ_BITS")) + groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]