feat(server): support vectorized warpers in flash causal lm

This commit is contained in:
OlivierDehaene 2023-05-12 15:41:36 +02:00
parent 218c9adaa5
commit f9e3a3bb91
11 changed files with 630 additions and 215 deletions

View File

@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch.keys_head_dim_last = False
return batch

View File

@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
inputs = []

View File

@ -19,9 +19,8 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
HeterogeneousNextTokenChooser
)
tracer = trace.get_tracer(__name__)
@ -48,7 +47,7 @@ class FlashCausalLMBatch(Batch):
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch
input_lengths: List[int]
@ -56,7 +55,7 @@ class FlashCausalLMBatch(Batch):
read_offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
@ -72,10 +71,11 @@ class FlashCausalLMBatch(Batch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
position_ids = []
cu_seqlens = [0]
@ -87,13 +87,14 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
requests_idx_mapping = {}
next_token_choosers = []
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
max_tokens = 0
max_length = 0
# Parse batch
for i, r in enumerate(pb.requests):
@ -119,7 +120,7 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -130,11 +131,26 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
@ -154,8 +170,8 @@ class FlashCausalLMBatch(Batch):
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
next_token_choosers=next_token_choosers,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@ -176,31 +192,29 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(request_ids))
position_ids = self.position_ids.new_empty(len(request_ids))
# Used to index into tensors
indices = []
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
max_seqlen = 0
past_key_values = []
requests = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
@ -208,34 +222,27 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length = self.input_lengths[idx]
# Copy tensors (GPU)
input_ids[i] = self.input_ids[idx]
position_ids[i] = self.position_ids[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
# Slice from past
past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
self.past_key_values[:, self.cu_seqlens[idx]: self.cu_seqlens[idx + 1]]
)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length
max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if single_request:
@ -258,6 +265,12 @@ class FlashCausalLMBatch(Batch):
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
@ -276,7 +289,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@ -290,6 +303,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size = sum([len(b) for b in batches])
dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size)
@ -302,19 +316,19 @@ class FlashCausalLMBatch(Batch):
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
max_length = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -347,25 +361,54 @@ class FlashCausalLMBatch(Batch):
)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor
cumulative_batch_size += len(batch)
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device
)
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
@ -381,7 +424,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@ -438,14 +481,14 @@ class FlashCausalLM(Model):
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
@ -460,15 +503,16 @@ class FlashCausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
single_request = len(batch) == 1
if prefill and len(batch) == 1:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
)
else:
pre_allocate_past_size = None
@ -483,6 +527,17 @@ class FlashCausalLM(Model):
pre_allocate_past_size,
)
if prefill:
next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
)
else:
next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
)
if prefill:
if len(batch) > 1:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -493,15 +548,11 @@ class FlashCausalLM(Model):
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
next_input_ids = batch.input_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(len(batch))
else:
prefill_logprobs = None
next_input_ids = batch.input_ids
next_position_ids = batch.position_ids
next_token_logprobs = out.new_empty(len(batch))
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
@ -552,7 +603,6 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
@ -563,31 +613,15 @@ class FlashCausalLM(Model):
# For each member of the batch
for i, (
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
input_length,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
# only take last token logit
logits = out[end_index - 1 : end_index]
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor = batch.input_ids.new_empty(
input_length + stopping_criteria.max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor[:input_length] = batch.input_ids[
start_index:end_index
]
batch.all_input_ids_tensor.append(all_input_ids_tensor)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
@ -596,32 +630,13 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices
if len(batch) > 1:
prefill_tokens_indices[
start_index : end_index - 1
] = batch.input_ids[start_index + 1 : end_index]
start_index: end_index - 1
] = batch.input_ids[start_index + 1: end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].view(1, -1)
prefill_tokens_indices = batch.input_ids
all_input_ids_tensor = batch.all_input_ids_tensor[i]
# Select next token
next_token_id, logprob = next_token_chooser(
all_input_ids_tensor[None, :input_length], logits
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values
next_input_ids[i] = next_token_id_squeezed
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
cumulative_length += input_length
@ -651,10 +666,11 @@ class FlashCausalLM(Model):
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
next_token_ids,
next_token_logprobs,
)
@ -665,10 +681,11 @@ class FlashCausalLM(Model):
input_length,
prefix_offset,
read_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
do_sample,
seed,
next_token_id,
next_token_logprob,
) in enumerate(iterator):
@ -700,16 +717,13 @@ class FlashCausalLM(Model):
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
all_input_ids[-stopping_criteria.current_tokens:]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
@ -718,8 +732,8 @@ class FlashCausalLM(Model):
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
]
start_index: end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
@ -751,8 +765,9 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None

View File

@ -231,6 +231,7 @@ class FlashSantacoderSharded(FlashSantacoder):
device=device,
rank=rank,
world_size=world_size,
decode_buffer=1,
)
@staticmethod

View File

@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
inputs = []

View File

@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""

View File

@ -21,6 +21,7 @@ class Batch(ABC):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
raise NotImplementedError

View File

@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token(batch)

View File

@ -9,13 +9,13 @@ from text_generation_server.utils.hub import (
RevisionNotFoundError,
)
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
HeterogeneousNextTokenChooser,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
)
from text_generation_server.utils.logits_process import Sampling, Greedy
__all__ = [
"convert_file",
@ -25,6 +25,7 @@ __all__ = [
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",

View File

@ -0,0 +1,374 @@
import math
import torch
from functools import lru_cache
from typing import Optional, List
from transformers import (
LogitsWarper,
LogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
repetition_penalty (`List[float]`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
def filter(self, indices):
self.penalty = self.penalty[indices]
return self
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = torch.tensor(
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature)
return scores
def filter(self, indices):
self.temperature = self.temperature[indices]
return self
class HeterogeneousTopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_p: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p_opposite = 1 - torch.tensor(
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
def filter(self, indices):
self.top_p_opposite = self.top_p_opposite[indices]
return self
class HeterogeneousTopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_k: List[int],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_k = top_k
self.max_top_k = max(top_k)
# value - 1 as we will use top_k to index and python uses 0 based numbering
self.top_k_tensor = torch.tensor(
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
dtype=torch.int64,
device=device,
).unsqueeze(1)
# 0 is a special value that disables top_k warping for this member of the batch
disabled = [x == 0 for x in top_k]
if any(disabled):
self.top_k_disabled_mask = torch.tensor(
disabled, dtype=torch.bool, device=device
)
else:
self.top_k_disabled_mask = None
self.filter_value = filter_value
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
if scores.size(-1) < self.max_top_k:
max_top_k = scores.size(-1)
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
else:
max_top_k = self.max_top_k
top_k = self.top_k_tensor
# Get the kth score for each member of the batch
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
# Mask member of kth_scores that do not want to use top_k warping
if self.top_k_disabled_mask is not None:
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < kth_scores
scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
def filter(self, indices):
self.top_k_tensor = self.top_k_tensor[indices]
self.top_k = [self.top_k[i] for i in indices]
self.max_top_k = max(self.top_k)
return self
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
mass (`float`):
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
mass: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.filter_value = filter_value
self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
def filter(self, indices):
self.mass = self.mass[indices]
return self
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy_indices = []
self.sampling_mapping = {}
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
if sample:
self.sampling_mapping[i] = Sampling(seed, device)
else:
self.greedy_indices.append(i)
self.greedy = Greedy()
def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i])
return out
def filter(self, indices):
new_greedy_indices = []
new_sampling_mapping = {}
for i, idx in enumerate(indices):
if idx in self.sampling_mapping:
new_sampling_mapping[i] = self.sampling_mapping[idx]
else:
new_greedy_indices.append(i)
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self

View File

@ -1,110 +1,33 @@
import re
import torch
from functools import lru_cache
from transformers import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
PreTrainedTokenizerBase, LogitsProcessorList,
)
from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax()
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
from text_generation_server.utils import Sampling, Greedy
from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \
HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \
HeterogeneousTypicalLogitsWarper, HeterogeneousSampling
class NextTokenChooser:
def __init__(
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -116,10 +39,10 @@ class NextTokenChooser:
)
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
)
if has_warpers:
self.static_warper = static_warper(
@ -148,9 +71,9 @@ class NextTokenChooser:
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
@ -178,11 +101,11 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
@ -208,9 +131,9 @@ class StoppingCriteria:
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
@ -221,3 +144,99 @@ class StoppingCriteria:
pb.max_new_tokens,
pb.ignore_eos_token,
)
class HeterogeneousNextTokenChooser:
def __init__(
self,
dtype: torch.dtype,
device: torch.device,
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
):
warpers = LogitsProcessorList()
if any(watermark):
raise NotImplementedError("Watermarking not implemented")
if any([x != 1.0 for x in repetition_penalty]):
warpers.append(
HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, dtype, device
)
)
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any([x != 0 for x in top_k]):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, dtype, device))
if any([x < 1.0 for x in top_p]):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
self.warpers = warpers
num_do_sample = sum(do_sample)
if num_do_sample == 0:
self.choice = Greedy()
else:
self.choice = HeterogeneousSampling(do_sample, seeds, device)
self.seeds = seeds
self.do_sample = do_sample
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
last_token_scores = self.warpers(input_ids, scores)
next_ids = self.choice(last_token_scores)
next_logprobs = torch.gather(
torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1)
).view(-1)
return next_ids, next_logprobs
def filter(self, indices):
for warper in self.warpers:
warper.filter(indices)
if isinstance(self.choice, HeterogeneousSampling):
self.choice.filter(indices)
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
return self
@classmethod
def from_pb(
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
do_sample=[pb_.do_sample for pb_ in pb],
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
)