mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Small updates.
This commit is contained in:
parent
ac419f5e46
commit
21b3072288
@ -105,7 +105,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -114,4 +114,4 @@ class BLOOMSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
return logits, outputs.past_key_values
|
return logits, speculative_logits, outputs.past_key_values
|
||||||
|
@ -551,7 +551,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@ -564,7 +564,11 @@ class CausalLM(Model):
|
|||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
if isinstance(outputs, tuple):
|
||||||
|
outputs, speculative_logits = outputs
|
||||||
|
else:
|
||||||
|
speculative_logits = None
|
||||||
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -574,7 +578,7 @@ class CausalLM(Model):
|
|||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, speculative_logits, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
|
@ -904,7 +904,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@ -913,8 +913,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||||||
|
|
||||||
return CausalLMOutputWithCrossAttentions(
|
return CausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
), speculative_logits
|
||||||
|
@ -602,7 +602,7 @@ class Seq2SeqLM(Model):
|
|||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs, speculative_logits = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
@ -611,6 +611,12 @@ class Seq2SeqLM(Model):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
# Our custom models
|
||||||
|
outputs, speculative_logits = outputs
|
||||||
|
else:
|
||||||
|
# Generic transformers models
|
||||||
|
speculative_logits = None
|
||||||
return (
|
return (
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
speculative_logits,
|
speculative_logits,
|
||||||
|
Loading…
Reference in New Issue
Block a user