mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Update based on transformers PR
This commit is contained in:
parent
ac62bd1572
commit
b03d7ae951
@ -378,7 +378,7 @@ def get_model(
|
|||||||
|
|
||||||
# Fast transformers path
|
# Fast transformers path
|
||||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||||
if transformers_model_class._supports_flex_attn:
|
if transformers_model_class.is_backend_compatible():
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
@ -250,14 +250,19 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: torch.Tensor,
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
hidden_states = self.model.model.forward(
|
# Transformers does not support None as a default
|
||||||
|
if lm_head_indices is None:
|
||||||
|
lm_head_indices = 0
|
||||||
|
|
||||||
|
logits = self.model.forward(
|
||||||
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
|
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||||
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
|
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
num_logits_to_keep=lm_head_indices,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -265,11 +270,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
)[0].squeeze(dim=0)
|
).logits.squeeze(dim=0)
|
||||||
# And compute logits from the lm_head, slicing correctly the indices
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
|
||||||
logits = self.model.lm_head.forward(hidden_states)
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user