Remove ngram debug code

This commit is contained in:
Nicolas Patry 2023-12-06 10:05:11 +00:00
parent b3c1492be1
commit 3a8b1923db

View File

@ -174,19 +174,16 @@ def longest_match(input_ids: List[int]) -> Optional[int]:
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int): def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
from loguru import logger
B = accepted_ids.shape[0] B = accepted_ids.shape[0]
device = input_ids.device device = input_ids.device
dtype = input_ids.dtype dtype = input_ids.dtype
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
logger.info(f"{input_ids.shape} ")
cpu_input_ids = input_ids.tolist() cpu_input_ids = input_ids.tolist()
cpu_next_ids = next_ids.tolist() cpu_next_ids = next_ids.tolist()
index = 0 index = 0
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())): for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
stop = len(_input_ids) stop = len(_input_ids)
# TODO 0 is not necessarily the pad token.
# Remove zero padded end. # Remove zero padded end.
for j, _id in enumerate(_input_ids): for j, _id in enumerate(_input_ids):
if _id == 0: if _id == 0:
@ -339,13 +336,7 @@ class HeterogeneousNextTokenChooser:
speculative_ids = Greedy()(speculative_scores) speculative_ids = Greedy()(speculative_scores)
else: else:
# n-gram # n-gram
import datetime
start = datetime.datetime.now()
# if input_ids.shape[0] > 4:
# import ipdb;ipdb.set_trace()
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate) speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
from loguru import logger
logger.info(f"n gram took {datetime.datetime.now() - start}")
else: else:
speculative_ids = None speculative_ids = None