diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 08e1c443..faa94516 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -16,20 +16,6 @@ from transformers import ( mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -class CUDAGraphWrapper: - def __init__(self): - self.cuda_graph = None - self.static_tensors = {} - - -# We want the LRU to be as big as possible as creating cuda graphs is expensive. However, each graph holds a tiny -# bit of GPU memory, so we still need to be careful -@lru_cache(512) -def get_cuda_graph_wrapper(warper_name, batch_size): - """warper_name and batch_size are only used as keys""" - return CUDAGraphWrapper() - - class StaticWarper: def __init__( self, @@ -102,20 +88,28 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): """ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): - self.penalty = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) + self.penalty = penalty + self.penalty_tensor = torch.tensor( + penalty, dtype=dtype, device=device + ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * self.penalty, score / self.penalty) + score = torch.where( + score < 0, score * self.penalty_tensor, score / self.penalty_tensor + ) scores.scatter_(1, input_ids, score) return scores def filter(self, indices): - self.penalty = self.penalty[indices] - return self + self.penalty = [self.penalty[i] for i in indices] + if any([x != 1.0 for x in self.penalty]): + self.penalty_tensor = self.penalty_tensor[indices] + return self + return None class HeterogeneousTemperatureLogitsWarper: @@ -132,17 +126,21 @@ class HeterogeneousTemperatureLogitsWarper: def __init__( self, temperature: List[float], dtype: torch.dtype, device: torch.device ): - self.temperature = torch.tensor( + self.temperature = temperature + self.temperature_tensor = torch.tensor( temperature, dtype=dtype, device=device ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - scores.div_(self.temperature) + scores.div_(self.temperature_tensor) return scores def filter(self, indices): - self.temperature = self.temperature[indices] - return self + self.temperature = [self.temperature[i] for i in indices] + if any([x != 1.0 for x in self.temperature]): + self.temperature_tensor = self.temperature_tensor[indices] + return self + return None class HeterogeneousTopPLogitsWarper(LogitsWarper): @@ -169,62 +167,40 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): filter_value: float = -math.inf, min_tokens_to_keep: int = 1, ): + self.top_p = top_p self.top_p_opposite = 1 - torch.tensor( top_p, dtype=dtype, device=device ).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - self.cuda_graph_wrapper = get_cuda_graph_wrapper("top_p_warper", len(top_p)) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - if self.cuda_graph_wrapper.cuda_graph is None: - self.cuda_graph_wrapper.static_tensors["scores"] = scores - self.cuda_graph_wrapper.static_tensors[ - "top_p_opposite" - ] = self.top_p_opposite + sorted_logits, sorted_indices = torch.sort(scores, descending=False) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) - self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph() + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = probs <= self.top_p_opposite + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 - with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool): - local_scores = self.cuda_graph_wrapper.static_tensors["scores"] - - sorted_logits, sorted_indices = torch.sort( - local_scores, descending=False - ) - probs = sorted_logits.softmax(dim=-1) - # This is way faster for some reason - for i in range(probs.shape[0]): - probs[i] = probs[i].cumsum(dim=-1) - - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = probs <= self.top_p_opposite - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) - local_scores = local_scores.masked_fill_( - indices_to_remove, self.filter_value - ) - self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores - - self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores) - self.cuda_graph_wrapper.static_tensors["top_p_opposite"].copy_( - self.top_p_opposite + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove ) - self.cuda_graph_wrapper.cuda_graph.replay() + warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) - return self.cuda_graph_wrapper.static_tensors["warped_scores"] + return warped_scores def filter(self, indices): - self.top_p_opposite = self.top_p_opposite[indices] - self.cuda_graph_wrapper = get_cuda_graph_wrapper( - "top_p_warper", len(self.top_p_opposite) - ) - return self + self.top_p = [self.top_p[i] for i in indices] + if any([x < 1.0 for x in self.top_p]): + self.top_p_opposite = self.top_p_opposite[indices] + return self + return None class HeterogeneousTopKLogitsWarper(LogitsWarper): @@ -264,7 +240,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): if any(disabled): self.top_k_disabled_mask = torch.tensor( disabled, dtype=torch.bool, device=device - ) + ).view(-1, 1) else: self.top_k_disabled_mask = None @@ -292,10 +268,20 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): return scores def filter(self, indices): - self.top_k_tensor = self.top_k_tensor[indices] self.top_k = [self.top_k[i] for i in indices] - self.max_top_k = max(self.top_k) - return self + disabled = [x == 0 for x in self.top_k] + + if not all(disabled): + self.top_k_tensor = self.top_k_tensor[indices] + self.max_top_k = max(self.top_k) + + if self.top_k_disabled_mask is not None: + self.top_k_disabled_mask = ( + self.top_k_disabled_mask[indices] if any(disabled) else None + ) + + return self + return None class HeterogeneousTypicalLogitsWarper(LogitsWarper): @@ -322,67 +308,70 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): filter_value: float = -math.inf, min_tokens_to_keep: int = 1, ): + self.mass = mass + self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) + + # 1 is a special value that disables typical_p warping for this member of the batch + disabled = [x == 1.0 for x in mass] + + if any(disabled): + self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device) + else: + self.disabled_mask = None + self.filter_value = filter_value - self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) self.min_tokens_to_keep = min_tokens_to_keep - self.cuda_graph_wrapper = get_cuda_graph_wrapper("typical_p_warper", len(mass)) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - if self.cuda_graph_wrapper.cuda_graph is None: - self.cuda_graph_wrapper.static_tensors["scores"] = scores - self.cuda_graph_wrapper.static_tensors["mass"] = self.mass + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) - self.cuda_graph_wrapper.cuda_graph = torch.cuda.CUDAGraph() + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) - with torch.cuda.graph(self.cuda_graph_wrapper.cuda_graph, pool=mempool): - local_scores = self.cuda_graph_wrapper.static_tensors["scores"] + # Remove tokens with cumulative mass above the threshold + last_ind = (probs < self.mass_tensor).sum(dim=1) + last_ind[last_ind < 0] = 0 - # calculate entropy - normalized = torch.nn.functional.log_softmax(local_scores, dim=-1) - p = torch.exp(normalized) - ent = -(normalized * p).nansum(-1, keepdim=True) + if self.disabled_mask is not None: + last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) - # shift and sort - shifted_scores = torch.abs((-normalized) - ent) - sorted_scores, sorted_indices = torch.sort( - shifted_scores, descending=False - ) - sorted_logits = local_scores.gather(-1, sorted_indices) - probs = sorted_logits.softmax(dim=-1) - # This is way faster for some reason - for i in range(probs.shape[0]): - probs[i] = probs[i].cumsum(dim=-1) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) - # Remove tokens with cumulative mass above the threshold - last_ind = (probs < self.mass).sum(dim=1) - last_ind[last_ind < 0] = 0 - sorted_indices_to_remove = sorted_scores > sorted_scores.gather( - 1, last_ind.view(-1, 1) - ) - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) + warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) - local_scores = local_scores.masked_fill_( - indices_to_remove, self.filter_value - ) - self.cuda_graph_wrapper.static_tensors["warped_scores"] = local_scores - - self.cuda_graph_wrapper.static_tensors["scores"].copy_(scores) - self.cuda_graph_wrapper.static_tensors["mass"].copy_(self.mass) - self.cuda_graph_wrapper.cuda_graph.replay() - - return self.cuda_graph_wrapper.static_tensors["warped_scores"] + return warped_scores def filter(self, indices): - self.mass = self.mass[indices] - self.cuda_graph_wrapper = get_cuda_graph_wrapper( - "typical_p_warper", len(self.mass) - ) - return self + self.mass = [self.mass[i] for i in indices] + disabled = [x == 1.0 for x in self.mass] + + if not all(disabled): + self.mass_tensor = self.mass_tensor[indices] + + if self.disabled_mask is not None: + self.disabled_mask = ( + self.disabled_mask[indices] if any(disabled) else None + ) + + return self + return None class HeterogeneousProcessorWrapper(LogitsProcessor): @@ -410,5 +399,7 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): if idx in self.processors: new_processors[i] = self.processors[idx] - self.processors = new_processors - return self + if new_processors: + self.processors = new_processors + return self + return None diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 99eba93b..e6e512bc 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -60,9 +60,9 @@ class NextTokenChooser: self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): - if self.watermark_processor: + if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor: + if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) if self.static_warper is None: @@ -209,19 +209,18 @@ class HeterogeneousNextTokenChooser: self.warpers = warpers - num_do_sample = sum(do_sample) - if num_do_sample == 0: - self.choice = Greedy() - else: + if any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) + else: + self.choice = Greedy() self.seeds = seeds self.do_sample = do_sample def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): - if self.watermark_processor: + if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor: + if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) for warper in self.warpers: @@ -235,12 +234,27 @@ class HeterogeneousNextTokenChooser: return next_ids, next_logprobs def filter(self, indices): + if self.watermark_processor is not None: + self.watermark_processor = self.watermark_processor.filter(indices) + + if self.repetition_processor is not None: + self.repetition_processor = self.repetition_processor.filter(indices) + + filtered_warpers = [] for warper in self.warpers: - warper.filter(indices) - if isinstance(self.choice, HeterogeneousSampling): - self.choice.filter(indices) + filtered_warper = warper.filter(indices) + if filtered_warper is not None: + filtered_warpers.append(filtered_warper) + self.warpers = filtered_warpers + self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + + if any(self.do_sample): + self.choice.filter(indices) + else: + self.choice = Greedy() + return self @classmethod