mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Improve create_n_gram degradation.
Padded zeros was the worst case scenario.
This commit is contained in:
parent
a3cc5a94c6
commit
7b34445457
@ -527,6 +527,11 @@ fn send_responses(
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
|
||||
assert_eq!(n, tokens_.logprobs.len());
|
||||
assert_eq!(n, tokens_.texts.len());
|
||||
assert_eq!(n, tokens_.is_special.len());
|
||||
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
|
@ -229,7 +229,7 @@ impl State {
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
|
@ -829,6 +829,9 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
logger.info(f"Accepted id {accepted_ids}")
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
|
@ -174,18 +174,28 @@ def longest_match(input_ids: List[int]) -> Optional[int]:
|
||||
|
||||
|
||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
||||
from loguru import logger
|
||||
B = accepted_ids.shape[0]
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||
input_ids = input_ids.tolist()
|
||||
logger.info(f"{input_ids.shape} ")
|
||||
cpu_input_ids = input_ids.tolist()
|
||||
cpu_next_ids = next_ids.tolist()
|
||||
|
||||
index = 0
|
||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(input_ids, accepted_ids.tolist())):
|
||||
_input_ids.extend(next_ids[index: index + n_accepted_ids].tolist())
|
||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
||||
stop = len(_input_ids)
|
||||
for j, _id in enumerate(_input_ids):
|
||||
if _id == 0:
|
||||
stop = j
|
||||
break
|
||||
_input_ids = _input_ids[:stop]
|
||||
_input_ids.extend(cpu_next_ids[index: index+n_accepted_ids])
|
||||
index = longest_match(_input_ids) + 1
|
||||
ids = _input_ids[index:index+speculate]
|
||||
speculative_ids[i, :len(ids)] = torch.tensor(ids, device=device, dtype=dtype)
|
||||
slice_ = input_ids[i, index:index+speculate]
|
||||
# logger.info(f"{slice_.shape} - {speculative_ids.shape}")
|
||||
speculative_ids[i, :len(slice_)] = slice_
|
||||
index += n_accepted_ids
|
||||
return speculative_ids
|
||||
|
||||
@ -327,7 +337,13 @@ class HeterogeneousNextTokenChooser:
|
||||
speculative_ids = Greedy()(speculative_scores)
|
||||
else:
|
||||
# n-gram
|
||||
import datetime
|
||||
start = datetime.datetime.now()
|
||||
# if input_ids.shape[0] > 4:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
||||
from loguru import logger
|
||||
logger.info(f"n gram took {datetime.datetime.now() - start}")
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user