diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json index 05b9a365..b2068446 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -40,7 +40,7 @@ }, { "id": 2772, - "logprob": 0.0, + "logprob": -0.23083496, "special": false, "text": "De" }, @@ -57,42 +57,42 @@ "text": " learning" }, { - "id": 508, - "logprob": -1.5087891, + "id": 756, + "logprob": -0.48095703, "special": false, - "text": " can" + "text": " has" }, { - "id": 367, + "id": 19479, "logprob": 0.0, "special": false, - "text": " be" + "text": " revolution" }, { - "id": 2714, - "logprob": -0.6538086, + "id": 1891, + "logprob": 0.0, "special": false, - "text": " thought" + "text": "ized" + }, + { + "id": 278, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 1746, + "logprob": 0.0, + "special": false, + "text": " field" }, { "id": 310, "logprob": 0.0, "special": false, "text": " of" - }, - { - "id": 408, - "logprob": 0.0, - "special": false, - "text": " as" - }, - { - "id": 263, - "logprob": 0.0, - "special": false, - "text": " a" } ] }, - "generated_text": "What is Deep Learning?\nDeep learning can be thought of as a" + "generated_text": "What is Deep Learning?\nDeep learning has revolutionized the field of" } diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index cde0f228..e3457483 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -16,6 +16,8 @@ from text_generation_server.utils.logits_process import ( from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from loguru import logger + class NextTokenChooser: def __init__( self, @@ -145,57 +147,20 @@ class StoppingCriteria: pb.ignore_eos_token, ) - -def longest_match(input_ids: List[int]) -> Optional[int]: - longest_match = 0 - seed = input_ids[-1] - final_matches = [] - current_matches = [] - for i in range(1, len(input_ids)): - index = len(input_ids) - i - 1 - - _current_matches = [] - for (_index, length) in current_matches: - if input_ids[index] == input_ids[len(input_ids) - length - 1]: - _current_matches.append((_index, length + 1)) - elif length > longest_match: - longest_match = length - final_matches.append((_index, length)) - else: - pass - current_matches = _current_matches - - if input_ids[index] == seed: - current_matches.append( (index, 1) ) - if not final_matches: - return 0 - return final_matches[-1][0] - - - -def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int): +def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool): + # import datetime + # start = datetime.datetime.now() B = accepted_ids.shape[0] device = input_ids.device dtype = input_ids.dtype - speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) - cpu_input_ids = input_ids.tolist() - cpu_next_ids = next_ids.tolist() + # speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) + seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] + indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 + all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) + all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) - index = 0 - for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())): - stop = len(_input_ids) - # Remove zero padded end. - for j, _id in enumerate(_input_ids): - if _id == 0: - stop = j - break - _input_ids = _input_ids[:stop] - _input_ids.extend(cpu_next_ids[index: index+n_accepted_ids]) - index = longest_match(_input_ids) + 1 - slice_ = input_ids[i, index:index+speculate] - # logger.info(f"{slice_.shape} - {speculative_ids.shape}") - speculative_ids[i, :len(slice_)] = slice_ - index += n_accepted_ids + # logger.info(f"All indices {all_indices} - {input_ids.shape}") + speculative_ids = input_ids.gather(dim=-1, index=all_indices) return speculative_ids class HeterogeneousNextTokenChooser: @@ -266,7 +231,11 @@ class HeterogeneousNextTokenChooser: self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None): + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): + import datetime + # from loguru import logger + + start = datetime.datetime.now() if speculated_ids is not None: B = scores.shape[0] // (speculated_ids.shape[1] + 1) S = speculated_ids.shape[1] + 1 @@ -276,8 +245,11 @@ class HeterogeneousNextTokenChooser: S = 1 scores = scores.view(B, S, -1) + # if verbose: + # logger.info(f"Reshape {datetime.datetime.now() - start}") + all_next_ids = [] - all_scores = [] + next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) for j in range(S): _scores = scores[:, j] if self.watermark_processor is not None: @@ -289,11 +261,13 @@ class HeterogeneousNextTokenChooser: _scores = warper(input_ids, _scores) - next_ids = self.choice(_scores) + _next_ids = self.choice(_scores) scores[:, j] = _scores - all_next_ids.append(next_ids.unsqueeze(1)) - next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S) + next_ids[:, j] = _next_ids + next_ids = next_ids.view(B*S) scores = scores.view( B* S, -1) + # if verbose: + # logger.info(f"Scores {datetime.datetime.now() - start}") if speculated_ids is not None: accepted_ids = [] @@ -325,20 +299,23 @@ class HeterogeneousNextTokenChooser: speculative_scores = speculative_scores[indices + accepted_ids - 1] else: accepted_ids = torch.ones_like(next_ids) + # if verbose: + # logger.info(f"Indices/accepted id {datetime.datetime.now() - start}") logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) if speculate > 0: if speculative_scores is not None: - # TODO This will only speculate the top score # Medusa provided some scores speculative_ids = Greedy()(speculative_scores) else: # n-gram - speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate) + speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) else: speculative_ids = None + # if verbose: + # logger.info(f"new speculative ids {datetime.datetime.now() - start}") return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids