diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index efe4caaa..cc8ddde8 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -46,6 +46,7 @@ from text_generation_server.utils import ( StoppingCriteria, make_tokenizer_optional, is_tokenizer_transparent, + pad_next_token_chooser_parameters, ) from text_generation_server.utils.debug import dbg_trace from text_generation_server.utils.speculate import get_speculate @@ -399,10 +400,9 @@ class CausalLMBatch(Batch): parameters = [r.data.parameters for r in flat_requests] # append the dummy parameters for dummy requests batch_size = batches[dst_batch_idx].batch_size - parameters.extend( - [generate_pb2.NextTokenChooserParameters()] * (batch_size - len(flat_requests)) - ) + parameters = pad_next_token_chooser_parameters(parameters, batch_size) + # update past grammar states fsm_grammar_states = [0] * batch_size for batch in batches: for i, req in enumerate(batch.requests): @@ -465,9 +465,7 @@ class CausalLMBatch(Batch): dummy_inputs = ["?"] * missing_inputs parameters = [r.parameters for r in pb.requests] # append the dummy parameters for dummy request - parameters.extend( - [generate_pb2.NextTokenChooserParameters()] * missing_inputs - ) + parameters = pad_next_token_chooser_parameters(parameters, new_bs) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( pb=parameters, diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index c27dad77..565a7c3c 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -22,7 +22,8 @@ from text_generation_server.utils.tokens import ( Sampling, Greedy, make_tokenizer_optional, - is_tokenizer_transparent + is_tokenizer_transparent, + pad_next_token_chooser_parameters, ) __all__ = [ diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index ef445964..7c5a285e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -504,6 +504,30 @@ class HeterogeneousNextTokenChooser: ) +def pad_next_token_chooser_parameters( + parameters: List[generate_pb2.NextTokenChooserParameters], + expected_size: int, +) -> List[generate_pb2.NextTokenChooserParameters]: + # disable all logits processors to minimize padding overhead + empty_parameters = generate_pb2.NextTokenChooserParameters( + temperature=1.0, + top_k=0, + top_p=1.0, + typical_p=1.0, + do_sample=False, + seed=0, + repetition_penalty=1.0, + frequency_penalty=0.0, + watermark=False, + grammar="", + grammar_type=0, + ) + parameters.extend( + [empty_parameters] * (expected_size - len(parameters)) + ) + return parameters + + class Sampling: def __init__(self, seed: int, device: str = "cpu"): self.generator = torch.Generator("cpu")