mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix: Handle concurrent grammar requests (#1610)
This PR fixes parallel grammar requests, currently grammar states are not concatenated correctly when a new request is added to the batch and this results in incorrect generation. This PR updates the `concatenate` function to correctly include the previous states. fixes: #1601
This commit is contained in:
parent
e9b200369c
commit
e259625b8b
@ -61,7 +61,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.2376709,
|
"logprob": -0.33666992,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
@ -180,7 +180,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.23840332,
|
"logprob": -0.33740234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
@ -299,7 +299,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.23840332,
|
"logprob": -0.33740234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
@ -418,7 +418,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.23840332,
|
"logprob": -0.33740234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
|
@ -408,7 +408,7 @@ class CausalLMBatch(Batch):
|
|||||||
batches[dst_batch_idx].next_token_chooser.dtype,
|
batches[dst_batch_idx].next_token_chooser.dtype,
|
||||||
batches[dst_batch_idx].next_token_chooser.device,
|
batches[dst_batch_idx].next_token_chooser.device,
|
||||||
batches[dst_batch_idx].next_token_chooser.tokenizer,
|
batches[dst_batch_idx].next_token_chooser.tokenizer,
|
||||||
hq_env.is_quantization_enabled
|
quantization_enabled=hq_env.is_quantization_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = batches[dst_batch_idx].input_ids
|
input_ids = batches[dst_batch_idx].input_ids
|
||||||
@ -463,7 +463,11 @@ class CausalLMBatch(Batch):
|
|||||||
parameters.append(parameters[0])
|
parameters.append(parameters[0])
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
parameters, dtype, device, tokenizer, hq_env.is_quantization_enabled
|
pb=parameters,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
quantization_enabled=hq_env.is_quantization_enabled,
|
||||||
)
|
)
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
[r.data.inputs for r in requests] + dummy_inputs,
|
[r.data.inputs for r in requests] + dummy_inputs,
|
||||||
|
@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
|
fsm_grammar_states = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
|
||||||
@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
|
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
top_n_tokens.extend(batch.top_n_tokens)
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
tokenizer=batches[0].next_token_chooser.tokenizer,
|
tokenizer=batches[0].next_token_chooser.tokenizer,
|
||||||
|
fsm_grammar_states=fsm_grammar_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
|
@ -467,7 +467,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
quantization_enabled: bool,
|
fsm_grammar_states: Optional[List[int]] = None,
|
||||||
|
quantization_enabled: bool = False,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
@ -484,7 +485,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammars=[pb_.grammar for pb_ in pb],
|
grammars=[pb_.grammar for pb_ in pb],
|
||||||
grammar_types=[pb_.grammar_type for pb_ in pb],
|
grammar_types=[pb_.grammar_type for pb_ in pb],
|
||||||
fsm_grammar_states=[0] * len(pb),
|
fsm_grammar_states=(
|
||||||
|
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||||
|
),
|
||||||
quantization_enabled=quantization_enabled,
|
quantization_enabled=quantization_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user