Transformers backend TP fix (#2945)

* init dispatch

* cohere fix
This commit is contained in:
Cyril Vallez 2025-01-23 18:09:57 +01:00 committed by GitHub
parent 29a0893b67
commit 18c4607d46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 140 additions and 49 deletions

View File

@ -391,21 +391,6 @@ def get_model(
) )
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM
# Fast transformers path
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
transformers_causal_lm_class = TransformersFlashCausalLM
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None: if quantization_config is None:
quantization_config = config_dict.get("compression_config", None) quantization_config = config_dict.get("compression_config", None)
@ -649,7 +634,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
) )
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -756,7 +741,7 @@ def get_model(
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -767,7 +752,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -792,7 +777,7 @@ def get_model(
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -803,7 +788,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -840,7 +825,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -863,7 +848,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return transformers_causal_lm_class.fallback( return TransformersFlashCausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -887,7 +872,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -913,12 +898,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif ( elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -931,12 +911,8 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
# elif sharded: elif FLASH_TRANSFORMERS_BACKEND:
# raise NotImplementedError( return TransformersFlashCausalLM.fallback(
# FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
# )
else:
return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -944,6 +920,47 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == BAICHUAN:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
@ -959,10 +976,19 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -988,7 +1014,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1010,10 +1036,19 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1041,7 +1076,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1091,7 +1126,7 @@ def get_model(
config_class=RWConfig, config_class=RWConfig,
) )
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1113,10 +1148,19 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1138,10 +1182,19 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1163,12 +1216,21 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
) )
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1190,10 +1252,19 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1339,8 +1410,6 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq": if quantize == "gptq":
raise NotImplementedError( raise NotImplementedError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
@ -1353,8 +1422,19 @@ def get_model(
raise NotImplementedError("Eetq quantization is not supported for AutoModel") raise NotImplementedError("Eetq quantization is not supported for AutoModel")
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return transformers_causal_lm_class.fallback( # Fast transformers if available
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
return TransformersFlashCausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1362,6 +1442,10 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM.fallback( return Seq2SeqLM.fallback(
model_id, model_id,
@ -1375,7 +1459,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,

View File

@ -260,4 +260,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head(hidden_states) logits = self.model.lm_head(hidden_states)
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale
return logits, None return logits, None