Pad next token chooser parameters with empty logits processors (#151)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-05-29 22:43:56 +02:00 committed by GitHub
parent 1023de8048
commit 7b879fd1d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 7 deletions

View File

@ -46,6 +46,7 @@ from text_generation_server.utils import (
StoppingCriteria, StoppingCriteria,
make_tokenizer_optional, make_tokenizer_optional,
is_tokenizer_transparent, is_tokenizer_transparent,
pad_next_token_chooser_parameters,
) )
from text_generation_server.utils.debug import dbg_trace from text_generation_server.utils.debug import dbg_trace
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
@ -399,10 +400,9 @@ class CausalLMBatch(Batch):
parameters = [r.data.parameters for r in flat_requests] parameters = [r.data.parameters for r in flat_requests]
# append the dummy parameters for dummy requests # append the dummy parameters for dummy requests
batch_size = batches[dst_batch_idx].batch_size batch_size = batches[dst_batch_idx].batch_size
parameters.extend( parameters = pad_next_token_chooser_parameters(parameters, batch_size)
[generate_pb2.NextTokenChooserParameters()] * (batch_size - len(flat_requests))
)
# update past grammar states
fsm_grammar_states = [0] * batch_size fsm_grammar_states = [0] * batch_size
for batch in batches: for batch in batches:
for i, req in enumerate(batch.requests): for i, req in enumerate(batch.requests):
@ -465,9 +465,7 @@ class CausalLMBatch(Batch):
dummy_inputs = ["?"] * missing_inputs dummy_inputs = ["?"] * missing_inputs
parameters = [r.parameters for r in pb.requests] parameters = [r.parameters for r in pb.requests]
# append the dummy parameters for dummy request # append the dummy parameters for dummy request
parameters.extend( parameters = pad_next_token_chooser_parameters(parameters, new_bs)
[generate_pb2.NextTokenChooserParameters()] * missing_inputs
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
pb=parameters, pb=parameters,

View File

@ -22,7 +22,8 @@ from text_generation_server.utils.tokens import (
Sampling, Sampling,
Greedy, Greedy,
make_tokenizer_optional, make_tokenizer_optional,
is_tokenizer_transparent is_tokenizer_transparent,
pad_next_token_chooser_parameters,
) )
__all__ = [ __all__ = [

View File

@ -504,6 +504,30 @@ class HeterogeneousNextTokenChooser:
) )
def pad_next_token_chooser_parameters(
parameters: List[generate_pb2.NextTokenChooserParameters],
expected_size: int,
) -> List[generate_pb2.NextTokenChooserParameters]:
# disable all logits processors to minimize padding overhead
empty_parameters = generate_pb2.NextTokenChooserParameters(
temperature=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
seed=0,
repetition_penalty=1.0,
frequency_penalty=0.0,
watermark=False,
grammar="",
grammar_type=0,
)
parameters.extend(
[empty_parameters] * (expected_size - len(parameters))
)
return parameters
class Sampling: class Sampling:
def __init__(self, seed: int, device: str = "cpu"): def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator("cpu") self.generator = torch.Generator("cpu")