push change

This commit is contained in:
Cyril Vallez 2024-12-12 18:25:11 +00:00
parent 490ca0ef6a
commit f843b62a44
2 changed files with 29 additions and 47 deletions

View File

@ -375,7 +375,8 @@ def get_model(
) )
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM # transformers_causal_lm_class = CausalLM
transformers_causal_lm_class = TransformersFlashCausalLM
if ( if (
not USE_CUSTOM_MODELING not USE_CUSTOM_MODELING
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

View File

@ -6,6 +6,7 @@ import torch
from opentelemetry import trace from opentelemetry import trace
from loguru import logger from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
@ -31,36 +32,12 @@ from text_generation_server.models.metadata_kernels import block_tables_to_ragge
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
def patch_everywhere( def tgi_flash_attention_forward(
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None module,
):
"""
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
Args:
attribute_name (`str`):
The name of attribute to patch.
patch (`Any`):
The patch for the attribute.
module_name_prefix (`Optional[str]`, defaults to `None`):
If set, only module names starting with this prefix will be considered for patching.
"""
# sys.modules may be updated while being iterated over, hence the list copy.
for name in list(sys.modules):
module = sys.modules[name]
if module_name_prefix is not None and not name.startswith(module_name_prefix):
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)
def _flash_attention_forward_patched(
query_states: torch.Tensor, query_states: torch.Tensor,
key_states: torch.Tensor, key_states: torch.Tensor,
value_states: torch.Tensor, value_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
softcap: Optional[float] = None, softcap: Optional[float] = None,
@ -71,15 +48,15 @@ def _flash_attention_forward_patched(
# This means no scale # This means no scale
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
# Correctly reshape the states query_states = query_states.transpose(1, 2).squeeze(dim=0)
_, _, num_heads, head_dim = query_states.size() key_states = key_states.transpose(1, 2).squeeze(dim=0)
_, _, num_kv_heads, _ = key_states.size() value_states = value_states.transpose(1, 2).squeeze(dim=0)
# query_states = query_states.view(-1, num_heads, head_dim)
# key_states = key_states.view(-1, num_kv_heads, head_dim) input_dtype = query_states.dtype
# value_states = value_states.view(-1, num_kv_heads, head_dim) if input_dtype == torch.float32:
query_states = query_states.squeeze(dim=0) query_states = query_states.to(target_dtype)
key_states = key_states.squeeze(dim=0) key_states = key_states.to(target_dtype)
value_states = value_states.squeeze(dim=0) value_states = value_states.to(target_dtype)
# Take care of updating the cache in-place # Take care of updating the cache in-place
kv_cache.store( kv_cache.store(
@ -89,6 +66,8 @@ def _flash_attention_forward_patched(
kv_scales=kv_scales kv_scales=kv_scales
) )
_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window sliding_window = -1 if sliding_window is None else sliding_window
@ -121,7 +100,10 @@ def _flash_attention_forward_patched(
# attn_output = attn_output.view(attn_output.shape[0], -1) # attn_output = attn_output.view(attn_output.shape[0], -1)
attn_output = attn_output.view(-1, num_heads * head_dim) attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output return attn_output, None
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
class TransformersFlashCausalLM(FlashCausalLM): class TransformersFlashCausalLM(FlashCausalLM):
@ -174,8 +156,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
device_map=("auto" if device_count > 1 else None), device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2" attn_implementation="tgi"
) )
if device_count == 1 and quantize != "bitsandbytes": if device_count == 1 and quantize != "bitsandbytes":
model = model.to(device) model = model.to(device)
@ -261,10 +244,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
return super().warmup(batch, max_input_tokens, max_total_tokens)
def forward( def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -318,11 +297,13 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_q=batch.max_input_length, max_q=batch.max_input_length,
max_k=batch.max_current_length, max_k=batch.max_current_length,
) )
logits = self.model.forward( # Use only the Model, not ModelForCausalLM
input_ids=input_ids[None, ...], hidden_states = self.model.model.forward(
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
position_ids=position_ids[None, ...], position_ids=position_ids[None, ...],
past_key_values=None, past_key_values=None,
use_cache=False, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
@ -330,10 +311,10 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
kv_head_mapping=self.kv_head_mapping, kv_head_mapping=self.kv_head_mapping,
).logits[0, ...] )[0].squeeze(dim=0)
print("SUCCESSFUL FORWARD") # And compute logits from the lm_head, slicing correctly the indices
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits, None return logits, None