diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index a6bec102..cde0f228 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -174,19 +174,16 @@ def longest_match(input_ids: List[int]) -> Optional[int]: def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int): - from loguru import logger B = accepted_ids.shape[0] device = input_ids.device dtype = input_ids.dtype speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) - logger.info(f"{input_ids.shape} ") cpu_input_ids = input_ids.tolist() cpu_next_ids = next_ids.tolist() index = 0 for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())): stop = len(_input_ids) - # TODO 0 is not necessarily the pad token. # Remove zero padded end. for j, _id in enumerate(_input_ids): if _id == 0: @@ -339,13 +336,7 @@ class HeterogeneousNextTokenChooser: speculative_ids = Greedy()(speculative_scores) else: # n-gram - import datetime - start = datetime.datetime.now() - # if input_ids.shape[0] > 4: - # import ipdb;ipdb.set_trace() speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate) - from loguru import logger - logger.info(f"n gram took {datetime.datetime.now() - start}") else: speculative_ids = None