remove unnecessary reinitialize to HeterogeneousNextTokenChooser to make sampling output correct

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-07-01 19:34:47 -07:00
parent 429dcd9c64
commit cf564ec0e2

View File

@ -1076,22 +1076,23 @@ class FlashCausalLMBatch(Batch):
(0, padded_bs - self.cache_lengths_tensor.shape[0]), (0, padded_bs - self.cache_lengths_tensor.shape[0]),
value=0, value=0,
) )
next_token_chooser_parameters = [] if len(self.next_token_chooser.do_sample) != padded_bs:
next_token_chooser_parameters.extend([r.parameters for r in self.requests]) next_token_chooser_parameters = []
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs) next_token_chooser_parameters.extend([r.parameters for r in self.requests])
# update past grammar states pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
fsm_grammar_states = [0] * padded_bs # update past grammar states
fsm_grammar_states = [0] * padded_bs
for i, req in enumerate(self.requests): for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
self.next_token_chooser.dtype, self.next_token_chooser.dtype,
self.next_token_chooser.device, self.next_token_chooser.device,
self.next_token_chooser.tokenizer, self.next_token_chooser.tokenizer,
fsm_grammar_states, fsm_grammar_states,
) )
def prepare_for_prefill( def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
@ -1379,23 +1380,25 @@ class FlashCausalLMBatch(Batch):
self.all_input_ids_tensor[i] self.all_input_ids_tensor[i]
) )
self.all_input_ids_tensor = all_input_ids_tensor self.all_input_ids_tensor = all_input_ids_tensor
if len(self.next_token_chooser.do_sample) != max_padded_bs:
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(
next_token_chooser_parameters, max_padded_bs
)
# update past grammar states
fsm_grammar_states = [0] * max_padded_bs
next_token_chooser_parameters = [] for i, req in enumerate(self.requests):
next_token_chooser_parameters.extend([r.parameters for r in self.requests]) fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
# update past grammar states
fsm_grammar_states = [0] * max_padded_bs
for i, req in enumerate(self.requests): self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i] next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb( self.next_token_chooser.device,
next_token_chooser_parameters, self.next_token_chooser.tokenizer,
self.next_token_chooser.dtype, fsm_grammar_states,
self.next_token_chooser.device, )
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
if ADAPTER_TO_INDEX: if ADAPTER_TO_INDEX:
if adapter_set: if adapter_set: