Fixing tests

This commit is contained in:
Nicolas Patry 2024-01-26 18:36:51 +00:00
parent a5600c23af
commit 9d3190179e
3 changed files with 43 additions and 32 deletions

View File

@ -581,7 +581,7 @@ class CausalLM(Model):
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
@ -695,6 +695,8 @@ class CausalLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
@ -709,6 +711,8 @@ class CausalLM(Model):
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -641,7 +641,7 @@ class Seq2SeqLM(Model):
)
# Speculation is not active for seq2seq
accepted_ids = torch.ones_like(batch.decoder_input_ids)
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
@ -749,6 +749,8 @@ class Seq2SeqLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
@ -763,6 +765,8 @@ class Seq2SeqLM(Model):
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -306,13 +306,15 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices]
logprobs = alllogprobs[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = alllogprobs
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
@ -486,6 +488,7 @@ def batch_top_tokens(
_top_values = top_values[start: stop]
_top_n_ishes = top_n_ishes[start: stop]
_top_n_tokens = top_n_tokens[start: stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]