mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
latest transformers changes
This commit is contained in:
parent
f4c60ca522
commit
2e2631e093
@ -12,7 +12,7 @@ import os
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto, modeling_task
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -380,12 +380,14 @@ def get_model(
|
|||||||
logger.info(
|
logger.info(
|
||||||
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
|
||||||
except KeyError:
|
|
||||||
transformers_model_class = modeling_task.AutoForCausalLM
|
|
||||||
|
|
||||||
if transformers_model_class._supports_flash_attn_2:
|
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||||
|
# Ugly check but works in the meantime
|
||||||
|
model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py")
|
||||||
|
with open(model_path) as file:
|
||||||
|
has_fa2_class = f"FlashAttention2(" in file.read()
|
||||||
|
|
||||||
|
if transformers_model_class._supports_flash_attn_2 and not has_fa2_class:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
||||||
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
||||||
|
@ -52,12 +52,6 @@ def tgi_flash_attention_forward(
|
|||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
|
||||||
input_dtype = query_states.dtype
|
|
||||||
if input_dtype == torch.float32:
|
|
||||||
query_states = query_states.to(target_dtype)
|
|
||||||
key_states = key_states.to(target_dtype)
|
|
||||||
value_states = value_states.to(target_dtype)
|
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=key_states,
|
key=key_states,
|
||||||
@ -66,7 +60,6 @@ def tgi_flash_attention_forward(
|
|||||||
kv_scales=kv_scales
|
kv_scales=kv_scales
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_, num_heads, head_dim = query_states.shape
|
_, num_heads, head_dim = query_states.shape
|
||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
@ -155,7 +148,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
device_map=("auto" if device_count > 1 else None),
|
device_map=("auto" if device_count > 1 else None),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
attn_implementation="tgi"
|
attn_implementation="tgi",
|
||||||
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device_count == 1 and quantize != "bitsandbytes":
|
if device_count == 1 and quantize != "bitsandbytes":
|
||||||
|
Loading…
Reference in New Issue
Block a user