use only last logits

This commit is contained in:
OlivierDehaene 2023-02-10 18:36:34 +01:00
parent 189fba28d1
commit e9441a1ea2
9 changed files with 202 additions and 192 deletions

View File

@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
parameters: { parameters: {
max_new_tokens: max_new_tokens, max_new_tokens: max_new_tokens,
do_sample: true, do_sample: true,
top_p: 0.9 top_p: 0.9,
seed: 0
} }
}); });
let params = { let params = {

View File

@ -66,7 +66,7 @@ impl Client {
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] #[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
@ -77,7 +77,7 @@ impl Client {
/// ///
/// Returns Generation for each request in batches /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))] #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,

View File

@ -53,7 +53,7 @@ impl ShardedClient {
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] #[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
@ -69,7 +69,7 @@ impl ShardedClient {
/// ///
/// Returns Generation for each request in batches /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))] #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,

View File

@ -96,7 +96,7 @@ impl Infer {
request: valid_request, request: valid_request,
response_tx, response_tx,
span: Span::current(), span: Span::current(),
batch_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,
@ -228,6 +228,18 @@ async fn batching_task(
.next_batch(min_size, max_batch_size - batch_size as usize) .next_batch(min_size, max_batch_size - batch_size as usize)
.await .await
{ {
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries) wrap_future(client.prefill(new_batch), &mut new_entries)
@ -253,7 +265,7 @@ async fn batching_task(
// Add relationship // Add relationship
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
// Update entry // Update entry
entry.batch_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
}); });
cached_batch = wrap_future(client.decode(batches), &mut entries) cached_batch = wrap_future(client.decode(batches), &mut entries)
@ -289,7 +301,7 @@ async fn wrap_future(
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) { fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| { entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
tracing::error!("{err}"); tracing::error!("{err}");
@ -312,7 +324,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message

View File

@ -18,9 +18,8 @@ pub(crate) struct Entry {
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry /// Span that will live as long as entry
pub span: Span, pub span: Span,
/// Span for every inference batch /// Temporary span used as a guard when logging inference, wait times...
/// This span will only live as long as one prefill/decode pub temp_span: Option<Span>,
pub batch_span: Option<Span>,
/// Instant when this entry was queued /// Instant when this entry was queued
pub queue_time: Instant, pub queue_time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
@ -125,7 +124,12 @@ impl State {
} }
/// Append an entry to the queue /// Append an entry to the queue
fn append(&mut self, entry: Entry) { fn append(&mut self, mut entry: Entry) {
// Create a span that will live as long as the entry is in the queue waiting to be batched
let queue_span = info_span!(parent: &entry.span, "queued");
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push((self.next_id, entry)); self.entries.push((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
} }
@ -163,7 +167,7 @@ impl State {
// Add relationship // Add relationship
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
// Update entry // Update entry
entry.batch_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request { batch_requests.push(Request {
id, id,
@ -235,7 +239,7 @@ mod tests {
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),
batch_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,

View File

@ -76,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
queue_time, queue_time,
inference_time, inference_time,
time_per_token, time_per_token,
seed seed,
) )
)] )]
async fn generate( async fn generate(
@ -179,7 +179,8 @@ async fn generate(
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token time_per_token,
seed,
) )
)] )]
async fn generate_stream( async fn generate_stream(
@ -241,13 +242,11 @@ async fn generate_stream(
// Tracing metadata // Tracing metadata
span.record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{:?}", total_time));
span span.record("validation_time", format!("{:?}", validation_time));
.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{:?}", queue_time));
span span.record("inference_time", format!("{:?}", inference_time));
.record("inference_time", format!("{:?}", inference_time)); span.record("time_per_token", format!("{:?}", time_per_token));
span span.record("seed", format!("{:?}", generated_text.seed));
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text); tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamResponse // StreamResponse

View File

@ -277,7 +277,6 @@ class CausalLM(Model):
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
) )
@tracer.start_as_current_span("forward")
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -327,93 +326,93 @@ class CausalLM(Model):
batch.all_input_ids, batch.all_input_ids,
) )
with tracer.start_as_current_span("post_processing"): # For each member of the batch
# For each member of the batch for i, (
for i, ( request,
request, input_length,
input_length, logits,
logits, next_token_chooser,
next_token_chooser, stopping_criteria,
stopping_criteria, all_input_ids,
all_input_ids, ) in enumerate(iterator):
) in enumerate(iterator): # Select next token
# Select next token next_token_id, logprobs = next_token_chooser(
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits) all_input_ids.view(1, -1), logits
next_token_id = tokens[-1].view(1, 1) )
# Append next token to all tokens # Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id]) all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1 new_input_length = input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode( next_token_text = self.tokenizer.decode(
next_token_id_squeezed, next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens(
# Evaluate stopping criteria prefill_token_ids, prefill_logprobs, prefill_texts
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
) )
else:
prefill_tokens = None
if stop: generation = Generation(
# Decode generated tokens request.id,
output_text = self.decode( prefill_tokens,
all_input_ids[-stopping_criteria.current_tokens :, 0] next_token_id_squeezed,
) next_token_logprob,
# Get seed next_token_text,
if isinstance(next_token_chooser.choice, Sampling): generated_text,
seed = next_token_chooser.choice.seed )
else:
seed = None
generated_text = GeneratedText( generations.append(generation)
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:

View File

@ -328,7 +328,6 @@ class Seq2SeqLM(Model):
def decode(self, decoder_ids: List[int]) -> str: def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
@tracer.start_as_current_span("forward")
def forward( def forward(
self, self,
input_ids, input_ids,
@ -407,94 +406,91 @@ class Seq2SeqLM(Model):
batch.decoder_input_ids, batch.decoder_input_ids,
) )
with tracer.start_as_current_span("post_processing"): # For each member of the batch
# For each member of the batch for i, (
for i, ( request,
request, input_length,
input_length, decoder_input_length,
decoder_input_length, logits,
logits, next_token_chooser,
next_token_chooser, stopping_criteria,
stopping_criteria, input_tokens,
input_tokens, decoder_input_ids,
decoder_input_ids, ) in enumerate(iterator):
) in enumerate(iterator): # Select next token
# Select next token next_token_id, logprobs = next_token_chooser(
next_token_id, logprobs = next_token_chooser( decoder_input_ids.view(1, -1), logits
decoder_input_ids.view(1, -1), logits )
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
) )
# Append next token to decoder tokens # Prefill
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) if stopping_criteria.current_tokens == 1:
new_decoder_input_length = decoder_input_length + 1 prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
# Generated token prefill_token_ids,
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens(
# Evaluate stopping criteria prefill_token_ids, [float("nan")], prefill_texts
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(
decoder_input_ids[-new_decoder_input_length:]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
) )
else:
prefill_tokens = None
generations.append(generation) generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:

View File

@ -36,16 +36,14 @@ class Sampling:
self.seed = seed self.seed = seed
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial( next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens return next_tokens
class Greedy: class Greedy:
def __call__(self, logits): def __call__(self, logits):
return logits.argmax(dim=-1) return logits.argmax()
class NextTokenChooser: class NextTokenChooser:
@ -87,8 +85,9 @@ class NextTokenChooser:
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
# Choose tokens # Choose tokens
next_ids = self.choice(scores) next_id = self.choice(scores[-1])
return next_ids, logprobs
return next_id.view(1, 1), logprobs
@classmethod @classmethod
def from_pb( def from_pb(