Small updates.

This commit is contained in:
Nicolas Patry 2024-02-22 12:06:36 +00:00
parent ac419f5e46
commit 21b3072288
4 changed files with 19 additions and 9 deletions

View File

@ -105,7 +105,7 @@ class BLOOMSharded(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -114,4 +114,4 @@ class BLOOMSharded(CausalLM):
)
logits = outputs.logits
return logits, outputs.past_key_values
return logits, speculative_logits, outputs.past_key_values

View File

@ -551,7 +551,7 @@ class CausalLM(Model):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
"input_ids": input_ids,
@ -564,7 +564,11 @@ class CausalLM(Model):
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values
if isinstance(outputs, tuple):
outputs, speculative_logits = outputs
else:
speculative_logits = None
return outputs.logits, speculative_logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")
def generate_token(
@ -574,7 +578,7 @@ class CausalLM(Model):
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward(
logits, speculative_logits, past = self.forward(
batch.input_ids,
attention_mask,
batch.position_ids,

View File

@ -904,7 +904,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
logits, speculative_logits = self.lm_head(hidden_states)
loss = None
if not return_dict:
@ -913,8 +913,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
logits=logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
), speculative_logits

View File

@ -602,7 +602,7 @@ class Seq2SeqLM(Model):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs, speculative_logits = self.model.forward(
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
@ -611,6 +611,12 @@ class Seq2SeqLM(Model):
past_key_values=past_key_values,
use_cache=True,
)
if isinstance(outputs, tuple):
# Our custom models
outputs, speculative_logits = outputs
else:
# Generic transformers models
speculative_logits = None
return (
outputs.logits,
speculative_logits,