mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Remove ngram debug code
This commit is contained in:
parent
b3c1492be1
commit
3a8b1923db
@ -174,19 +174,16 @@ def longest_match(input_ids: List[int]) -> Optional[int]:
|
|||||||
|
|
||||||
|
|
||||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
||||||
from loguru import logger
|
|
||||||
B = accepted_ids.shape[0]
|
B = accepted_ids.shape[0]
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
dtype = input_ids.dtype
|
dtype = input_ids.dtype
|
||||||
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||||
logger.info(f"{input_ids.shape} ")
|
|
||||||
cpu_input_ids = input_ids.tolist()
|
cpu_input_ids = input_ids.tolist()
|
||||||
cpu_next_ids = next_ids.tolist()
|
cpu_next_ids = next_ids.tolist()
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
||||||
stop = len(_input_ids)
|
stop = len(_input_ids)
|
||||||
# TODO 0 is not necessarily the pad token.
|
|
||||||
# Remove zero padded end.
|
# Remove zero padded end.
|
||||||
for j, _id in enumerate(_input_ids):
|
for j, _id in enumerate(_input_ids):
|
||||||
if _id == 0:
|
if _id == 0:
|
||||||
@ -339,13 +336,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
speculative_ids = Greedy()(speculative_scores)
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
else:
|
else:
|
||||||
# n-gram
|
# n-gram
|
||||||
import datetime
|
|
||||||
start = datetime.datetime.now()
|
|
||||||
# if input_ids.shape[0] > 4:
|
|
||||||
# import ipdb;ipdb.set_trace()
|
|
||||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
||||||
from loguru import logger
|
|
||||||
logger.info(f"n gram took {datetime.datetime.now() - start}")
|
|
||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user