Fixing tests

This commit is contained in:
Nicolas Patry 2024-01-26 18:36:51 +00:00
parent a5600c23af
commit 9d3190179e
3 changed files with 43 additions and 32 deletions

View File

@ -581,7 +581,7 @@ class CausalLM(Model):
stopped = True stopped = True
# Speculation is not active for causal # Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids) accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens,
batch.top_n_tokens_tensor, batch.top_n_tokens_tensor,
@ -695,20 +695,24 @@ class CausalLM(Model):
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0: if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode( all_top_tokens = []
top_token_ids, for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
clean_up_tokenization_spaces=False, toptoken_texts = self.tokenizer.batch_decode(
skip_special_tokens=False, top_token_ids,
) clean_up_tokenization_spaces=False,
special_toptokens = [ skip_special_tokens=False,
token_id in self.all_special_ids for token_id in top_token_ids )
] special_toptokens = [
top_tokens = Tokens( token_id in self.all_special_ids for token_id in top_token_ids
top_token_ids, ]
top_token_logprobs, top_tokens = Tokens(
toptoken_texts, top_token_ids,
special_toptokens, top_token_logprobs,
) toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else: else:
top_tokens = None top_tokens = None

View File

@ -641,7 +641,7 @@ class Seq2SeqLM(Model):
) )
# Speculation is not active for seq2seq # Speculation is not active for seq2seq
accepted_ids = torch.ones_like(batch.decoder_input_ids) accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens,
batch.top_n_tokens_tensor, batch.top_n_tokens_tensor,
@ -749,20 +749,24 @@ class Seq2SeqLM(Model):
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0: if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode( all_top_tokens = []
top_token_ids, for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
clean_up_tokenization_spaces=False, toptoken_texts = self.tokenizer.batch_decode(
skip_special_tokens=False, top_token_ids,
) clean_up_tokenization_spaces=False,
special_toptokens = [ skip_special_tokens=False,
token_id in self.all_special_ids for token_id in top_token_ids )
] special_toptokens = [
top_tokens = Tokens( token_id in self.all_special_ids for token_id in top_token_ids
top_token_ids, ]
top_token_logprobs, top_tokens = Tokens(
toptoken_texts, top_token_ids,
special_toptokens, top_token_logprobs,
) toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else: else:
top_tokens = None top_tokens = None

View File

@ -306,13 +306,15 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype accepted_ids, device=input_ids.device, dtype=input_ids.dtype
) )
next_ids = next_ids[indices] next_ids = next_ids[indices]
logprobs = alllogprobs[indices]
indices = torch.arange(B, device=input_ids.device) * S indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None: if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1] speculative_scores = speculative_scores[indices + accepted_ids - 1]
else: else:
accepted_ids = torch.ones_like(next_ids) accepted_ids = torch.ones_like(next_ids)
logprobs = alllogprobs
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0: if speculate > 0:
@ -436,7 +438,7 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
@ -486,6 +488,7 @@ def batch_top_tokens(
_top_values = top_values[start: stop] _top_values = top_values[start: stop]
_top_n_ishes = top_n_ishes[start: stop] _top_n_ishes = top_n_ishes[start: stop]
_top_n_tokens = top_n_tokens[start: stop] _top_n_tokens = top_n_tokens[start: stop]
_top_indices = _top_indices[:n_accepted_ids] _top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids] _top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids] _top_n_ishes = _top_n_ishes[:n_accepted_ids]