Include a few fixes

Bad degradation on pad tokens for ngram
Invalid batch discard `stopped` value could be incorrect.
This commit is contained in:
Nicolas Patry 2023-12-06 09:45:42 +00:00
parent 7b34445457
commit f6958ea6d4
5 changed files with 27 additions and 29 deletions

View File

@ -527,11 +527,6 @@ fn send_responses(
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
assert_eq!(n, tokens_.logprobs.len());
assert_eq!(n, tokens_.texts.len());
assert_eq!(n, tokens_.is_special.len());
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()

View File

@ -229,7 +229,7 @@ impl State {
} }
if self.requires_padding { if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate; decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else { } else {
let max_new_tokens = match self.window_size { let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens, None => entry.request.stopping_parameters.max_new_tokens,
@ -237,7 +237,7 @@ impl State {
window_size.saturating_sub(entry.request.input_length), window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens, entry.request.stopping_parameters.max_new_tokens,
), ),
} + self.speculate; };
// pad to block size // pad to block size
decode_tokens += decode_tokens +=
@ -245,7 +245,7 @@ impl State {
} }
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget || (prefill_tokens + decode_tokens + self.speculate) > token_budget
{ {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
@ -543,7 +543,15 @@ mod tests {
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
// Budget of 1 is not enough
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
} }
#[tokio::test] #[tokio::test]

View File

@ -162,9 +162,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate :]
speculate_ids = []
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -806,10 +803,9 @@ class FlashCausalLM(Model):
del batch del batch
raise e raise e
try: if isinstance(out, tuple):
out, speculative_logits = out.logits, out.speculative_logits out, speculative_logits = out
except Exception: else:
out = out
speculative_logits = None speculative_logits = None
@ -829,9 +825,6 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
from loguru import logger
logger.info(f"Accepted id {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )
@ -843,8 +836,7 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
length = len(batch) next_position_ids = batch.position_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(length)
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore # We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None batch.cu_seqlen_prefill = None
@ -965,6 +957,9 @@ class FlashCausalLM(Model):
# Append next token to all tokens # Append next token to all tokens
next_token_texts = [] next_token_texts = []
left = 0 left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
# Generated token # Generated token
next_token_id = next_token_ids[j] next_token_id = next_token_ids[j]
@ -982,11 +977,12 @@ class FlashCausalLM(Model):
) )
if stop: if stop:
stopped = True
left = index + n_accepted_ids - j - 1 left = index + n_accepted_ids - j - 1
current_stopped = True
break break
else: else:
stopped = False current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left] _next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
@ -997,15 +993,12 @@ class FlashCausalLM(Model):
if i % self.world_size == self.rank: if i % self.world_size == self.rank:
if stop: if stop:
# Decode generated tokens # Decode generated tokens
# Remove potentially accepted ids that do not respect
# the stopping_criteria
_ids = all_input_ids
output_text, _, _ = self.decode_token( output_text, _, _ = self.decode_token(
_ids, all_input_ids,
prefix_offset=len(_ids) prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
- 1, - 1,
read_offset=len(_ids) read_offset=len(all_input_ids)
- stopping_criteria.current_tokens, - stopping_criteria.current_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )

View File

@ -34,7 +34,7 @@ class MedusaModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
logits = self.lm_head(x) logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return Output(logits=logits, speculative_logits=speculative_logits) return logits, speculative_logits
class MedusaHead(torch.nn.Module): class MedusaHead(torch.nn.Module):

View File

@ -186,6 +186,8 @@ def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, a
index = 0 index = 0
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())): for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
stop = len(_input_ids) stop = len(_input_ids)
# TODO 0 is not necessarily the pad token.
# Remove zero padded end.
for j, _id in enumerate(_input_ids): for j, _id in enumerate(_input_ids):
if _id == 0: if _id == 0:
stop = j stop = j