mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Small updates.
This commit is contained in:
parent
ac419f5e46
commit
21b3072288
@ -105,7 +105,7 @@ class BLOOMSharded(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -114,4 +114,4 @@ class BLOOMSharded(CausalLM):
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, outputs.past_key_values
|
||||
return logits, speculative_logits, outputs.past_key_values
|
||||
|
@ -551,7 +551,7 @@ class CausalLM(Model):
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
@ -564,7 +564,11 @@ class CausalLM(Model):
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
if isinstance(outputs, tuple):
|
||||
outputs, speculative_logits = outputs
|
||||
else:
|
||||
speculative_logits = None
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
@ -574,7 +578,7 @@ class CausalLM(Model):
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
logits, past = self.forward(
|
||||
logits, speculative_logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
|
@ -904,7 +904,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
|
||||
if not return_dict:
|
||||
@ -913,8 +913,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
), speculative_logits
|
||||
|
@ -602,7 +602,7 @@ class Seq2SeqLM(Model):
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -611,6 +611,12 @@ class Seq2SeqLM(Model):
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
if isinstance(outputs, tuple):
|
||||
# Our custom models
|
||||
outputs, speculative_logits = outputs
|
||||
else:
|
||||
# Generic transformers models
|
||||
speculative_logits = None
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
|
Loading…
Reference in New Issue
Block a user