diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 65f9b4dd..590ba557 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -286,7 +286,9 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) for _ in range( default_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 43676ea2..3f28f5b3 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -289,7 +289,9 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) for _ in range( default_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index e043a5e4..a3199d02 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -325,7 +325,9 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b04e77b2..bf61322b 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -129,6 +129,18 @@ class HeterogeneousTemperatureLogitsWarper: return self +class CUDAGraphWrapper: + def __init__(self): + self.cuda_graph = None + self.static_tensors = {} + + +@lru_cache(512) +def get_cuda_graph_wrapper(warper_name, batch_size): + """warper_name and batch_size are only used as keys""" + return CUDAGraphWrapper() + + class HeterogeneousTopPLogitsWarper(LogitsWarper): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. @@ -158,26 +170,56 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): ).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep + self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p)) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - sorted_logits, sorted_indices = torch.sort(scores, descending=False) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + if self.cuda_graph_wrapper.cuda_graph is None: + self.cuda_graph_wrapper.static_tensors["scores"] = scores + self.cuda_graph_wrapper.static_tensors[ + "top_p_opposite" + ] = self.top_p_opposite - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph() - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove + with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph): + local_scores = self.cuda_graph_wrapper.static_tensors["scores"] + + sorted_logits, sorted_indices = torch.sort( + local_scores, descending=False + ) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = probs <= self.top_p_opposite + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + local_scores = local_scores.masked_fill_( + indices_to_remove, self.filter_value + ) + self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores + + self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores) + self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_( + self.top_p_opposite ) - scores.masked_fill_(indices_to_remove, self.filter_value) - return scores + self.cuda_graph_wrapper.cuda_graph.replay() + + return self.cuda_graph_wrapper.static_tensors["warped_scores"] def filter(self, indices): self.top_p_opposite = self.top_p_opposite[indices] + self.cuda_graph_wrapper = get_cuda_graph_wrapper( + "top_p_warper", len(self.top_p_opposite) + ) return self @@ -279,37 +321,63 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): self.filter_value = filter_value self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) self.min_tokens_to_keep = min_tokens_to_keep + self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass)) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # calculate entropy - normalized = torch.nn.functional.log_softmax(scores, dim=-1) - p = torch.exp(normalized) - ent = -(normalized * p).nansum(-1, keepdim=True) + if self.cuda_graph_wrapper.cuda_graph is None: + self.cuda_graph_wrapper.static_tensors["scores"] = scores + self.cuda_graph_wrapper.static_tensors["mass"] = self.mass - # shift and sort - shifted_scores = torch.abs((-normalized) - ent) - sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) - sorted_logits = scores.gather(-1, sorted_indices) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph() - # Remove tokens with cumulative mass above the threshold - last_ind = (cumulative_probs < self.mass).sum(dim=1) - last_ind[last_ind < 0] = 0 - sorted_indices_to_remove = sorted_scores > sorted_scores.gather( - 1, last_ind.view(-1, 1) - ) - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) + with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph): + local_scores = self.cuda_graph_wrapper.static_tensors["scores"] - scores = scores.masked_fill_(indices_to_remove, self.filter_value) - return scores + # calculate entropy + normalized = torch.nn.functional.log_softmax(local_scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort( + shifted_scores, descending=False + ) + sorted_logits = local_scores.gather(-1, sorted_indices) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + + local_scores = local_scores.masked_fill_( + indices_to_remove, self.filter_value + ) + self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores + + self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores) + self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass) + self.cuda_graph_wrapper.cuda_graph.replay() + + return self.cuda_graph_wrapper.static_tensors["warped_scores"] def filter(self, indices): self.mass = self.mass[indices] + self.cuda_graph_wrapper = get_cuda_graph_wrapper( + "typical_p_warper", len(self.mass) + ) return self diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 371a9ee4..7f2be8c6 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -4,7 +4,6 @@ import torch from transformers import ( RepetitionPenaltyLogitsProcessor, PreTrainedTokenizerBase, - LogitsProcessorList, ) from typing import List, Tuple, Optional @@ -166,25 +165,27 @@ class HeterogeneousNextTokenChooser: do_sample: List[bool], seeds: List[int], ): - warpers = LogitsProcessorList() + warpers = [] - if any(watermark): - warpers.append( - HeterogeneousProcessorWrapper( - { - i: WatermarkLogitsProcessor(device=device) - for i, do_watermark in enumerate(watermark) - if do_watermark - } - ) + self.watermark_processor = ( + HeterogeneousProcessorWrapper( + { + i: WatermarkLogitsProcessor(device=device) + for i, do_watermark in enumerate(watermark) + if do_watermark + } ) + if any(watermark) + else None + ) - if any([x != 1.0 for x in repetition_penalty]): - warpers.append( - HeterogeneousRepetitionPenaltyLogitsProcessor( - repetition_penalty, dtype, device - ) + self.repetition_processor = ( + HeterogeneousRepetitionPenaltyLogitsProcessor( + repetition_penalty, dtype, device ) + if any([x != 1.0 for x in repetition_penalty]) + else None + ) if any([x != 1.0 for x in temperature]): do_sample = [ @@ -217,11 +218,22 @@ class HeterogeneousNextTokenChooser: self.seeds = seeds self.do_sample = do_sample + self.cuda_graph = None + self.static_scores = None + self.static_warped_scores = None + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): - last_token_scores = self.warpers(input_ids, scores) - next_ids = self.choice(last_token_scores) + if self.watermark_processor: + scores = self.watermark_processor(input_ids, scores) + if self.repetition_processor: + scores = self.repetition_processor(input_ids, scores) + + for warper in self.warpers: + scores = warper(input_ids, scores) + + next_ids = self.choice(scores) next_logprobs = torch.gather( - torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1) + torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) ).view(-1) return next_ids, next_logprobs