diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 590c0d57..67129ec3 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -105,7 +105,7 @@ class BLOOMSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -114,4 +114,4 @@ class BLOOMSharded(CausalLM): ) logits = outputs.logits - return logits, outputs.past_key_values + return logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index b7f5e9db..d4666229 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -551,7 +551,7 @@ class CausalLM(Model): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { "input_ids": input_ids, @@ -564,7 +564,11 @@ class CausalLM(Model): kwargs["position_ids"] = position_ids outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values + if isinstance(outputs, tuple): + outputs, speculative_logits = outputs + else: + speculative_logits = None + return outputs.logits, speculative_logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") def generate_token( @@ -574,7 +578,7 @@ class CausalLM(Model): # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - logits, past = self.forward( + logits, speculative_logits, past = self.forward( batch.input_ids, attention_mask, batch.position_ids, diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 8e3be63f..9381d164 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -904,7 +904,7 @@ class BloomForCausalLM(BloomPreTrainedModel): ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) loss = None if not return_dict: @@ -913,8 +913,8 @@ class BloomForCausalLM(BloomPreTrainedModel): return CausalLMOutputWithCrossAttentions( loss=loss, - logits=lm_logits, + logits=logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) + ), speculative_logits diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index cae5525c..fae9a2df 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -602,7 +602,7 @@ class Seq2SeqLM(Model): List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: # Model Forward - outputs, speculative_logits = self.model.forward( + outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -611,6 +611,12 @@ class Seq2SeqLM(Model): past_key_values=past_key_values, use_cache=True, ) + if isinstance(outputs, tuple): + # Our custom models + outputs, speculative_logits = outputs + else: + # Generic transformers models + speculative_logits = None return ( outputs.logits, speculative_logits,