Update based on transformers PR

This commit is contained in:
Cyril Vallez 2025-01-17 15:34:08 +00:00
parent ac62bd1572
commit b03d7ae951
No known key found for this signature in database
2 changed files with 9 additions and 8 deletions

View File

@ -378,7 +378,7 @@ def get_model(
# Fast transformers path
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
if transformers_model_class._supports_flex_attn:
if transformers_model_class.is_backend_compatible():
transformers_causal_lm_class = TransformersFlashCausalLM
quantization_config = config_dict.get("quantization_config", None)

View File

@ -250,14 +250,19 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
lm_head_indices: torch.Tensor,
lm_head_indices: Optional[torch.Tensor],
):
hidden_states = self.model.model.forward(
# Transformers does not support None as a default
if lm_head_indices is None:
lm_head_indices = 0
logits = self.model.forward(
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
num_logits_to_keep=lm_head_indices,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
@ -265,11 +270,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen=seqlen,
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
).logits.squeeze(dim=0)
return logits