mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
parent
29a0893b67
commit
18c4607d46
@ -391,21 +391,6 @@ def get_model(
|
|||||||
)
|
)
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
transformers_causal_lm_class = CausalLM
|
|
||||||
|
|
||||||
# Fast transformers path
|
|
||||||
transformers_model_class = getattr(
|
|
||||||
transformers,
|
|
||||||
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
FLASH_TRANSFORMERS_BACKEND
|
|
||||||
and transformers_model_class is not None
|
|
||||||
and transformers_model_class._supports_flex_attn
|
|
||||||
):
|
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
quantization_config = config_dict.get("compression_config", None)
|
quantization_config = config_dict.get("compression_config", None)
|
||||||
@ -649,7 +634,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -756,7 +741,7 @@ def get_model(
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -767,7 +752,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -792,7 +777,7 @@ def get_model(
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -803,7 +788,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -840,7 +825,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -863,7 +848,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return TransformersFlashCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -887,7 +872,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -913,12 +898,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||||
model_type == LLAMA
|
|
||||||
or model_type == BAICHUAN
|
|
||||||
or model_type == PHI3
|
|
||||||
or model_type == GRANITE
|
|
||||||
):
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -931,12 +911,8 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
# elif sharded:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
# raise NotImplementedError(
|
return TransformersFlashCausalLM.fallback(
|
||||||
# FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
|
||||||
# )
|
|
||||||
else:
|
|
||||||
return transformers_causal_lm_class.fallback(
|
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -944,6 +920,47 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == BAICHUAN:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -959,10 +976,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -988,7 +1014,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1010,10 +1036,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1041,7 +1076,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1091,7 +1126,7 @@ def get_model(
|
|||||||
config_class=RWConfig,
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1113,10 +1148,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1138,10 +1182,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1163,12 +1216,21 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1190,10 +1252,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1339,8 +1410,6 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||||
|
|
||||||
if sharded:
|
|
||||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
@ -1353,8 +1422,19 @@ def get_model(
|
|||||||
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
||||||
elif quantize == "exl2":
|
elif quantize == "exl2":
|
||||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
return transformers_causal_lm_class.fallback(
|
# Fast transformers if available
|
||||||
|
transformers_model_class = getattr(
|
||||||
|
transformers,
|
||||||
|
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
FLASH_TRANSFORMERS_BACKEND
|
||||||
|
and transformers_model_class is not None
|
||||||
|
and transformers_model_class._supports_flex_attn
|
||||||
|
):
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1362,6 +1442,10 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sharded:
|
||||||
|
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM.fallback(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1375,7 +1459,7 @@ def get_model(
|
|||||||
auto_map = config_dict.get("auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
if trust_remote_code and auto_map is not None:
|
if trust_remote_code and auto_map is not None:
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -260,4 +260,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.model.lm_head(hidden_states)
|
logits = self.model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
|
||||||
|
if hasattr(self.model.config, "logits_scaling"):
|
||||||
|
logits = logits / self.model.config.logits_scaling
|
||||||
|
# For Cohere for similar reasons
|
||||||
|
elif hasattr(self.model, "logit_scale"):
|
||||||
|
logits = logits * self.model.logit_scale
|
||||||
|
|
||||||
return logits, None
|
return logits, None
|
||||||
|
Loading…
Reference in New Issue
Block a user