diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 7d7bf23d..ee4cdb08 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -281,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - output = self.fc(input) + output, speculative_logits = self.fc(input) additional_features = self.additional_fc(input) output = torch.cat((output, additional_features), -1) - return output + return output, speculative_logits def extra_repr(self) -> str: """Overwriting `nn.Linear.extra_repr` to include new parameters.""" @@ -1501,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) loss = None - return CausalLMOutputWithPastImage( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, + return ( + CausalLMOutputWithPastImage( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ), + speculative_logits, ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 0e755e47..9b0f8b92 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel): output_hidden_states=output_hidden_states, use_cache=use_cache, ) - logits = self.lm_head(outputs.last_hidden_state) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( @@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel): loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) ) - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index a6df2ebe..c96e8152 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -662,8 +662,13 @@ class IdeficsCausalLM(Model): if self.has_position_ids: kwargs["position_ids"] = position_ids - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values, outputs.image_hidden_states + outputs, speculative_logits = self.model.forward(**kwargs) + return ( + outputs.logits, + speculative_logits, + outputs.past_key_values, + outputs.image_hidden_states, + ) @tracer.start_as_current_span("generate_token") def generate_token( @@ -686,7 +691,7 @@ class IdeficsCausalLM(Model): :, : -batch.padding_right_offset ] - logits, past, image_hidden_states = self.forward( + logits, speculative_logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, attention_mask=attention_mask, position_ids=batch.position_ids,