Include a few fixes

Bad degradation on pad tokens for ngram
Invalid batch discard `stopped` value could be incorrect.
This commit is contained in:
Nicolas Patry 2023-12-06 09:45:42 +00:00
parent 7b34445457
commit f6958ea6d4
5 changed files with 27 additions and 29 deletions

View File

@ -527,11 +527,6 @@ fn send_responses(
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
assert_eq!(n, tokens_.logprobs.len());
assert_eq!(n, tokens_.texts.len());
assert_eq!(n, tokens_.is_special.len());
let mut iterator = tokens_
.ids
.into_iter()

View File

@ -229,7 +229,7 @@ impl State {
}
if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate;
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
@ -237,7 +237,7 @@ impl State {
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
} + self.speculate;
};
// pad to block size
decode_tokens +=
@ -245,7 +245,7 @@ impl State {
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
@ -543,7 +543,15 @@ mod tests {
queue.append(entry1);
queue.append(entry2);
// Budget of 1 is not enough
assert!(queue.next_batch(None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
}
#[tokio::test]

View File

@ -162,9 +162,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :]
speculate_ids = []
input_length = len(tokenized_input)
input_lengths.append(input_length)
@ -806,10 +803,9 @@ class FlashCausalLM(Model):
del batch
raise e
try:
out, speculative_logits = out.logits, out.speculative_logits
except Exception:
out = out
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
@ -829,9 +825,6 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
)
from loguru import logger
logger.info(f"Accepted id {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
@ -843,8 +836,7 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
length = len(batch)
next_position_ids = batch.position_ids.new_empty(length)
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
@ -965,6 +957,9 @@ class FlashCausalLM(Model):
# Append next token to all tokens
next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
@ -982,11 +977,12 @@ class FlashCausalLM(Model):
)
if stop:
stopped = True
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
stopped = False
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
@ -997,15 +993,12 @@ class FlashCausalLM(Model):
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
# Remove potentially accepted ids that do not respect
# the stopping_criteria
_ids = all_input_ids
output_text, _, _ = self.decode_token(
_ids,
prefix_offset=len(_ids)
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(_ids)
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)

View File

@ -34,7 +34,7 @@ class MedusaModel(torch.nn.Module):
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return Output(logits=logits, speculative_logits=speculative_logits)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):

View File

@ -186,6 +186,8 @@ def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, a
index = 0
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
stop = len(_input_ids)
# TODO 0 is not necessarily the pad token.
# Remove zero padded end.
for j, _id in enumerate(_input_ids):
if _id == 0:
stop = j