mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Include a few fixes
Bad degradation on pad tokens for ngram Invalid batch discard `stopped` value could be incorrect.
This commit is contained in:
parent
7b34445457
commit
f6958ea6d4
@ -527,11 +527,6 @@ fn send_responses(
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
|
||||
assert_eq!(n, tokens_.logprobs.len());
|
||||
assert_eq!(n, tokens_.texts.len());
|
||||
assert_eq!(n, tokens_.is_special.len());
|
||||
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
|
@ -229,7 +229,7 @@ impl State {
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
@ -237,7 +237,7 @@ impl State {
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
} + self.speculate;
|
||||
};
|
||||
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
@ -245,7 +245,7 @@ impl State {
|
||||
}
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens) > token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
@ -543,7 +543,15 @@ mod tests {
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
// Budget of 1 is not enough
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -162,9 +162,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
speculate_ids = []
|
||||
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
@ -806,10 +803,9 @@ class FlashCausalLM(Model):
|
||||
del batch
|
||||
raise e
|
||||
|
||||
try:
|
||||
out, speculative_logits = out.logits, out.speculative_logits
|
||||
except Exception:
|
||||
out = out
|
||||
if isinstance(out, tuple):
|
||||
out, speculative_logits = out
|
||||
else:
|
||||
speculative_logits = None
|
||||
|
||||
|
||||
@ -829,9 +825,6 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
logger.info(f"Accepted id {accepted_ids}")
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
@ -843,8 +836,7 @@ class FlashCausalLM(Model):
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
length = len(batch)
|
||||
next_position_ids = batch.position_ids.new_empty(length)
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
@ -965,6 +957,9 @@ class FlashCausalLM(Model):
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
before = stopping_criteria.current_tokens
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
next_token_id = next_token_ids[j]
|
||||
@ -982,11 +977,12 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
if stop:
|
||||
stopped = True
|
||||
left = index + n_accepted_ids - j - 1
|
||||
current_stopped = True
|
||||
break
|
||||
else:
|
||||
stopped = False
|
||||
current_stopped = False
|
||||
stopped = stopped and current_stopped
|
||||
|
||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
||||
@ -997,15 +993,12 @@ class FlashCausalLM(Model):
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
# Remove potentially accepted ids that do not respect
|
||||
# the stopping_criteria
|
||||
_ids = all_input_ids
|
||||
output_text, _, _ = self.decode_token(
|
||||
_ids,
|
||||
prefix_offset=len(_ids)
|
||||
all_input_ids,
|
||||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(_ids)
|
||||
read_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
@ -34,7 +34,7 @@ class MedusaModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logits = self.lm_head(x)
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return Output(logits=logits, speculative_logits=speculative_logits)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
|
@ -186,6 +186,8 @@ def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, a
|
||||
index = 0
|
||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
||||
stop = len(_input_ids)
|
||||
# TODO 0 is not necessarily the pad token.
|
||||
# Remove zero padded end.
|
||||
for j, _id in enumerate(_input_ids):
|
||||
if _id == 0:
|
||||
stop = j
|
||||
|
Loading…
Reference in New Issue
Block a user