mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Pad next token chooser parameters with empty logits processors (#151)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
1023de8048
commit
7b879fd1d8
@ -46,6 +46,7 @@ from text_generation_server.utils import (
|
||||
StoppingCriteria,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent,
|
||||
pad_next_token_chooser_parameters,
|
||||
)
|
||||
from text_generation_server.utils.debug import dbg_trace
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
@ -399,10 +400,9 @@ class CausalLMBatch(Batch):
|
||||
parameters = [r.data.parameters for r in flat_requests]
|
||||
# append the dummy parameters for dummy requests
|
||||
batch_size = batches[dst_batch_idx].batch_size
|
||||
parameters.extend(
|
||||
[generate_pb2.NextTokenChooserParameters()] * (batch_size - len(flat_requests))
|
||||
)
|
||||
parameters = pad_next_token_chooser_parameters(parameters, batch_size)
|
||||
|
||||
# update past grammar states
|
||||
fsm_grammar_states = [0] * batch_size
|
||||
for batch in batches:
|
||||
for i, req in enumerate(batch.requests):
|
||||
@ -465,9 +465,7 @@ class CausalLMBatch(Batch):
|
||||
dummy_inputs = ["?"] * missing_inputs
|
||||
parameters = [r.parameters for r in pb.requests]
|
||||
# append the dummy parameters for dummy request
|
||||
parameters.extend(
|
||||
[generate_pb2.NextTokenChooserParameters()] * missing_inputs
|
||||
)
|
||||
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
pb=parameters,
|
||||
|
@ -22,7 +22,8 @@ from text_generation_server.utils.tokens import (
|
||||
Sampling,
|
||||
Greedy,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent
|
||||
is_tokenizer_transparent,
|
||||
pad_next_token_chooser_parameters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -504,6 +504,30 @@ class HeterogeneousNextTokenChooser:
|
||||
)
|
||||
|
||||
|
||||
def pad_next_token_chooser_parameters(
|
||||
parameters: List[generate_pb2.NextTokenChooserParameters],
|
||||
expected_size: int,
|
||||
) -> List[generate_pb2.NextTokenChooserParameters]:
|
||||
# disable all logits processors to minimize padding overhead
|
||||
empty_parameters = generate_pb2.NextTokenChooserParameters(
|
||||
temperature=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
typical_p=1.0,
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
repetition_penalty=1.0,
|
||||
frequency_penalty=0.0,
|
||||
watermark=False,
|
||||
grammar="",
|
||||
grammar_type=0,
|
||||
)
|
||||
parameters.extend(
|
||||
[empty_parameters] * (expected_size - len(parameters))
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator("cpu")
|
||||
|
Loading…
Reference in New Issue
Block a user