diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 66be0be2..92f3c51c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -378,7 +378,7 @@ def get_model( # Fast transformers path transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - if transformers_model_class._supports_flex_attn: + if transformers_model_class.is_backend_compatible(): transformers_causal_lm_class = TransformersFlashCausalLM quantization_config = config_dict.get("quantization_config", None) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 18ab27c2..21aa1f8b 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -250,14 +250,19 @@ class TransformersFlashCausalLM(FlashCausalLM): slots: torch.Tensor, seqlen: Seqlen, max_s: int, - lm_head_indices: torch.Tensor, + lm_head_indices: Optional[torch.Tensor], ): - hidden_states = self.model.model.forward( + # Transformers does not support None as a default + if lm_head_indices is None: + lm_head_indices = 0 + + logits = self.model.forward( input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, + num_logits_to_keep=lm_head_indices, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -265,11 +270,7 @@ class TransformersFlashCausalLM(FlashCausalLM): seqlen=seqlen, max_s=max_s, kv_head_mapping=self.kv_head_mapping, - )[0].squeeze(dim=0) - # And compute logits from the lm_head, slicing correctly the indices - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits = self.model.lm_head.forward(hidden_states) + ).logits.squeeze(dim=0) return logits