diff --git a/router/src/infer.rs b/router/src/infer.rs index 2e199ce2..d4057f1f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -527,6 +527,11 @@ fn send_responses( let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + + assert_eq!(n, tokens_.logprobs.len()); + assert_eq!(n, tokens_.texts.len()); + assert_eq!(n, tokens_.is_special.len()); + let mut iterator = tokens_ .ids .into_iter() diff --git a/router/src/queue.rs b/router/src/queue.rs index 0436b8f2..87ee285a 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -229,7 +229,7 @@ impl State { } if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate; } else { let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3f9c21b2..d0708e11 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -829,6 +829,9 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits ) + from loguru import logger + logger.info(f"Accepted id {accepted_ids}") + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d5a703c8..0f6c7ce7 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -174,18 +174,28 @@ def longest_match(input_ids: List[int]) -> Optional[int]: def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int): + from loguru import logger B = accepted_ids.shape[0] device = input_ids.device dtype = input_ids.dtype speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) - input_ids = input_ids.tolist() + logger.info(f"{input_ids.shape} ") + cpu_input_ids = input_ids.tolist() + cpu_next_ids = next_ids.tolist() index = 0 - for i, (_input_ids, n_accepted_ids) in enumerate(zip(input_ids, accepted_ids.tolist())): - _input_ids.extend(next_ids[index: index + n_accepted_ids].tolist()) + for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())): + stop = len(_input_ids) + for j, _id in enumerate(_input_ids): + if _id == 0: + stop = j + break + _input_ids = _input_ids[:stop] + _input_ids.extend(cpu_next_ids[index: index+n_accepted_ids]) index = longest_match(_input_ids) + 1 - ids = _input_ids[index:index+speculate] - speculative_ids[i, :len(ids)] = torch.tensor(ids, device=device, dtype=dtype) + slice_ = input_ids[i, index:index+speculate] + # logger.info(f"{slice_.shape} - {speculative_ids.shape}") + speculative_ids[i, :len(slice_)] = slice_ index += n_accepted_ids return speculative_ids @@ -327,7 +337,13 @@ class HeterogeneousNextTokenChooser: speculative_ids = Greedy()(speculative_scores) else: # n-gram + import datetime + start = datetime.datetime.now() + # if input_ids.shape[0] > 4: + # import ipdb;ipdb.set_trace() speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate) + from loguru import logger + logger.info(f"n gram took {datetime.datetime.now() - start}") else: speculative_ids = None