latest transformers changes

This commit is contained in:
Cyril Vallez 2024-12-19 17:37:45 +00:00
parent f4c60ca522
commit 2e2631e093
2 changed files with 10 additions and 14 deletions

View File

@ -12,7 +12,7 @@ import os
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto, modeling_task from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
@ -380,12 +380,14 @@ def get_model(
logger.info( logger.info(
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
) )
try:
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
except KeyError:
transformers_model_class = modeling_task.AutoForCausalLM
if transformers_model_class._supports_flash_attn_2: transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
# Ugly check but works in the meantime
model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py")
with open(model_path) as file:
has_fa2_class = f"FlashAttention2(" in file.read()
if transformers_model_class._supports_flash_attn_2 and not has_fa2_class:
logger.info( logger.info(
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
"batch and sequence length). All TGI's batching/caching optimizations are enabled." "batch and sequence length). All TGI's batching/caching optimizations are enabled."

View File

@ -52,12 +52,6 @@ def tgi_flash_attention_forward(
key_states = key_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Take care of updating the cache in-place # Take care of updating the cache in-place
kv_cache.store( kv_cache.store(
key=key_states, key=key_states,
@ -66,7 +60,6 @@ def tgi_flash_attention_forward(
kv_scales=kv_scales kv_scales=kv_scales
) )
_, num_heads, head_dim = query_states.shape _, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window sliding_window = -1 if sliding_window is None else sliding_window
@ -155,7 +148,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
device_map=("auto" if device_count > 1 else None), device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
attn_implementation="tgi" attn_implementation="tgi",
tp_plan="auto" if world_size > 1 else None,
) )
if device_count == 1 and quantize != "bitsandbytes": if device_count == 1 and quantize != "bitsandbytes":