diff --git a/k6/load_test.js b/k6/load_test.js index 3f3791af..516b5666 100644 --- a/k6/load_test.js +++ b/k6/load_test.js @@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) { parameters: { max_new_tokens: max_new_tokens, do_sample: true, - top_p: 0.9 + top_p: 0.9, + seed: 0 } }); let params = { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 1f0d23f2..9620cf70 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -66,7 +66,7 @@ impl Client { /// /// Returns Generation for each request in batch /// and the next cached batch - #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))] pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let response = self.stub.prefill(request).await?.into_inner(); @@ -77,7 +77,7 @@ impl Client { /// /// Returns Generation for each request in batches /// and the next cached batch - #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))] pub async fn decode( &mut self, batches: Vec, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 2e662ca3..3f1cd4ab 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -53,7 +53,7 @@ impl ShardedClient { /// /// Returns Generation for each request in batch /// and the next cached batch - #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))] pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients @@ -69,7 +69,7 @@ impl ShardedClient { /// /// Returns Generation for each request in batches /// and the next cached batch - #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))] pub async fn decode( &mut self, batches: Vec, diff --git a/router/src/infer.rs b/router/src/infer.rs index 791a5b5e..4e368492 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -96,7 +96,7 @@ impl Infer { request: valid_request, response_tx, span: Span::current(), - batch_span: None, + temp_span: None, queue_time: Instant::now(), batch_time: None, _permit: permit, @@ -228,6 +228,18 @@ async fn batching_task( .next_batch(min_size, max_batch_size - batch_size as usize) .await { + let new_batch_size = new_batch.size; + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = + info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size); + // Add relationship + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + // Generate one token for this new batch to have the attention past in cache let new_cached_batch = wrap_future(client.prefill(new_batch), &mut new_entries) @@ -253,7 +265,7 @@ async fn batching_task( // Add relationship entry_batch_span.follows_from(&next_batch_span); // Update entry - entry.batch_span = Some(entry_batch_span); + entry.temp_span = Some(entry_batch_span); }); cached_batch = wrap_future(client.decode(batches), &mut entries) @@ -289,7 +301,7 @@ async fn wrap_future( fn send_errors(error: ClientError, entries: &mut IntMap) { entries.drain().for_each(|(_, entry)| { // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); tracing::error!("{err}"); @@ -312,7 +324,7 @@ fn send_generations(generations: Vec, entries: &mut IntMap>, /// Span that will live as long as entry pub span: Span, - /// Span for every inference batch - /// This span will only live as long as one prefill/decode - pub batch_span: Option, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, /// Instant when this entry was queued pub queue_time: Instant, /// Instant when this entry was added to a batch @@ -125,7 +124,12 @@ impl State { } /// Append an entry to the queue - fn append(&mut self, entry: Entry) { + fn append(&mut self, mut entry: Entry) { + // Create a span that will live as long as the entry is in the queue waiting to be batched + let queue_span = info_span!(parent: &entry.span, "queued"); + entry.temp_span = Some(queue_span); + + // Push entry in the queue self.entries.push((self.next_id, entry)); self.next_id += 1; } @@ -163,7 +167,7 @@ impl State { // Add relationship entry_batch_span.follows_from(&next_batch_span); // Update entry - entry.batch_span = Some(entry_batch_span); + entry.temp_span = Some(entry_batch_span); batch_requests.push(Request { id, @@ -235,7 +239,7 @@ mod tests { }, response_tx, span: info_span!("entry"), - batch_span: None, + temp_span: None, queue_time: Instant::now(), batch_time: None, _permit: permit, diff --git a/router/src/server.rs b/router/src/server.rs index 09be5f22..432586bb 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -76,7 +76,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: @@ -327,93 +326,93 @@ class CausalLM(Model): batch.all_input_ids, ) - with tracer.start_as_current_span("post_processing"): - # For each member of the batch - for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits) - next_token_id = tokens[-1].view(1, 1) + # For each member of the batch + for i, ( + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits + ) - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( - next_token_id_squeezed, + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :, 0] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + generated_text = None + next_batch_keep_indices.append(i) + next_batch_input_ids.append(next_token_id) + next_batch_all_input_ids.append(all_input_ids) + next_batch_size += 1 + next_batch_input_lengths.append(new_input_length) + next_batch_max_sequence_length = max( + next_batch_max_sequence_length, new_input_length + ) + + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, all_input_ids[1:] + ).squeeze(1)[-new_input_length:-1].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts ) + else: + prefill_tokens = None - if stop: - # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, + ) - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - generated_text = None - next_batch_keep_indices.append(i) - next_batch_input_ids.append(next_token_id) - next_batch_all_input_ids.append(all_input_ids) - next_batch_size += 1 - next_batch_input_lengths.append(new_input_length) - next_batch_max_sequence_length = max( - next_batch_max_sequence_length, new_input_length - ) - - # Prefill - if stopping_criteria.current_tokens == 1: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids[1:] - ).squeeze(1)[-new_input_length:-1].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - generated_text, - ) - - generations.append(generation) + generations.append(generation) # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 96dfeb67..d6cccd44 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -328,7 +328,6 @@ class Seq2SeqLM(Model): def decode(self, decoder_ids: List[int]) -> str: return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) - @tracer.start_as_current_span("forward") def forward( self, input_ids, @@ -407,94 +406,91 @@ class Seq2SeqLM(Model): batch.decoder_input_ids, ) - with tracer.start_as_current_span("post_processing"): - # For each member of the batch - for i, ( - request, - input_length, - decoder_input_length, - logits, - next_token_chooser, - stopping_criteria, - input_tokens, - decoder_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - decoder_input_ids.view(1, -1), logits + # For each member of the batch + for i, ( + request, + input_length, + decoder_input_length, + logits, + next_token_chooser, + stopping_criteria, + input_tokens, + decoder_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + decoder_input_ids.view(1, -1), logits + ) + + # Append next token to decoder tokens + decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) + new_decoder_input_length = decoder_input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria(next_token_id, next_token_text) + + if stop: + # Slice with decoder_input_length to remove padding + # Decode all tokens + output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) + + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + generated_text = None + next_batch_keep_indices.append(i) + next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) + next_batch_size += 1 + next_batch_input_lengths.append(input_length) + next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_max_input_length = max( + next_batch_max_input_length, input_length + ) + next_batch_max_decoder_input_length = max( + next_batch_max_decoder_input_length, new_decoder_input_length ) - # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) - new_decoder_input_length = decoder_input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( - next_token_id_squeezed, + # Prefill + if stopping_criteria.current_tokens == 1: + prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria(next_token_id, next_token_text) - - if stop: - # Slice with decoder_input_length to remove padding - # Decode all tokens - output_text = self.decode( - decoder_input_ids[-new_decoder_input_length:] - ) - - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - generated_text = None - next_batch_keep_indices.append(i) - next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) - next_batch_size += 1 - next_batch_input_lengths.append(input_length) - next_batch_decoder_input_lengths.append(new_decoder_input_length) - next_batch_max_input_length = max( - next_batch_max_input_length, input_length - ) - next_batch_max_decoder_input_length = max( - next_batch_max_decoder_input_length, new_decoder_input_length - ) - - # Prefill - if stopping_criteria.current_tokens == 1: - prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, [float("nan")], prefill_texts - ) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - generated_text, + prefill_tokens = PrefillTokens( + prefill_token_ids, [float("nan")], prefill_texts ) + else: + prefill_tokens = None - generations.append(generation) + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, + ) + + generations.append(generation) # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index abdee9dc..3b3f08c7 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -36,16 +36,14 @@ class Sampling: self.seed = seed def __call__(self, logits): - probs = torch.nn.functional.softmax(logits, dim=-1) - next_tokens = torch.multinomial( - probs, num_samples=1, generator=self.generator - ).squeeze(1) + probs = torch.nn.functional.softmax(logits) + next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) return next_tokens class Greedy: def __call__(self, logits): - return logits.argmax(dim=-1) + return logits.argmax() class NextTokenChooser: @@ -87,8 +85,9 @@ class NextTokenChooser: logprobs = torch.log_softmax(scores, -1) # Choose tokens - next_ids = self.choice(scores) - return next_ids, logprobs + next_id = self.choice(scores[-1]) + + return next_id.view(1, 1), logprobs @classmethod def from_pb(