Update based on transformers PR

This commit is contained in:
Cyril Vallez 2025-01-17 15:34:08 +00:00
parent ac62bd1572
commit b03d7ae951
No known key found for this signature in database
2 changed files with 9 additions and 8 deletions

View File

@ -378,7 +378,7 @@ def get_model(
# Fast transformers path # Fast transformers path
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
if transformers_model_class._supports_flex_attn: if transformers_model_class.is_backend_compatible():
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)

View File

@ -250,14 +250,19 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
lm_head_indices: torch.Tensor, lm_head_indices: Optional[torch.Tensor],
): ):
hidden_states = self.model.model.forward( # Transformers does not support None as a default
if lm_head_indices is None:
lm_head_indices = 0
logits = self.model.forward(
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True, return_dict=True,
num_logits_to_keep=lm_head_indices,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
@ -265,11 +270,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
kv_head_mapping=self.kv_head_mapping, kv_head_mapping=self.kv_head_mapping,
)[0].squeeze(dim=0) ).logits.squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
return logits return logits