mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Updating medusa test + Speeding ngram immensely by just making a smple
search on device instead of on CPU with bad worst cases O(n)
This commit is contained in:
parent
3a8b1923db
commit
d2b42f6883
@ -40,7 +40,7 @@
|
||||
},
|
||||
{
|
||||
"id": 2772,
|
||||
"logprob": 0.0,
|
||||
"logprob": -0.23083496,
|
||||
"special": false,
|
||||
"text": "De"
|
||||
},
|
||||
@ -57,42 +57,42 @@
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 508,
|
||||
"logprob": -1.5087891,
|
||||
"id": 756,
|
||||
"logprob": -0.48095703,
|
||||
"special": false,
|
||||
"text": " can"
|
||||
"text": " has"
|
||||
},
|
||||
{
|
||||
"id": 367,
|
||||
"id": 19479,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " be"
|
||||
"text": " revolution"
|
||||
},
|
||||
{
|
||||
"id": 2714,
|
||||
"logprob": -0.6538086,
|
||||
"id": 1891,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " thought"
|
||||
"text": "ized"
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 1746,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 408,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " as"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "What is Deep Learning?\nDeep learning can be thought of as a"
|
||||
"generated_text": "What is Deep Learning?\nDeep learning has revolutionized the field of"
|
||||
}
|
||||
|
@ -16,6 +16,8 @@ from text_generation_server.utils.logits_process import (
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
from loguru import logger
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
@ -145,57 +147,20 @@ class StoppingCriteria:
|
||||
pb.ignore_eos_token,
|
||||
)
|
||||
|
||||
|
||||
def longest_match(input_ids: List[int]) -> Optional[int]:
|
||||
longest_match = 0
|
||||
seed = input_ids[-1]
|
||||
final_matches = []
|
||||
current_matches = []
|
||||
for i in range(1, len(input_ids)):
|
||||
index = len(input_ids) - i - 1
|
||||
|
||||
_current_matches = []
|
||||
for (_index, length) in current_matches:
|
||||
if input_ids[index] == input_ids[len(input_ids) - length - 1]:
|
||||
_current_matches.append((_index, length + 1))
|
||||
elif length > longest_match:
|
||||
longest_match = length
|
||||
final_matches.append((_index, length))
|
||||
else:
|
||||
pass
|
||||
current_matches = _current_matches
|
||||
|
||||
if input_ids[index] == seed:
|
||||
current_matches.append( (index, 1) )
|
||||
if not final_matches:
|
||||
return 0
|
||||
return final_matches[-1][0]
|
||||
|
||||
|
||||
|
||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
||||
# import datetime
|
||||
# start = datetime.datetime.now()
|
||||
B = accepted_ids.shape[0]
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||
cpu_input_ids = input_ids.tolist()
|
||||
cpu_next_ids = next_ids.tolist()
|
||||
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
||||
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
||||
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
||||
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
||||
|
||||
index = 0
|
||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
||||
stop = len(_input_ids)
|
||||
# Remove zero padded end.
|
||||
for j, _id in enumerate(_input_ids):
|
||||
if _id == 0:
|
||||
stop = j
|
||||
break
|
||||
_input_ids = _input_ids[:stop]
|
||||
_input_ids.extend(cpu_next_ids[index: index+n_accepted_ids])
|
||||
index = longest_match(_input_ids) + 1
|
||||
slice_ = input_ids[i, index:index+speculate]
|
||||
# logger.info(f"{slice_.shape} - {speculative_ids.shape}")
|
||||
speculative_ids[i, :len(slice_)] = slice_
|
||||
index += n_accepted_ids
|
||||
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
|
||||
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
||||
return speculative_ids
|
||||
|
||||
class HeterogeneousNextTokenChooser:
|
||||
@ -266,7 +231,11 @@ class HeterogeneousNextTokenChooser:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
|
||||
import datetime
|
||||
# from loguru import logger
|
||||
|
||||
start = datetime.datetime.now()
|
||||
if speculated_ids is not None:
|
||||
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||
S = speculated_ids.shape[1] + 1
|
||||
@ -276,8 +245,11 @@ class HeterogeneousNextTokenChooser:
|
||||
S = 1
|
||||
scores = scores.view(B, S, -1)
|
||||
|
||||
# if verbose:
|
||||
# logger.info(f"Reshape {datetime.datetime.now() - start}")
|
||||
|
||||
all_next_ids = []
|
||||
all_scores = []
|
||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||
for j in range(S):
|
||||
_scores = scores[:, j]
|
||||
if self.watermark_processor is not None:
|
||||
@ -289,11 +261,13 @@ class HeterogeneousNextTokenChooser:
|
||||
_scores = warper(input_ids, _scores)
|
||||
|
||||
|
||||
next_ids = self.choice(_scores)
|
||||
_next_ids = self.choice(_scores)
|
||||
scores[:, j] = _scores
|
||||
all_next_ids.append(next_ids.unsqueeze(1))
|
||||
next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S)
|
||||
next_ids[:, j] = _next_ids
|
||||
next_ids = next_ids.view(B*S)
|
||||
scores = scores.view( B* S, -1)
|
||||
# if verbose:
|
||||
# logger.info(f"Scores {datetime.datetime.now() - start}")
|
||||
|
||||
if speculated_ids is not None:
|
||||
accepted_ids = []
|
||||
@ -325,20 +299,23 @@ class HeterogeneousNextTokenChooser:
|
||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||
else:
|
||||
accepted_ids = torch.ones_like(next_ids)
|
||||
# if verbose:
|
||||
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
|
||||
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
if speculate > 0:
|
||||
if speculative_scores is not None:
|
||||
# TODO This will only speculate the top score
|
||||
# Medusa provided some scores
|
||||
speculative_ids = Greedy()(speculative_scores)
|
||||
else:
|
||||
# n-gram
|
||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
||||
else:
|
||||
speculative_ids = None
|
||||
# if verbose:
|
||||
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
|
||||
|
||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user