push change

This commit is contained in:
Cyril Vallez 2024-12-12 18:25:11 +00:00
parent 490ca0ef6a
commit f843b62a44
2 changed files with 29 additions and 47 deletions

View File

@ -375,7 +375,8 @@ def get_model(
)
model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM
# transformers_causal_lm_class = CausalLM
transformers_causal_lm_class = TransformersFlashCausalLM
if (
not USE_CUSTOM_MODELING
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

View File

@ -6,6 +6,7 @@ import torch
from opentelemetry import trace
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
@ -31,36 +32,12 @@ from text_generation_server.models.metadata_kernels import block_tables_to_ragge
tracer = trace.get_tracer(__name__)
def patch_everywhere(
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None
):
"""
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
Args:
attribute_name (`str`):
The name of attribute to patch.
patch (`Any`):
The patch for the attribute.
module_name_prefix (`Optional[str]`, defaults to `None`):
If set, only module names starting with this prefix will be considered for patching.
"""
# sys.modules may be updated while being iterated over, hence the list copy.
for name in list(sys.modules):
module = sys.modules[name]
if module_name_prefix is not None and not name.startswith(module_name_prefix):
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)
def _flash_attention_forward_patched(
def tgi_flash_attention_forward(
module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
@ -71,15 +48,15 @@ def _flash_attention_forward_patched(
# This means no scale
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
# Correctly reshape the states
_, _, num_heads, head_dim = query_states.size()
_, _, num_kv_heads, _ = key_states.size()
# query_states = query_states.view(-1, num_heads, head_dim)
# key_states = key_states.view(-1, num_kv_heads, head_dim)
# value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states = query_states.squeeze(dim=0)
key_states = key_states.squeeze(dim=0)
value_states = value_states.squeeze(dim=0)
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Take care of updating the cache in-place
kv_cache.store(
@ -89,6 +66,8 @@ def _flash_attention_forward_patched(
kv_scales=kv_scales
)
_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
@ -121,7 +100,10 @@ def _flash_attention_forward_patched(
# attn_output = attn_output.view(attn_output.shape[0], -1)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output
return attn_output, None
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
class TransformersFlashCausalLM(FlashCausalLM):
@ -174,8 +156,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2"
attn_implementation="tgi"
)
if device_count == 1 and quantize != "bitsandbytes":
model = model.to(device)
@ -261,10 +244,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
trust_remote_code=trust_remote_code,
)
def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
return super().warmup(batch, max_input_tokens, max_total_tokens)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -318,11 +297,13 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
logits = self.model.forward(
input_ids=input_ids[None, ...],
# Use only the Model, not ModelForCausalLM
hidden_states = self.model.model.forward(
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
position_ids=position_ids[None, ...],
past_key_values=None,
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
@ -330,10 +311,10 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
kv_head_mapping=self.kv_head_mapping,
).logits[0, ...]
print("SUCCESSFUL FORWARD")
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, None