diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index f3f52496..6cfc290e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1076,22 +1076,23 @@ class FlashCausalLMBatch(Batch): (0, padded_bs - self.cache_lengths_tensor.shape[0]), value=0, ) - next_token_chooser_parameters = [] - next_token_chooser_parameters.extend([r.parameters for r in self.requests]) - pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs) - # update past grammar states - fsm_grammar_states = [0] * padded_bs + if len(self.next_token_chooser.do_sample) != padded_bs: + next_token_chooser_parameters = [] + next_token_chooser_parameters.extend([r.parameters for r in self.requests]) + pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs) + # update past grammar states + fsm_grammar_states = [0] * padded_bs - for i, req in enumerate(self.requests): - fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] + for i, req in enumerate(self.requests): + fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] - self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, - self.next_token_chooser.dtype, - self.next_token_chooser.device, - self.next_token_chooser.tokenizer, - fsm_grammar_states, - ) + self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, + self.next_token_chooser.dtype, + self.next_token_chooser.device, + self.next_token_chooser.tokenizer, + fsm_grammar_states, + ) def prepare_for_prefill( self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id @@ -1379,23 +1380,25 @@ class FlashCausalLMBatch(Batch): self.all_input_ids_tensor[i] ) self.all_input_ids_tensor = all_input_ids_tensor + if len(self.next_token_chooser.do_sample) != max_padded_bs: + next_token_chooser_parameters = [] + next_token_chooser_parameters.extend([r.parameters for r in self.requests]) + pad_next_token_chooser_parameters( + next_token_chooser_parameters, max_padded_bs + ) + # update past grammar states + fsm_grammar_states = [0] * max_padded_bs - next_token_chooser_parameters = [] - next_token_chooser_parameters.extend([r.parameters for r in self.requests]) - pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs) - # update past grammar states - fsm_grammar_states = [0] * max_padded_bs + for i, req in enumerate(self.requests): + fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] - for i, req in enumerate(self.requests): - fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] - - self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, - self.next_token_chooser.dtype, - self.next_token_chooser.device, - self.next_token_chooser.tokenizer, - fsm_grammar_states, - ) + self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, + self.next_token_chooser.dtype, + self.next_token_chooser.device, + self.next_token_chooser.tokenizer, + fsm_grammar_states, + ) if ADAPTER_TO_INDEX: if adapter_set: