raise error if needed

This commit is contained in:
Cyril Vallez 2025-01-20 11:29:51 +01:00
parent f01014de37
commit 2659b5998b
No known key found for this signature in database
2 changed files with 3 additions and 1 deletions

View File

@ -380,6 +380,8 @@ def get_model(
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
if transformers_model_class.is_backend_compatible(): if transformers_model_class.is_backend_compatible():
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.")
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None: if quantization_config is None:

View File

@ -48,7 +48,7 @@ def tgi_flash_attention_forward(
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
softcap: Optional[float] = None, softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling **_kwargs, # This is needed to "absorb" other args passed by Transformers modeling
): ):
kv_cache = kv_cache[module.layer_idx] kv_cache = kv_cache[module.layer_idx]