mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
push change
This commit is contained in:
parent
490ca0ef6a
commit
f843b62a44
@ -375,7 +375,8 @@ def get_model(
|
|||||||
)
|
)
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
transformers_causal_lm_class = CausalLM
|
# transformers_causal_lm_class = CausalLM
|
||||||
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
if (
|
if (
|
||||||
not USE_CUSTOM_MODELING
|
not USE_CUSTOM_MODELING
|
||||||
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
@ -31,36 +32,12 @@ from text_generation_server.models.metadata_kernels import block_tables_to_ragge
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patch_everywhere(
|
def tgi_flash_attention_forward(
|
||||||
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None
|
module,
|
||||||
):
|
|
||||||
"""
|
|
||||||
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attribute_name (`str`):
|
|
||||||
The name of attribute to patch.
|
|
||||||
patch (`Any`):
|
|
||||||
The patch for the attribute.
|
|
||||||
module_name_prefix (`Optional[str]`, defaults to `None`):
|
|
||||||
If set, only module names starting with this prefix will be considered for patching.
|
|
||||||
"""
|
|
||||||
# sys.modules may be updated while being iterated over, hence the list copy.
|
|
||||||
for name in list(sys.modules):
|
|
||||||
module = sys.modules[name]
|
|
||||||
if module_name_prefix is not None and not name.startswith(module_name_prefix):
|
|
||||||
continue
|
|
||||||
if hasattr(module, attribute_name):
|
|
||||||
setattr(module, attribute_name, patch)
|
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_forward_patched(
|
|
||||||
query_states: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
value_states: torch.Tensor,
|
value_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
query_length: int,
|
|
||||||
is_causal: bool,
|
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
@ -71,15 +48,15 @@ def _flash_attention_forward_patched(
|
|||||||
# This means no scale
|
# This means no scale
|
||||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||||
|
|
||||||
# Correctly reshape the states
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
_, _, num_heads, head_dim = query_states.size()
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
_, _, num_kv_heads, _ = key_states.size()
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
# query_states = query_states.view(-1, num_heads, head_dim)
|
|
||||||
# key_states = key_states.view(-1, num_kv_heads, head_dim)
|
input_dtype = query_states.dtype
|
||||||
# value_states = value_states.view(-1, num_kv_heads, head_dim)
|
if input_dtype == torch.float32:
|
||||||
query_states = query_states.squeeze(dim=0)
|
query_states = query_states.to(target_dtype)
|
||||||
key_states = key_states.squeeze(dim=0)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.squeeze(dim=0)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
@ -89,6 +66,8 @@ def _flash_attention_forward_patched(
|
|||||||
kv_scales=kv_scales
|
kv_scales=kv_scales
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_, num_heads, head_dim = query_states.shape
|
||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
|
||||||
@ -121,7 +100,10 @@ def _flash_attention_forward_patched(
|
|||||||
# attn_output = attn_output.view(attn_output.shape[0], -1)
|
# attn_output = attn_output.view(attn_output.shape[0], -1)
|
||||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
|
|
||||||
return attn_output
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
class TransformersFlashCausalLM(FlashCausalLM):
|
class TransformersFlashCausalLM(FlashCausalLM):
|
||||||
@ -174,8 +156,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
device_map=("auto" if device_count > 1 else None),
|
device_map=("auto" if device_count > 1 else None),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
attn_implementation="flash_attention_2"
|
attn_implementation="tgi"
|
||||||
)
|
)
|
||||||
|
|
||||||
if device_count == 1 and quantize != "bitsandbytes":
|
if device_count == 1 and quantize != "bitsandbytes":
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
@ -261,10 +244,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],):
|
|
||||||
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
|
||||||
return super().warmup(batch, max_input_tokens, max_total_tokens)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -318,11 +297,13 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
max_q=batch.max_input_length,
|
max_q=batch.max_input_length,
|
||||||
max_k=batch.max_current_length,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits = self.model.forward(
|
# Use only the Model, not ModelForCausalLM
|
||||||
input_ids=input_ids[None, ...],
|
hidden_states = self.model.model.forward(
|
||||||
|
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
||||||
position_ids=position_ids[None, ...],
|
position_ids=position_ids[None, ...],
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
return_dict=True,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -330,10 +311,10 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
).logits[0, ...]
|
)[0].squeeze(dim=0)
|
||||||
print("SUCCESSFUL FORWARD")
|
# And compute logits from the lm_head, slicing correctly the indices
|
||||||
|
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, None
|
return logits, None
|
||||||
|
Loading…
Reference in New Issue
Block a user