diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 35ab8ede..54665083 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -375,7 +375,8 @@ def get_model( ) model_type = config_dict.get("model_type", None) - transformers_causal_lm_class = CausalLM + # transformers_causal_lm_class = CausalLM + transformers_causal_lm_class = TransformersFlashCausalLM if ( not USE_CUSTOM_MODELING and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index abfaa06e..d71e75b4 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -6,6 +6,7 @@ import torch from opentelemetry import trace from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import transformers.modeling_utils from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, @@ -31,36 +32,12 @@ from text_generation_server.models.metadata_kernels import block_tables_to_ragge tracer = trace.get_tracer(__name__) -def patch_everywhere( - attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None -): - """ - Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`. - - Args: - attribute_name (`str`): - The name of attribute to patch. - patch (`Any`): - The patch for the attribute. - module_name_prefix (`Optional[str]`, defaults to `None`): - If set, only module names starting with this prefix will be considered for patching. - """ - # sys.modules may be updated while being iterated over, hence the list copy. - for name in list(sys.modules): - module = sys.modules[name] - if module_name_prefix is not None and not name.startswith(module_name_prefix): - continue - if hasattr(module, attribute_name): - setattr(module, attribute_name, patch) - - -def _flash_attention_forward_patched( +def tgi_flash_attention_forward( + module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, - query_length: int, - is_causal: bool, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, @@ -71,15 +48,15 @@ def _flash_attention_forward_patched( # This means no scale kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) - # Correctly reshape the states - _, _, num_heads, head_dim = query_states.size() - _, _, num_kv_heads, _ = key_states.size() - # query_states = query_states.view(-1, num_heads, head_dim) - # key_states = key_states.view(-1, num_kv_heads, head_dim) - # value_states = value_states.view(-1, num_kv_heads, head_dim) - query_states = query_states.squeeze(dim=0) - key_states = key_states.squeeze(dim=0) - value_states = value_states.squeeze(dim=0) + query_states = query_states.transpose(1, 2).squeeze(dim=0) + key_states = key_states.transpose(1, 2).squeeze(dim=0) + value_states = value_states.transpose(1, 2).squeeze(dim=0) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) # Take care of updating the cache in-place kv_cache.store( @@ -89,6 +66,8 @@ def _flash_attention_forward_patched( kv_scales=kv_scales ) + + _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window @@ -121,7 +100,10 @@ def _flash_attention_forward_patched( # attn_output = attn_output.view(attn_output.shape[0], -1) attn_output = attn_output.view(-1, num_heads * head_dim) - return attn_output + return attn_output, None + + +transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward class TransformersFlashCausalLM(FlashCausalLM): @@ -174,8 +156,9 @@ class TransformersFlashCausalLM(FlashCausalLM): device_map=("auto" if device_count > 1 else None), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation="flash_attention_2" + attn_implementation="tgi" ) + if device_count == 1 and quantize != "bitsandbytes": model = model.to(device) @@ -261,10 +244,6 @@ class TransformersFlashCausalLM(FlashCausalLM): trust_remote_code=trust_remote_code, ) - def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],): - patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) - return super().warmup(batch, max_input_tokens, max_total_tokens) - def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -318,11 +297,13 @@ class TransformersFlashCausalLM(FlashCausalLM): max_q=batch.max_input_length, max_k=batch.max_current_length, ) - logits = self.model.forward( - input_ids=input_ids[None, ...], + # Use only the Model, not ModelForCausalLM + hidden_states = self.model.model.forward( + input_ids=input_ids[None, ...], # expand dim to easily fit transformers position_ids=position_ids[None, ...], past_key_values=None, use_cache=False, # we use self.kv_cache instead of transformers cache object + return_dict=True, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -330,10 +311,10 @@ class TransformersFlashCausalLM(FlashCausalLM): seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, kv_head_mapping=self.kv_head_mapping, - ).logits[0, ...] - print("SUCCESSFUL FORWARD") + )[0].squeeze(dim=0) + # And compute logits from the lm_head, slicing correctly the indices + logits = self.model.lm_head.forward(hidden_states[lm_head_indices]) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None