mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
feat(server): support vectorized warpers in flash causal lm
This commit is contained in:
parent
218c9adaa5
commit
f9e3a3bb91
@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||
pb=pb, tokenizer=tokenizer, device=device
|
||||
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
|
||||
)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
inputs = []
|
||||
|
@ -19,9 +19,8 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
Sampling,
|
||||
HeterogeneousNextTokenChooser
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -48,7 +47,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
all_input_ids_tensor: List[torch.Tensor]
|
||||
all_input_ids_tensor: torch.Tensor
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
@ -56,7 +55,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
@ -75,6 +74,7 @@ class FlashCausalLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
position_ids = []
|
||||
@ -87,13 +87,14 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
next_token_choosers = []
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
|
||||
# Parse batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
@ -119,7 +120,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.append(cumulative_length + input_length)
|
||||
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
@ -130,11 +131,26 @@ class FlashCausalLMBatch(Batch):
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
max_tokens += input_length + max_new_tokens
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
)
|
||||
|
||||
# Padded all_input_ids_tensor
|
||||
all_input_ids_tensor = np.zeros(
|
||||
(len(all_input_ids), max_length), dtype=np.int64
|
||||
)
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(
|
||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
||||
)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
position_ids = torch.tensor(
|
||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
||||
)
|
||||
@ -154,8 +170,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=[],
|
||||
next_token_choosers=next_token_choosers,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@ -176,31 +192,29 @@ class FlashCausalLMBatch(Batch):
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
|
||||
input_ids = self.input_ids.new_empty(len(request_ids))
|
||||
position_ids = self.position_ids.new_empty(len(request_ids))
|
||||
# Used to index into tensors
|
||||
indices = []
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
|
||||
max_seqlen = 0
|
||||
past_key_values = []
|
||||
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
max_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
requests_idx_mapping[request_id] = i
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
@ -208,28 +222,21 @@ class FlashCausalLMBatch(Batch):
|
||||
# Get length
|
||||
request_input_length = self.input_lengths[idx]
|
||||
|
||||
# Copy tensors (GPU)
|
||||
input_ids[i] = self.input_ids[idx]
|
||||
position_ids[i] = self.position_ids[idx]
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
cu_seqlens[i + 1] = cumulative_length + request_input_length
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
|
||||
# Slice from past
|
||||
past_key_values.append(
|
||||
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
|
||||
self.past_key_values[:, self.cu_seqlens[idx]: self.cu_seqlens[idx + 1]]
|
||||
)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||
|
||||
input_lengths.append(request_input_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
@ -258,6 +265,12 @@ class FlashCausalLMBatch(Batch):
|
||||
# Cat all past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
|
||||
|
||||
@ -276,7 +289,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@ -290,6 +303,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
total_batch_size = sum([len(b) for b in batches])
|
||||
|
||||
dtype = batches[0].past_key_values.dtype
|
||||
device = batches[0].input_ids.device
|
||||
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
@ -302,19 +316,19 @@ class FlashCausalLMBatch(Batch):
|
||||
past_key_values = []
|
||||
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
next_token_choosers = []
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_length = 0
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
@ -347,25 +361,54 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
# Update
|
||||
cumulative_length += batch.cu_seqlens[-1]
|
||||
cumulative_batch_size += len(batch)
|
||||
max_tokens += batch.max_tokens
|
||||
max_length = max(
|
||||
max_length,
|
||||
max(
|
||||
input_length
|
||||
+ stopping_criteria.max_new_tokens
|
||||
- stopping_criteria.current_tokens
|
||||
for input_length, stopping_criteria in zip(
|
||||
batch.input_lengths, batch.stopping_criterias
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
cumulative_batch_size = 0
|
||||
for i, batch in enumerate(batches):
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor
|
||||
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
# Cat past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
# Create final tensor on GPU
|
||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
@ -381,7 +424,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@ -463,6 +506,7 @@ class FlashCausalLM(Model):
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
prefill = batch.past_key_values is None
|
||||
single_request = len(batch) == 1
|
||||
|
||||
if prefill and len(batch) == 1:
|
||||
# Ask to pre-allocate kv to its max size
|
||||
@ -483,6 +527,17 @@ class FlashCausalLM(Model):
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
|
||||
)
|
||||
else:
|
||||
next_token_logits = out
|
||||
|
||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
@ -493,15 +548,11 @@ class FlashCausalLM(Model):
|
||||
batch.cu_seqlens_q = torch.arange(
|
||||
0, len(batch) + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
next_input_ids = batch.input_ids.new_empty(len(batch))
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_input_ids = batch.input_ids
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
next_token_logprobs = out.new_empty(len(batch))
|
||||
|
||||
# Prepare past for next decode
|
||||
if len(batch) > 1:
|
||||
# Used to slice next batch past
|
||||
@ -552,7 +603,6 @@ class FlashCausalLM(Model):
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.input_lengths,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
@ -564,7 +614,6 @@ class FlashCausalLM(Model):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
input_length,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
@ -573,21 +622,6 @@ class FlashCausalLM(Model):
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if prefill:
|
||||
# Prefill mode
|
||||
# out is of shape [cumulative_sequence_lengths, vocab_size]
|
||||
# only take last token logit
|
||||
logits = out[end_index - 1 : end_index]
|
||||
|
||||
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
|
||||
all_input_ids_tensor = batch.input_ids.new_empty(
|
||||
input_length + stopping_criteria.max_new_tokens
|
||||
)
|
||||
# Copy from batch.input_ids to all_input_ids_tensor
|
||||
all_input_ids_tensor[:input_length] = batch.input_ids[
|
||||
start_index:end_index
|
||||
]
|
||||
batch.all_input_ids_tensor.append(all_input_ids_tensor)
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
@ -596,32 +630,13 @@ class FlashCausalLM(Model):
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[
|
||||
start_index : end_index - 1
|
||||
] = batch.input_ids[start_index + 1 : end_index]
|
||||
start_index: end_index - 1
|
||||
] = batch.input_ids[start_index + 1: end_index]
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids[
|
||||
start_index + 1 : end_index
|
||||
]
|
||||
else:
|
||||
# Decode mode
|
||||
# out is of shape [batch_size, vocab_size]
|
||||
logits = out[i].view(1, -1)
|
||||
prefill_tokens_indices = batch.input_ids
|
||||
|
||||
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprob = next_token_chooser(
|
||||
all_input_ids_tensor[None, :input_length], logits
|
||||
)
|
||||
|
||||
# Add to all_input_ids_tensor
|
||||
next_token_id_squeezed = next_token_id.view(1)
|
||||
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
||||
|
||||
# Set values
|
||||
next_input_ids[i] = next_token_id_squeezed
|
||||
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
|
||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||
|
||||
cumulative_length += input_length
|
||||
|
||||
@ -651,10 +666,11 @@ class FlashCausalLM(Model):
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.all_input_ids_tensor,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
)
|
||||
@ -665,10 +681,11 @@ class FlashCausalLM(Model):
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
all_input_ids_tensor,
|
||||
do_sample,
|
||||
seed,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
) in enumerate(iterator):
|
||||
@ -700,16 +717,13 @@ class FlashCausalLM(Model):
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :]
|
||||
all_input_ids[-stopping_criteria.current_tokens:]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
output_text,
|
||||
stopping_criteria.current_tokens,
|
||||
reason,
|
||||
seed if do_sample else None,
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
@ -718,7 +732,7 @@ class FlashCausalLM(Model):
|
||||
if prefill:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||
start_index : end_index - 1
|
||||
start_index: end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
@ -751,8 +765,9 @@ class FlashCausalLM(Model):
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
cumulative_length += input_length
|
||||
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, batch if not stopped else None
|
||||
|
@ -231,6 +231,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
decode_buffer=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "GalacticaCausalLMBatch":
|
||||
inputs = []
|
||||
|
@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "Seq2SeqLMBatch":
|
||||
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
|
@ -21,6 +21,7 @@ class Batch(ABC):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.device
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
|
@ -9,13 +9,13 @@ from text_generation_server.utils.hub import (
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
from text_generation_server.utils.tokens import (
|
||||
Greedy,
|
||||
NextTokenChooser,
|
||||
Sampling,
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
StopSequenceCriteria,
|
||||
FinishReason,
|
||||
)
|
||||
from text_generation_server.utils.logits_process import Sampling, Greedy
|
||||
|
||||
__all__ = [
|
||||
"convert_file",
|
||||
@ -25,6 +25,7 @@ __all__ = [
|
||||
"weight_hub_files",
|
||||
"download_weights",
|
||||
"EntryNotFoundError",
|
||||
"HeterogeneousNextTokenChooser",
|
||||
"LocalEntryNotFoundError",
|
||||
"RevisionNotFoundError",
|
||||
"Greedy",
|
||||
|
374
server/text_generation_server/utils/logits_process.py
Normal file
374
server/text_generation_server/utils/logits_process.py
Normal file
@ -0,0 +1,374 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional, List
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
return probs.div_(q).argmax()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
):
|
||||
self.warpers = []
|
||||
|
||||
if temperature is not None and temperature != 1.0:
|
||||
temperature = float(temperature)
|
||||
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||
|
||||
self.cuda_graph = None
|
||||
self.static_scores = None
|
||||
self.static_warped_scores = None
|
||||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def static_warper(
|
||||
temperature: Optional[float],
|
||||
top_k: Optional[int],
|
||||
top_p: Optional[float],
|
||||
typical_p: Optional[float],
|
||||
) -> StaticWarper:
|
||||
return StaticWarper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
|
||||
|
||||
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
repetition_penalty (`List[float]`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||
self.penalty = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.penalty = self.penalty[indices]
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousTemperatureLogitsWarper:
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
||||
):
|
||||
self.temperature = torch.tensor(
|
||||
temperature, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores.div_(self.temperature)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.temperature = self.temperature[indices]
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_p: List[float],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.top_p_opposite = 1 - torch.tensor(
|
||||
top_p, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.top_p_opposite = self.top_p_opposite[indices]
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: List[int],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.top_k = top_k
|
||||
self.max_top_k = max(top_k)
|
||||
# value - 1 as we will use top_k to index and python uses 0 based numbering
|
||||
self.top_k_tensor = torch.tensor(
|
||||
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
).unsqueeze(1)
|
||||
|
||||
# 0 is a special value that disables top_k warping for this member of the batch
|
||||
disabled = [x == 0 for x in top_k]
|
||||
|
||||
if any(disabled):
|
||||
self.top_k_disabled_mask = torch.tensor(
|
||||
disabled, dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
self.top_k_disabled_mask = None
|
||||
|
||||
self.filter_value = filter_value
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
|
||||
if scores.size(-1) < self.max_top_k:
|
||||
max_top_k = scores.size(-1)
|
||||
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
|
||||
else:
|
||||
max_top_k = self.max_top_k
|
||||
top_k = self.top_k_tensor
|
||||
|
||||
# Get the kth score for each member of the batch
|
||||
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||
|
||||
# Mask member of kth_scores that do not want to use top_k warping
|
||||
if self.top_k_disabled_mask is not None:
|
||||
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
|
||||
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < kth_scores
|
||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.top_k_tensor = self.top_k_tensor[indices]
|
||||
self.top_k = [self.top_k[i] for i in indices]
|
||||
self.max_top_k = max(self.top_k)
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
mass (`float`):
|
||||
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mass: List[float],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.filter_value = filter_value
|
||||
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||
1, last_ind.view(-1, 1)
|
||||
)
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
|
||||
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.mass = self.mass[indices]
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
"""
|
||||
|
||||
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||
self.seeds = seeds
|
||||
|
||||
self.greedy_indices = []
|
||||
self.sampling_mapping = {}
|
||||
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
||||
if sample:
|
||||
self.sampling_mapping[i] = Sampling(seed, device)
|
||||
else:
|
||||
self.greedy_indices.append(i)
|
||||
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
|
||||
|
||||
for i, sampling in self.sampling_mapping.items():
|
||||
out[i] = sampling(logits[i])
|
||||
return out
|
||||
|
||||
def filter(self, indices):
|
||||
new_greedy_indices = []
|
||||
new_sampling_mapping = {}
|
||||
for i, idx in enumerate(indices):
|
||||
if idx in self.sampling_mapping:
|
||||
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
||||
else:
|
||||
new_greedy_indices.append(i)
|
||||
|
||||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
||||
|
@ -1,96 +1,19 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
from functools import lru_cache
|
||||
from transformers import (
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerBase, LogitsProcessorList,
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
return probs.div_(q).argmax()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax()
|
||||
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
):
|
||||
self.warpers = []
|
||||
|
||||
if temperature is not None and temperature != 1.0:
|
||||
temperature = float(temperature)
|
||||
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||
|
||||
self.cuda_graph = None
|
||||
self.static_scores = None
|
||||
self.static_warped_scores = None
|
||||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def static_warper(
|
||||
temperature: Optional[float],
|
||||
top_k: Optional[int],
|
||||
top_p: Optional[float],
|
||||
typical_p: Optional[float],
|
||||
) -> StaticWarper:
|
||||
return StaticWarper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
from text_generation_server.utils import Sampling, Greedy
|
||||
from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \
|
||||
HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \
|
||||
HeterogeneousTypicalLogitsWarper, HeterogeneousSampling
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -221,3 +144,99 @@ class StoppingCriteria:
|
||||
pb.max_new_tokens,
|
||||
pb.ignore_eos_token,
|
||||
)
|
||||
|
||||
|
||||
class HeterogeneousNextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
watermark: List[bool],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
top_k: List[int],
|
||||
top_p: List[float],
|
||||
typical_p: List[float],
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
):
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
if any(watermark):
|
||||
raise NotImplementedError("Watermarking not implemented")
|
||||
|
||||
if any([x != 1.0 for x in repetition_penalty]):
|
||||
warpers.append(
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
||||
repetition_penalty, dtype, device
|
||||
)
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
]
|
||||
warpers.append(
|
||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||
)
|
||||
|
||||
if any([x != 0 for x in top_k]):
|
||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, dtype, device))
|
||||
|
||||
if any([x < 1.0 for x in top_p]):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||
|
||||
if any([x < 1.0 for x in typical_p]):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||
|
||||
self.warpers = warpers
|
||||
|
||||
num_do_sample = sum(do_sample)
|
||||
if num_do_sample == 0:
|
||||
self.choice = Greedy()
|
||||
else:
|
||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||
|
||||
self.seeds = seeds
|
||||
self.do_sample = do_sample
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||
last_token_scores = self.warpers(input_ids, scores)
|
||||
next_ids = self.choice(last_token_scores)
|
||||
next_logprobs = torch.gather(
|
||||
torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1)
|
||||
).view(-1)
|
||||
|
||||
return next_ids, next_logprobs
|
||||
|
||||
def filter(self, indices):
|
||||
for warper in self.warpers:
|
||||
warper.filter(indices)
|
||||
if isinstance(self.choice, HeterogeneousSampling):
|
||||
self.choice.filter(indices)
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
top_k=[pb_.top_k for pb_ in pb],
|
||||
top_p=[pb_.top_p for pb_ in pb],
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
do_sample=[pb_.do_sample for pb_ in pb],
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user