Improve create_n_gram degradation.

Padded zeros was the worst case scenario.
This commit is contained in:
Nicolas Patry 2023-12-06 06:31:57 +00:00
parent a3cc5a94c6
commit 7b34445457
4 changed files with 30 additions and 6 deletions

View File

@ -527,6 +527,11 @@ fn send_responses(
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
assert_eq!(n, tokens_.logprobs.len());
assert_eq!(n, tokens_.texts.len());
assert_eq!(n, tokens_.is_special.len());
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()

View File

@ -229,7 +229,7 @@ impl State {
} }
if self.requires_padding { if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate;
} else { } else {
let max_new_tokens = match self.window_size { let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens, None => entry.request.stopping_parameters.max_new_tokens,

View File

@ -829,6 +829,9 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
from loguru import logger
logger.info(f"Accepted id {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )

View File

@ -174,18 +174,28 @@ def longest_match(input_ids: List[int]) -> Optional[int]:
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int): def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
from loguru import logger
B = accepted_ids.shape[0] B = accepted_ids.shape[0]
device = input_ids.device device = input_ids.device
dtype = input_ids.dtype dtype = input_ids.dtype
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
input_ids = input_ids.tolist() logger.info(f"{input_ids.shape} ")
cpu_input_ids = input_ids.tolist()
cpu_next_ids = next_ids.tolist()
index = 0 index = 0
for i, (_input_ids, n_accepted_ids) in enumerate(zip(input_ids, accepted_ids.tolist())): for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
_input_ids.extend(next_ids[index: index + n_accepted_ids].tolist()) stop = len(_input_ids)
for j, _id in enumerate(_input_ids):
if _id == 0:
stop = j
break
_input_ids = _input_ids[:stop]
_input_ids.extend(cpu_next_ids[index: index+n_accepted_ids])
index = longest_match(_input_ids) + 1 index = longest_match(_input_ids) + 1
ids = _input_ids[index:index+speculate] slice_ = input_ids[i, index:index+speculate]
speculative_ids[i, :len(ids)] = torch.tensor(ids, device=device, dtype=dtype) # logger.info(f"{slice_.shape} - {speculative_ids.shape}")
speculative_ids[i, :len(slice_)] = slice_
index += n_accepted_ids index += n_accepted_ids
return speculative_ids return speculative_ids
@ -327,7 +337,13 @@ class HeterogeneousNextTokenChooser:
speculative_ids = Greedy()(speculative_scores) speculative_ids = Greedy()(speculative_scores)
else: else:
# n-gram # n-gram
import datetime
start = datetime.datetime.now()
# if input_ids.shape[0] > 4:
# import ipdb;ipdb.set_trace()
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate) speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
from loguru import logger
logger.info(f"n gram took {datetime.datetime.now() - start}")
else: else:
speculative_ids = None speculative_ids = None