diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bf481e29..a1359212 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,7 @@ import os from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto, modeling_task +from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path @@ -380,12 +380,14 @@ def get_model( logger.info( "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." ) - try: - transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - except KeyError: - transformers_model_class = modeling_task.AutoForCausalLM - if transformers_model_class._supports_flash_attn_2: + transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) + # Ugly check but works in the meantime + model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py") + with open(model_path) as file: + has_fa2_class = f"FlashAttention2(" in file.read() + + if transformers_model_class._supports_flash_attn_2 and not has_fa2_class: logger.info( f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " "batch and sequence length). All TGI's batching/caching optimizations are enabled." diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 7bcad8aa..49dcac62 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -52,12 +52,6 @@ def tgi_flash_attention_forward( key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - # Take care of updating the cache in-place kv_cache.store( key=key_states, @@ -66,7 +60,6 @@ def tgi_flash_attention_forward( kv_scales=kv_scales ) - _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window @@ -155,7 +148,8 @@ class TransformersFlashCausalLM(FlashCausalLM): device_map=("auto" if device_count > 1 else None), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation="tgi" + attn_implementation="tgi", + tp_plan="auto" if world_size > 1 else None, ) if device_count == 1 and quantize != "bitsandbytes":