Fixing tests

This commit is contained in:
Nicolas Patry 2024-01-26 18:36:51 +00:00
parent a5600c23af
commit 9d3190179e
3 changed files with 43 additions and 32 deletions

View File

@ -581,7 +581,7 @@ class CausalLM(Model):
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
@ -695,20 +695,24 @@ class CausalLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -641,7 +641,7 @@ class Seq2SeqLM(Model):
)
# Speculation is not active for seq2seq
accepted_ids = torch.ones_like(batch.decoder_input_ids)
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
@ -749,20 +749,24 @@ class Seq2SeqLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -306,13 +306,15 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices]
logprobs = alllogprobs[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = alllogprobs
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
@ -436,7 +438,7 @@ class HeterogeneousSampling:
def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations.
@ -486,6 +488,7 @@ def batch_top_tokens(
_top_values = top_values[start: stop]
_top_n_ishes = top_n_ishes[start: stop]
_top_n_tokens = top_n_tokens[start: stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]