mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
remove unnecessary reinitialize to HeterogeneousNextTokenChooser to make sampling output correct
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
429dcd9c64
commit
cf564ec0e2
@ -1076,22 +1076,23 @@ class FlashCausalLMBatch(Batch):
|
|||||||
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
|
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
|
||||||
value=0,
|
value=0,
|
||||||
)
|
)
|
||||||
next_token_chooser_parameters = []
|
if len(self.next_token_chooser.do_sample) != padded_bs:
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
next_token_chooser_parameters = []
|
||||||
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
|
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||||
# update past grammar states
|
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
|
||||||
fsm_grammar_states = [0] * padded_bs
|
# update past grammar states
|
||||||
|
fsm_grammar_states = [0] * padded_bs
|
||||||
|
|
||||||
for i, req in enumerate(self.requests):
|
for i, req in enumerate(self.requests):
|
||||||
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
|
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
|
||||||
|
|
||||||
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
self.next_token_chooser.dtype,
|
self.next_token_chooser.dtype,
|
||||||
self.next_token_chooser.device,
|
self.next_token_chooser.device,
|
||||||
self.next_token_chooser.tokenizer,
|
self.next_token_chooser.tokenizer,
|
||||||
fsm_grammar_states,
|
fsm_grammar_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_prefill(
|
def prepare_for_prefill(
|
||||||
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
|
||||||
@ -1379,23 +1380,25 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.all_input_ids_tensor[i]
|
self.all_input_ids_tensor[i]
|
||||||
)
|
)
|
||||||
self.all_input_ids_tensor = all_input_ids_tensor
|
self.all_input_ids_tensor = all_input_ids_tensor
|
||||||
|
if len(self.next_token_chooser.do_sample) != max_padded_bs:
|
||||||
|
next_token_chooser_parameters = []
|
||||||
|
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||||
|
pad_next_token_chooser_parameters(
|
||||||
|
next_token_chooser_parameters, max_padded_bs
|
||||||
|
)
|
||||||
|
# update past grammar states
|
||||||
|
fsm_grammar_states = [0] * max_padded_bs
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
for i, req in enumerate(self.requests):
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
|
||||||
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
|
|
||||||
# update past grammar states
|
|
||||||
fsm_grammar_states = [0] * max_padded_bs
|
|
||||||
|
|
||||||
for i, req in enumerate(self.requests):
|
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
|
next_token_chooser_parameters,
|
||||||
|
self.next_token_chooser.dtype,
|
||||||
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
self.next_token_chooser.device,
|
||||||
next_token_chooser_parameters,
|
self.next_token_chooser.tokenizer,
|
||||||
self.next_token_chooser.dtype,
|
fsm_grammar_states,
|
||||||
self.next_token_chooser.device,
|
)
|
||||||
self.next_token_chooser.tokenizer,
|
|
||||||
fsm_grammar_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ADAPTER_TO_INDEX:
|
if ADAPTER_TO_INDEX:
|
||||||
if adapter_set:
|
if adapter_set:
|
||||||
|
Loading…
Reference in New Issue
Block a user