mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
parent
29a0893b67
commit
18c4607d46
@ -391,21 +391,6 @@ def get_model(
|
||||
)
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
transformers_causal_lm_class = CausalLM
|
||||
|
||||
# Fast transformers path
|
||||
transformers_model_class = getattr(
|
||||
transformers,
|
||||
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
FLASH_TRANSFORMERS_BACKEND
|
||||
and transformers_model_class is not None
|
||||
and transformers_model_class._supports_flex_attn
|
||||
):
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
quantization_config = config_dict.get("compression_config", None)
|
||||
@ -649,7 +634,7 @@ def get_model(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
||||
)
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -756,7 +741,7 @@ def get_model(
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -767,7 +752,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -792,7 +777,7 @@ def get_model(
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -803,7 +788,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -840,7 +825,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -863,7 +848,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -887,7 +872,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -913,12 +898,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif (
|
||||
model_type == LLAMA
|
||||
or model_type == BAICHUAN
|
||||
or model_type == PHI3
|
||||
or model_type == GRANITE
|
||||
):
|
||||
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -931,12 +911,8 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
# elif sharded:
|
||||
# raise NotImplementedError(
|
||||
# FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||
# )
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -944,6 +920,47 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == BAICHUAN:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == GEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
@ -959,10 +976,19 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -988,7 +1014,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1010,10 +1036,19 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1041,7 +1076,7 @@ def get_model(
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1091,7 +1126,7 @@ def get_model(
|
||||
config_class=RWConfig,
|
||||
)
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1113,10 +1148,19 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1138,10 +1182,19 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1163,12 +1216,21 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||
)
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1190,10 +1252,19 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||
else:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1339,8 +1410,6 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||
|
||||
if sharded:
|
||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
raise NotImplementedError(
|
||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
@ -1353,8 +1422,19 @@ def get_model(
|
||||
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
||||
elif quantize == "exl2":
|
||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return transformers_causal_lm_class.fallback(
|
||||
|
||||
# Fast transformers if available
|
||||
transformers_model_class = getattr(
|
||||
transformers,
|
||||
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
FLASH_TRANSFORMERS_BACKEND
|
||||
and transformers_model_class is not None
|
||||
and transformers_model_class._supports_flex_attn
|
||||
):
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1362,6 +1442,10 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if sharded:
|
||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
return Seq2SeqLM.fallback(
|
||||
model_id,
|
||||
@ -1375,7 +1459,7 @@ def get_model(
|
||||
auto_map = config_dict.get("auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return transformers_causal_lm_class.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -260,4 +260,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.model.lm_head(hidden_states)
|
||||
|
||||
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
|
||||
if hasattr(self.model.config, "logits_scaling"):
|
||||
logits = logits / self.model.config.logits_scaling
|
||||
# For Cohere for similar reasons
|
||||
elif hasattr(self.model, "logit_scale"):
|
||||
logits = logits * self.model.logit_scale
|
||||
|
||||
return logits, None
|
||||
|
Loading…
Reference in New Issue
Block a user