mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing other types of models + tests + Damn you python scoping.
This commit is contained in:
parent
6e629add98
commit
0452d590d0
@ -50,19 +50,39 @@ def test_batch_top_tokens():
|
||||
top_n_tokens = [0, 2, 3, 4, 5]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
||||
accepted_ids = torch.ones_like(top_n_tokens_tensor)
|
||||
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs
|
||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||
)
|
||||
|
||||
assert topn_tok_ids[0] == []
|
||||
assert topn_tok_ids[1] == [0, 3]
|
||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||
assert topn_tok_ids[0] == [[]]
|
||||
assert topn_tok_ids[1] == [[0, 3]]
|
||||
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||
|
||||
assert topn_tok_logprobs[0] == []
|
||||
assert topn_tok_logprobs[1] == [-1, -2]
|
||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||
assert topn_tok_logprobs[0] == [[]]
|
||||
assert topn_tok_logprobs[1] == [[-1, -2]]
|
||||
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||
|
||||
# Now let's make second member of the batch be speculated
|
||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
|
||||
accepted_ids[1] = 2
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||
)
|
||||
|
||||
assert topn_tok_ids[0] == [[]]
|
||||
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
|
||||
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||
|
||||
assert topn_tok_logprobs[0] == [[]]
|
||||
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
|
||||
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||
|
@ -580,10 +580,13 @@ class CausalLM(Model):
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Speculation is not active for causal
|
||||
accepted_ids = torch.ones_like(batch.input_ids)
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
@ -858,9 +858,8 @@ class FlashCausalLM(Model):
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
||||
if prefill:
|
||||
|
@ -640,10 +640,13 @@ class Seq2SeqLM(Model):
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
# Speculation is not active for seq2seq
|
||||
accepted_ids = torch.ones_like(batch.input_ids)
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
@ -436,8 +436,8 @@ class HeterogeneousSampling:
|
||||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
When multiple tokens have equal probabilities and they don't all fit, the
|
||||
@ -446,13 +446,14 @@ def batch_top_tokens(
|
||||
max_top_n = max(top_n_tokens)
|
||||
# Early exit when top_n_tokens is not used
|
||||
if max_top_n == 0:
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
||||
|
||||
|
||||
n = speculative_length + 1
|
||||
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n)
|
||||
batch_size = accepted_ids.shape[0]
|
||||
speculate_size = logprobs.shape[0] // batch_size
|
||||
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)]
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
|
||||
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
@ -477,12 +478,14 @@ def batch_top_tokens(
|
||||
|
||||
batch_top_token_ids = []
|
||||
batch_top_token_logprobs = []
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
for i, n_accepted_ids in enumerate(accepted_ids):
|
||||
_top_indices = top_indices[n * i: n * (i + 1)]
|
||||
_top_values = top_values[n * i: n * (i + 1)]
|
||||
_top_n_ishes = top_n_ishes[n * i: n * (i + 1)]
|
||||
_top_n_tokens = top_n_tokens[n * i: n * (i + 1)]
|
||||
accepted_ids_list = accepted_ids.tolist()
|
||||
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
||||
start = speculate_size * i
|
||||
stop = speculate_size * (i + 1)
|
||||
_top_indices = top_indices[start: stop]
|
||||
_top_values = top_values[start: stop]
|
||||
_top_n_ishes = top_n_ishes[start: stop]
|
||||
_top_n_tokens = top_n_tokens[start: stop]
|
||||
_top_indices = _top_indices[:n_accepted_ids]
|
||||
_top_values = _top_values[:n_accepted_ids]
|
||||
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||
|
Loading…
Reference in New Issue
Block a user