mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix MPT, not sure about idefics.
This commit is contained in:
parent
64d38afa9f
commit
f592df5234
@ -281,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = self.fc(input)
|
||||
output, speculative_logits = self.fc(input)
|
||||
additional_features = self.additional_fc(input)
|
||||
output = torch.cat((output, additional_features), -1)
|
||||
|
||||
return output
|
||||
return output, speculative_logits
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
||||
@ -1501,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
|
||||
return CausalLMOutputWithPastImage(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
return (
|
||||
CausalLMOutputWithPastImage(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
|
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||
if self.logit_scale is not None:
|
||||
if self.logit_scale == 0:
|
||||
warnings.warn(
|
||||
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
||||
)
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
return (
|
||||
CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
||||
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.past_key_values,
|
||||
outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, past, image_hidden_states = self.forward(
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=batch.position_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user