mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing tests
This commit is contained in:
parent
a5600c23af
commit
9d3190179e
@ -581,7 +581,7 @@ class CausalLM(Model):
|
||||
stopped = True
|
||||
|
||||
# Speculation is not active for causal
|
||||
accepted_ids = torch.ones_like(batch.input_ids)
|
||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
@ -695,6 +695,8 @@ class CausalLM(Model):
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
@ -709,6 +711,8 @@ class CausalLM(Model):
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
@ -641,7 +641,7 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
|
||||
# Speculation is not active for seq2seq
|
||||
accepted_ids = torch.ones_like(batch.decoder_input_ids)
|
||||
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
@ -749,6 +749,8 @@ class Seq2SeqLM(Model):
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
@ -763,6 +765,8 @@ class Seq2SeqLM(Model):
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
all_top_tokens.append(top_tokens)
|
||||
top_tokens = all_top_tokens
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
|
@ -306,13 +306,15 @@ class HeterogeneousNextTokenChooser:
|
||||
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
||||
)
|
||||
next_ids = next_ids[indices]
|
||||
logprobs = alllogprobs[indices]
|
||||
indices = torch.arange(B, device=input_ids.device) * S
|
||||
if speculative_scores is not None:
|
||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||
else:
|
||||
accepted_ids = torch.ones_like(next_ids)
|
||||
logprobs = alllogprobs
|
||||
|
||||
next_logprobs = torch.gather(alllogprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
|
||||
if speculate > 0:
|
||||
@ -486,6 +488,7 @@ def batch_top_tokens(
|
||||
_top_values = top_values[start: stop]
|
||||
_top_n_ishes = top_n_ishes[start: stop]
|
||||
_top_n_tokens = top_n_tokens[start: stop]
|
||||
|
||||
_top_indices = _top_indices[:n_accepted_ids]
|
||||
_top_values = _top_values[:n_accepted_ids]
|
||||
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||
|
Loading…
Reference in New Issue
Block a user