Fix MPT, not sure about idefics.

This commit is contained in:
Nicolas Patry 2024-02-22 16:08:15 +00:00
parent 64d38afa9f
commit f592df5234
3 changed files with 31 additions and 20 deletions

View File

@ -281,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.fc(input)
output, speculative_logits = self.fc(input)
additional_features = self.additional_fc(input)
output = torch.cat((output, additional_features), -1)
return output
return output, speculative_logits
def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
@ -1501,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits, speculative_logits = self.lm_head(hidden_states)
loss = None
return CausalLMOutputWithPastImage(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
return (
CausalLMOutputWithPastImage(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
),
speculative_logits,
)
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):

View File

@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
output_hidden_states=output_hidden_states,
use_cache=use_cache,
)
logits = self.lm_head(outputs.last_hidden_state)
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
if self.logit_scale is not None:
if self.logit_scale == 0:
warnings.warn(
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
return (
CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation(

View File

@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
outputs, speculative_logits = self.model.forward(**kwargs)
return (
outputs.logits,
speculative_logits,
outputs.past_key_values,
outputs.image_hidden_states,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
:, : -batch.padding_right_offset
]
logits, past, image_hidden_states = self.forward(
logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
attention_mask=attention_mask,
position_ids=batch.position_ids,