mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
push change
This commit is contained in:
parent
490ca0ef6a
commit
f843b62a44
@ -375,7 +375,8 @@ def get_model(
|
||||
)
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
transformers_causal_lm_class = CausalLM
|
||||
# transformers_causal_lm_class = CausalLM
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
if (
|
||||
not USE_CUSTOM_MODELING
|
||||
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from opentelemetry import trace
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
import transformers.modeling_utils
|
||||
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
@ -31,36 +32,12 @@ from text_generation_server.models.metadata_kernels import block_tables_to_ragge
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def patch_everywhere(
|
||||
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
|
||||
|
||||
Args:
|
||||
attribute_name (`str`):
|
||||
The name of attribute to patch.
|
||||
patch (`Any`):
|
||||
The patch for the attribute.
|
||||
module_name_prefix (`Optional[str]`, defaults to `None`):
|
||||
If set, only module names starting with this prefix will be considered for patching.
|
||||
"""
|
||||
# sys.modules may be updated while being iterated over, hence the list copy.
|
||||
for name in list(sys.modules):
|
||||
module = sys.modules[name]
|
||||
if module_name_prefix is not None and not name.startswith(module_name_prefix):
|
||||
continue
|
||||
if hasattr(module, attribute_name):
|
||||
setattr(module, attribute_name, patch)
|
||||
|
||||
|
||||
def _flash_attention_forward_patched(
|
||||
def tgi_flash_attention_forward(
|
||||
module,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
@ -71,15 +48,15 @@ def _flash_attention_forward_patched(
|
||||
# This means no scale
|
||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||
|
||||
# Correctly reshape the states
|
||||
_, _, num_heads, head_dim = query_states.size()
|
||||
_, _, num_kv_heads, _ = key_states.size()
|
||||
# query_states = query_states.view(-1, num_heads, head_dim)
|
||||
# key_states = key_states.view(-1, num_kv_heads, head_dim)
|
||||
# value_states = value_states.view(-1, num_kv_heads, head_dim)
|
||||
query_states = query_states.squeeze(dim=0)
|
||||
key_states = key_states.squeeze(dim=0)
|
||||
value_states = value_states.squeeze(dim=0)
|
||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Take care of updating the cache in-place
|
||||
kv_cache.store(
|
||||
@ -89,6 +66,8 @@ def _flash_attention_forward_patched(
|
||||
kv_scales=kv_scales
|
||||
)
|
||||
|
||||
|
||||
_, num_heads, head_dim = query_states.shape
|
||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||
sliding_window = -1 if sliding_window is None else sliding_window
|
||||
|
||||
@ -121,7 +100,10 @@ def _flash_attention_forward_patched(
|
||||
# attn_output = attn_output.view(attn_output.shape[0], -1)
|
||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||
|
||||
return attn_output
|
||||
return attn_output, None
|
||||
|
||||
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||
|
||||
|
||||
class TransformersFlashCausalLM(FlashCausalLM):
|
||||
@ -174,8 +156,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
device_map=("auto" if device_count > 1 else None),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
attn_implementation="flash_attention_2"
|
||||
attn_implementation="tgi"
|
||||
)
|
||||
|
||||
if device_count == 1 and quantize != "bitsandbytes":
|
||||
model = model.to(device)
|
||||
|
||||
@ -261,10 +244,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
|
||||
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
||||
return super().warmup(batch, max_input_tokens, max_total_tokens)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
@ -318,11 +297,13 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits = self.model.forward(
|
||||
input_ids=input_ids[None, ...],
|
||||
# Use only the Model, not ModelForCausalLM
|
||||
hidden_states = self.model.model.forward(
|
||||
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
||||
position_ids=position_ids[None, ...],
|
||||
past_key_values=None,
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
return_dict=True,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
@ -330,10 +311,10 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
).logits[0, ...]
|
||||
print("SUCCESSFUL FORWARD")
|
||||
)[0].squeeze(dim=0)
|
||||
# And compute logits from the lm_head, slicing correctly the indices
|
||||
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, None
|
||||
|
Loading…
Reference in New Issue
Block a user