mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Remove ngram debug code
This commit is contained in:
parent
b3c1492be1
commit
3a8b1923db
@ -174,19 +174,16 @@ def longest_match(input_ids: List[int]) -> Optional[int]:
|
||||
|
||||
|
||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
||||
from loguru import logger
|
||||
B = accepted_ids.shape[0]
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||
logger.info(f"{input_ids.shape} ")
|
||||
cpu_input_ids = input_ids.tolist()
|
||||
cpu_next_ids = next_ids.tolist()
|
||||
|
||||
index = 0
|
||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
||||
stop = len(_input_ids)
|
||||
# TODO 0 is not necessarily the pad token.
|
||||
# Remove zero padded end.
|
||||
for j, _id in enumerate(_input_ids):
|
||||
if _id == 0:
|
||||
@ -339,13 +336,7 @@ class HeterogeneousNextTokenChooser:
|
||||
speculative_ids = Greedy()(speculative_scores)
|
||||
else:
|
||||
# n-gram
|
||||
import datetime
|
||||
start = datetime.datetime.now()
|
||||
# if input_ids.shape[0] > 4:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
||||
from loguru import logger
|
||||
logger.info(f"n gram took {datetime.datetime.now() - start}")
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user