mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Updating medusa test + Speeding ngram immensely by just making a smple
search on device instead of on CPU with bad worst cases O(n)
This commit is contained in:
parent
3a8b1923db
commit
d2b42f6883
@ -40,7 +40,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2772,
|
"id": 2772,
|
||||||
"logprob": 0.0,
|
"logprob": -0.23083496,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "De"
|
"text": "De"
|
||||||
},
|
},
|
||||||
@ -57,42 +57,42 @@
|
|||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 508,
|
"id": 756,
|
||||||
"logprob": -1.5087891,
|
"logprob": -0.48095703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " can"
|
"text": " has"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 367,
|
"id": 19479,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " be"
|
"text": " revolution"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2714,
|
"id": 1891,
|
||||||
"logprob": -0.6538086,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " thought"
|
"text": "ized"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1746,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " field"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 310,
|
"id": 310,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 408,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " as"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 263,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " a"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "What is Deep Learning?\nDeep learning can be thought of as a"
|
"generated_text": "What is Deep Learning?\nDeep learning has revolutionized the field of"
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,8 @@ from text_generation_server.utils.logits_process import (
|
|||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -145,57 +147,20 @@ class StoppingCriteria:
|
|||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
||||||
def longest_match(input_ids: List[int]) -> Optional[int]:
|
# import datetime
|
||||||
longest_match = 0
|
# start = datetime.datetime.now()
|
||||||
seed = input_ids[-1]
|
|
||||||
final_matches = []
|
|
||||||
current_matches = []
|
|
||||||
for i in range(1, len(input_ids)):
|
|
||||||
index = len(input_ids) - i - 1
|
|
||||||
|
|
||||||
_current_matches = []
|
|
||||||
for (_index, length) in current_matches:
|
|
||||||
if input_ids[index] == input_ids[len(input_ids) - length - 1]:
|
|
||||||
_current_matches.append((_index, length + 1))
|
|
||||||
elif length > longest_match:
|
|
||||||
longest_match = length
|
|
||||||
final_matches.append((_index, length))
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
current_matches = _current_matches
|
|
||||||
|
|
||||||
if input_ids[index] == seed:
|
|
||||||
current_matches.append( (index, 1) )
|
|
||||||
if not final_matches:
|
|
||||||
return 0
|
|
||||||
return final_matches[-1][0]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
|
||||||
B = accepted_ids.shape[0]
|
B = accepted_ids.shape[0]
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
dtype = input_ids.dtype
|
dtype = input_ids.dtype
|
||||||
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||||
cpu_input_ids = input_ids.tolist()
|
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
||||||
cpu_next_ids = next_ids.tolist()
|
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
||||||
|
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
||||||
|
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
||||||
|
|
||||||
index = 0
|
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
|
||||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
||||||
stop = len(_input_ids)
|
|
||||||
# Remove zero padded end.
|
|
||||||
for j, _id in enumerate(_input_ids):
|
|
||||||
if _id == 0:
|
|
||||||
stop = j
|
|
||||||
break
|
|
||||||
_input_ids = _input_ids[:stop]
|
|
||||||
_input_ids.extend(cpu_next_ids[index: index+n_accepted_ids])
|
|
||||||
index = longest_match(_input_ids) + 1
|
|
||||||
slice_ = input_ids[i, index:index+speculate]
|
|
||||||
# logger.info(f"{slice_.shape} - {speculative_ids.shape}")
|
|
||||||
speculative_ids[i, :len(slice_)] = slice_
|
|
||||||
index += n_accepted_ids
|
|
||||||
return speculative_ids
|
return speculative_ids
|
||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
@ -266,7 +231,11 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
|
||||||
|
import datetime
|
||||||
|
# from loguru import logger
|
||||||
|
|
||||||
|
start = datetime.datetime.now()
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
S = speculated_ids.shape[1] + 1
|
S = speculated_ids.shape[1] + 1
|
||||||
@ -276,8 +245,11 @@ class HeterogeneousNextTokenChooser:
|
|||||||
S = 1
|
S = 1
|
||||||
scores = scores.view(B, S, -1)
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
|
# if verbose:
|
||||||
|
# logger.info(f"Reshape {datetime.datetime.now() - start}")
|
||||||
|
|
||||||
all_next_ids = []
|
all_next_ids = []
|
||||||
all_scores = []
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
for j in range(S):
|
for j in range(S):
|
||||||
_scores = scores[:, j]
|
_scores = scores[:, j]
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -289,11 +261,13 @@ class HeterogeneousNextTokenChooser:
|
|||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
|
||||||
|
|
||||||
next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
all_next_ids.append(next_ids.unsqueeze(1))
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S)
|
next_ids = next_ids.view(B*S)
|
||||||
scores = scores.view( B* S, -1)
|
scores = scores.view( B* S, -1)
|
||||||
|
# if verbose:
|
||||||
|
# logger.info(f"Scores {datetime.datetime.now() - start}")
|
||||||
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
@ -325,20 +299,23 @@ class HeterogeneousNextTokenChooser:
|
|||||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||||
else:
|
else:
|
||||||
accepted_ids = torch.ones_like(next_ids)
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
|
# if verbose:
|
||||||
|
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
# TODO This will only speculate the top score
|
|
||||||
# Medusa provided some scores
|
# Medusa provided some scores
|
||||||
speculative_ids = Greedy()(speculative_scores)
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
else:
|
else:
|
||||||
# n-gram
|
# n-gram
|
||||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
# if verbose:
|
||||||
|
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user