Small updates.

This commit is contained in:
Nicolas Patry 2024-02-22 12:06:36 +00:00
parent ac419f5e46
commit 21b3072288
4 changed files with 19 additions and 9 deletions

View File

@ -105,7 +105,7 @@ class BLOOMSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -114,4 +114,4 @@ class BLOOMSharded(CausalLM):
) )
logits = outputs.logits logits = outputs.logits
return logits, outputs.past_key_values return logits, speculative_logits, outputs.past_key_values

View File

@ -551,7 +551,7 @@ class CausalLM(Model):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
kwargs = { kwargs = {
"input_ids": input_ids, "input_ids": input_ids,
@ -564,7 +564,11 @@ class CausalLM(Model):
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs) outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values if isinstance(outputs, tuple):
outputs, speculative_logits = outputs
else:
speculative_logits = None
return outputs.logits, speculative_logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -574,7 +578,7 @@ class CausalLM(Model):
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward( logits, speculative_logits, past = self.forward(
batch.input_ids, batch.input_ids,
attention_mask, attention_mask,
batch.position_ids, batch.position_ids,

View File

@ -904,7 +904,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
loss = None loss = None
if not return_dict: if not return_dict:
@ -913,8 +913,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
return CausalLMOutputWithCrossAttentions( return CausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=lm_logits, logits=logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) ), speculative_logits

View File

@ -602,7 +602,7 @@ class Seq2SeqLM(Model):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
outputs, speculative_logits = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
@ -611,6 +611,12 @@ class Seq2SeqLM(Model):
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
if isinstance(outputs, tuple):
# Our custom models
outputs, speculative_logits = outputs
else:
# Generic transformers models
speculative_logits = None
return ( return (
outputs.logits, outputs.logits,
speculative_logits, speculative_logits,