mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing tests
This commit is contained in:
parent
a5600c23af
commit
9d3190179e
@ -581,7 +581,7 @@ class CausalLM(Model):
|
||||
stopped = True
|
||||
|
||||
# Speculation is not active for causal
|
||||
accepted_ids = torch.ones_like(batch.input_ids)
|
||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
@ -695,20 +695,24 @@ class CausalLM(Model):
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
@ -641,7 +641,7 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
|
||||
# Speculation is not active for seq2seq
|
||||
accepted_ids = torch.ones_like(batch.decoder_input_ids)
|
||||
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
@ -749,20 +749,24 @@ class Seq2SeqLM(Model):
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
@ -306,13 +306,15 @@ class HeterogeneousNextTokenChooser:
|
||||
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
||||
)
|
||||
next_ids = next_ids[indices]
|
||||
logprobs = alllogprobs[indices]
|
||||
indices = torch.arange(B, device=input_ids.device) * S
|
||||
if speculative_scores is not None:
|
||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||
else:
|
||||
accepted_ids = torch.ones_like(next_ids)
|
||||
logprobs = alllogprobs
|
||||
|
||||
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
|
||||
if speculate > 0:
|
||||
@ -436,7 +438,7 @@ class HeterogeneousSampling:
|
||||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
@ -486,6 +488,7 @@ def batch_top_tokens(
|
||||
_top_values = top_values[start: stop]
|
||||
_top_n_ishes = top_n_ishes[start: stop]
|
||||
_top_n_tokens = top_n_tokens[start: stop]
|
||||
|
||||
_top_indices = _top_indices[:n_accepted_ids]
|
||||
_top_values = _top_values[:n_accepted_ids]
|
||||
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||
|
Loading…
Reference in New Issue
Block a user