diff --git a/router/src/infer.rs b/router/src/infer.rs index d4057f1f..2e199ce2 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -527,11 +527,6 @@ fn send_responses( let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); - - assert_eq!(n, tokens_.logprobs.len()); - assert_eq!(n, tokens_.texts.len()); - assert_eq!(n, tokens_.is_special.len()); - let mut iterator = tokens_ .ids .into_iter() diff --git a/router/src/queue.rs b/router/src/queue.rs index 87ee285a..106cacc4 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -229,7 +229,7 @@ impl State { } if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate; + decode_tokens += entry.request.stopping_parameters.max_new_tokens; } else { let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, @@ -237,7 +237,7 @@ impl State { window_size.saturating_sub(entry.request.input_length), entry.request.stopping_parameters.max_new_tokens, ), - } + self.speculate; + }; // pad to block size decode_tokens += @@ -245,7 +245,7 @@ impl State { } if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens) > token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget { // Entry is over budget // Add it back to the front @@ -543,7 +543,15 @@ mod tests { queue.append(entry1); queue.append(entry2); + // Budget of 1 is not enough assert!(queue.next_batch(None, 1, 1).await.is_none()); + + let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); } #[tokio::test] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d0708e11..63e024ac 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -162,9 +162,6 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] - speculate_ids = [] - - input_length = len(tokenized_input) input_lengths.append(input_length) @@ -806,10 +803,9 @@ class FlashCausalLM(Model): del batch raise e - try: - out, speculative_logits = out.logits, out.speculative_logits - except Exception: - out = out + if isinstance(out, tuple): + out, speculative_logits = out + else: speculative_logits = None @@ -829,9 +825,6 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits ) - from loguru import logger - logger.info(f"Accepted id {accepted_ids}") - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) @@ -843,8 +836,7 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - length = len(batch) - next_position_ids = batch.position_ids.new_empty(length) + next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None @@ -965,6 +957,9 @@ class FlashCausalLM(Model): # Append next token to all tokens next_token_texts = [] left = 0 + before = stopping_criteria.current_tokens + + current_stopped = False for j in range(index, index + n_accepted_ids): # Generated token next_token_id = next_token_ids[j] @@ -982,11 +977,12 @@ class FlashCausalLM(Model): ) if stop: - stopped = True left = index + n_accepted_ids - j - 1 + current_stopped = True break else: - stopped = False + current_stopped = False + stopped = stopped and current_stopped _next_token_ids = next_token_ids[index: index+n_accepted_ids - left] _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] @@ -997,15 +993,12 @@ class FlashCausalLM(Model): if i % self.world_size == self.rank: if stop: # Decode generated tokens - # Remove potentially accepted ids that do not respect - # the stopping_criteria - _ids = all_input_ids output_text, _, _ = self.decode_token( - _ids, - prefix_offset=len(_ids) + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(_ids) + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, skip_special_tokens=True, ) diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py index ce908333..afa9bfc4 100644 --- a/server/text_generation_server/utils/medusa.py +++ b/server/text_generation_server/utils/medusa.py @@ -34,7 +34,7 @@ class MedusaModel(torch.nn.Module): def forward(self, x): logits = self.lm_head(x) speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) - return Output(logits=logits, speculative_logits=speculative_logits) + return logits, speculative_logits class MedusaHead(torch.nn.Module): diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0f6c7ce7..a6bec102 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -186,6 +186,8 @@ def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, a index = 0 for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())): stop = len(_input_ids) + # TODO 0 is not necessarily the pad token. + # Remove zero padded end. for j, _id in enumerate(_input_ids): if _id == 0: stop = j