Updating medusa test + Speeding ngram immensely by just making a smple

search on device instead of on CPU with bad worst cases O(n)
This commit is contained in:
Nicolas Patry 2023-12-06 16:31:35 +00:00
parent 3a8b1923db
commit d2b42f6883
2 changed files with 53 additions and 76 deletions

View File

@ -40,7 +40,7 @@
},
{
"id": 2772,
"logprob": 0.0,
"logprob": -0.23083496,
"special": false,
"text": "De"
},
@ -57,42 +57,42 @@
"text": " learning"
},
{
"id": 508,
"logprob": -1.5087891,
"id": 756,
"logprob": -0.48095703,
"special": false,
"text": " can"
"text": " has"
},
{
"id": 367,
"id": 19479,
"logprob": 0.0,
"special": false,
"text": " be"
"text": " revolution"
},
{
"id": 2714,
"logprob": -0.6538086,
"id": 1891,
"logprob": 0.0,
"special": false,
"text": " thought"
"text": "ized"
},
{
"id": 278,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 1746,
"logprob": 0.0,
"special": false,
"text": " field"
},
{
"id": 310,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 408,
"logprob": 0.0,
"special": false,
"text": " as"
},
{
"id": 263,
"logprob": 0.0,
"special": false,
"text": " a"
}
]
},
"generated_text": "What is Deep Learning?\nDeep learning can be thought of as a"
"generated_text": "What is Deep Learning?\nDeep learning has revolutionized the field of"
}

View File

@ -16,6 +16,8 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from loguru import logger
class NextTokenChooser:
def __init__(
self,
@ -145,57 +147,20 @@ class StoppingCriteria:
pb.ignore_eos_token,
)
def longest_match(input_ids: List[int]) -> Optional[int]:
longest_match = 0
seed = input_ids[-1]
final_matches = []
current_matches = []
for i in range(1, len(input_ids)):
index = len(input_ids) - i - 1
_current_matches = []
for (_index, length) in current_matches:
if input_ids[index] == input_ids[len(input_ids) - length - 1]:
_current_matches.append((_index, length + 1))
elif length > longest_match:
longest_match = length
final_matches.append((_index, length))
else:
pass
current_matches = _current_matches
if input_ids[index] == seed:
current_matches.append( (index, 1) )
if not final_matches:
return 0
return final_matches[-1][0]
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# import datetime
# start = datetime.datetime.now()
B = accepted_ids.shape[0]
device = input_ids.device
dtype = input_ids.dtype
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
cpu_input_ids = input_ids.tolist()
cpu_next_ids = next_ids.tolist()
# speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
index = 0
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
stop = len(_input_ids)
# Remove zero padded end.
for j, _id in enumerate(_input_ids):
if _id == 0:
stop = j
break
_input_ids = _input_ids[:stop]
_input_ids.extend(cpu_next_ids[index: index+n_accepted_ids])
index = longest_match(_input_ids) + 1
slice_ = input_ids[i, index:index+speculate]
# logger.info(f"{slice_.shape} - {speculative_ids.shape}")
speculative_ids[i, :len(slice_)] = slice_
index += n_accepted_ids
# logger.info(f"All indices {all_indices} - {input_ids.shape}")
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser:
@ -266,7 +231,11 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
import datetime
# from loguru import logger
start = datetime.datetime.now()
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
@ -276,8 +245,11 @@ class HeterogeneousNextTokenChooser:
S = 1
scores = scores.view(B, S, -1)
# if verbose:
# logger.info(f"Reshape {datetime.datetime.now() - start}")
all_next_ids = []
all_scores = []
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
@ -289,11 +261,13 @@ class HeterogeneousNextTokenChooser:
_scores = warper(input_ids, _scores)
next_ids = self.choice(_scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
all_next_ids.append(next_ids.unsqueeze(1))
next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S)
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
# if verbose:
# logger.info(f"Scores {datetime.datetime.now() - start}")
if speculated_ids is not None:
accepted_ids = []
@ -325,20 +299,23 @@ class HeterogeneousNextTokenChooser:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
# if verbose:
# logger.info(f"Indices/accepted id {datetime.datetime.now() - start}")
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
# TODO This will only speculate the top score
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else:
speculative_ids = None
# if verbose:
# logger.info(f"new speculative ids {datetime.datetime.now() - start}")
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids