mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing other types of models + tests + Damn you python scoping.
This commit is contained in:
parent
6e629add98
commit
0452d590d0
@ -50,19 +50,39 @@ def test_batch_top_tokens():
|
|||||||
top_n_tokens = [0, 2, 3, 4, 5]
|
top_n_tokens = [0, 2, 3, 4, 5]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
||||||
|
accepted_ids = torch.ones_like(top_n_tokens_tensor)
|
||||||
|
|
||||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
assert topn_tok_ids[0] == []
|
assert topn_tok_ids[0] == [[]]
|
||||||
assert topn_tok_ids[1] == [0, 3]
|
assert topn_tok_ids[1] == [[0, 3]]
|
||||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||||
|
|
||||||
assert topn_tok_logprobs[0] == []
|
assert topn_tok_logprobs[0] == [[]]
|
||||||
assert topn_tok_logprobs[1] == [-1, -2]
|
assert topn_tok_logprobs[1] == [[-1, -2]]
|
||||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||||
|
|
||||||
|
# Now let's make second member of the batch be speculated
|
||||||
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
|
||||||
|
accepted_ids[1] = 2
|
||||||
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||||
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
assert topn_tok_ids[0] == [[]]
|
||||||
|
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
|
||||||
|
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||||
|
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||||
|
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||||
|
|
||||||
|
assert topn_tok_logprobs[0] == [[]]
|
||||||
|
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
|
||||||
|
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||||
|
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||||
|
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||||
|
@ -580,10 +580,13 @@ class CausalLM(Model):
|
|||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
|
# Speculation is not active for causal
|
||||||
|
accepted_ids = torch.ones_like(batch.input_ids)
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
batch.top_n_tokens_tensor,
|
batch.top_n_tokens_tensor,
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
@ -858,9 +858,8 @@ class FlashCausalLM(Model):
|
|||||||
speculative_logits,
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculated_length = batch.speculative_ids.shape[-1] if batch.speculative_ids is not None else 0
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids, speculated_length
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
|
@ -640,10 +640,13 @@ class Seq2SeqLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Speculation is not active for seq2seq
|
||||||
|
accepted_ids = torch.ones_like(batch.input_ids)
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
batch.top_n_tokens_tensor,
|
batch.top_n_tokens_tensor,
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
@ -436,8 +436,8 @@ class HeterogeneousSampling:
|
|||||||
|
|
||||||
|
|
||||||
def batch_top_tokens(
|
def batch_top_tokens(
|
||||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor, speculative_length: int
|
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||||
"""Find the top n most likely tokens for a batch of generations.
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
When multiple tokens have equal probabilities and they don't all fit, the
|
When multiple tokens have equal probabilities and they don't all fit, the
|
||||||
@ -446,13 +446,14 @@ def batch_top_tokens(
|
|||||||
max_top_n = max(top_n_tokens)
|
max_top_n = max(top_n_tokens)
|
||||||
# Early exit when top_n_tokens is not used
|
# Early exit when top_n_tokens is not used
|
||||||
if max_top_n == 0:
|
if max_top_n == 0:
|
||||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
||||||
|
|
||||||
|
|
||||||
n = speculative_length + 1
|
batch_size = accepted_ids.shape[0]
|
||||||
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(n)
|
speculate_size = logprobs.shape[0] // batch_size
|
||||||
|
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculative_length + 1)]
|
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
|
||||||
|
|
||||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
@ -477,12 +478,14 @@ def batch_top_tokens(
|
|||||||
|
|
||||||
batch_top_token_ids = []
|
batch_top_token_ids = []
|
||||||
batch_top_token_logprobs = []
|
batch_top_token_logprobs = []
|
||||||
accepted_ids = accepted_ids.tolist()
|
accepted_ids_list = accepted_ids.tolist()
|
||||||
for i, n_accepted_ids in enumerate(accepted_ids):
|
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
||||||
_top_indices = top_indices[n * i: n * (i + 1)]
|
start = speculate_size * i
|
||||||
_top_values = top_values[n * i: n * (i + 1)]
|
stop = speculate_size * (i + 1)
|
||||||
_top_n_ishes = top_n_ishes[n * i: n * (i + 1)]
|
_top_indices = top_indices[start: stop]
|
||||||
_top_n_tokens = top_n_tokens[n * i: n * (i + 1)]
|
_top_values = top_values[start: stop]
|
||||||
|
_top_n_ishes = top_n_ishes[start: stop]
|
||||||
|
_top_n_tokens = top_n_tokens[start: stop]
|
||||||
_top_indices = _top_indices[:n_accepted_ids]
|
_top_indices = _top_indices[:n_accepted_ids]
|
||||||
_top_values = _top_values[:n_accepted_ids]
|
_top_values = _top_values[:n_accepted_ids]
|
||||||
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||||
|
Loading…
Reference in New Issue
Block a user