mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix MPT, not sure about idefics.
This commit is contained in:
parent
64d38afa9f
commit
f592df5234
@ -281,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
output = self.fc(input)
|
output, speculative_logits = self.fc(input)
|
||||||
additional_features = self.additional_fc(input)
|
additional_features = self.additional_fc(input)
|
||||||
output = torch.cat((output, additional_features), -1)
|
output = torch.cat((output, additional_features), -1)
|
||||||
|
|
||||||
return output
|
return output, speculative_logits
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
||||||
@ -1501,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
return CausalLMOutputWithPastImage(
|
return (
|
||||||
|
CausalLMOutputWithPastImage(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
image_hidden_states=outputs.image_hidden_states,
|
image_hidden_states=outputs.image_hidden_states,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
logits = self.lm_head(outputs.last_hidden_state)
|
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||||
if self.logit_scale is not None:
|
if self.logit_scale is not None:
|
||||||
if self.logit_scale == 0:
|
if self.logit_scale == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
||||||
)
|
)
|
||||||
return CausalLMOutputWithPast(
|
return (
|
||||||
|
CausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
|
@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
|
|||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
return (
|
||||||
|
outputs.logits,
|
||||||
|
speculative_logits,
|
||||||
|
outputs.past_key_values,
|
||||||
|
outputs.image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
|
|||||||
:, : -batch.padding_right_offset
|
:, : -batch.padding_right_offset
|
||||||
]
|
]
|
||||||
|
|
||||||
logits, past, image_hidden_states = self.forward(
|
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=batch.position_ids,
|
position_ids=batch.position_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user