mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
raise error if needed
This commit is contained in:
parent
f01014de37
commit
2659b5998b
@ -380,6 +380,8 @@ def get_model(
|
|||||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||||
if transformers_model_class.is_backend_compatible():
|
if transformers_model_class.is_backend_compatible():
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
|
if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
|
||||||
|
raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.")
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
|
@ -48,7 +48,7 @@ def tgi_flash_attention_forward(
|
|||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**_kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
|
Loading…
Reference in New Issue
Block a user