use only last logits

This commit is contained in:
OlivierDehaene 2023-02-10 18:36:34 +01:00
parent 189fba28d1
commit e9441a1ea2
9 changed files with 202 additions and 192 deletions

View File

@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
parameters: {
max_new_tokens: max_new_tokens,
do_sample: true,
top_p: 0.9
top_p: 0.9,
seed: 0
}
});
let params = {

View File

@ -66,7 +66,7 @@ impl Client {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
@ -77,7 +77,7 @@ impl Client {
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,

View File

@ -53,7 +53,7 @@ impl ShardedClient {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
@ -69,7 +69,7 @@ impl ShardedClient {
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,

View File

@ -96,7 +96,7 @@ impl Infer {
request: valid_request,
response_tx,
span: Span::current(),
batch_span: None,
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
@ -228,6 +228,18 @@ async fn batching_task(
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries)
@ -253,7 +265,7 @@ async fn batching_task(
// Add relationship
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.batch_span = Some(entry_batch_span);
entry.temp_span = Some(entry_batch_span);
});
cached_batch = wrap_future(client.decode(batches), &mut entries)
@ -289,7 +301,7 @@ async fn wrap_future(
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
tracing::error!("{err}");
@ -312,7 +324,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message

View File

@ -18,9 +18,8 @@ pub(crate) struct Entry {
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry
pub span: Span,
/// Span for every inference batch
/// This span will only live as long as one prefill/decode
pub batch_span: Option<Span>,
/// Temporary span used as a guard when logging inference, wait times...
pub temp_span: Option<Span>,
/// Instant when this entry was queued
pub queue_time: Instant,
/// Instant when this entry was added to a batch
@ -125,7 +124,12 @@ impl State {
}
/// Append an entry to the queue
fn append(&mut self, entry: Entry) {
fn append(&mut self, mut entry: Entry) {
// Create a span that will live as long as the entry is in the queue waiting to be batched
let queue_span = info_span!(parent: &entry.span, "queued");
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push((self.next_id, entry));
self.next_id += 1;
}
@ -163,7 +167,7 @@ impl State {
// Add relationship
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.batch_span = Some(entry_batch_span);
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request {
id,
@ -235,7 +239,7 @@ mod tests {
},
response_tx,
span: info_span!("entry"),
batch_span: None,
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,

View File

@ -76,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
queue_time,
inference_time,
time_per_token,
seed
seed,
)
)]
async fn generate(
@ -179,7 +179,8 @@ async fn generate(
validation_time,
queue_time,
inference_time,
time_per_token
time_per_token,
seed,
)
)]
async fn generate_stream(
@ -241,13 +242,11 @@ async fn generate_stream(
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span
.record("validation_time", format!("{:?}", validation_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span
.record("inference_time", format!("{:?}", inference_time));
span
.record("time_per_token", format!("{:?}", time_per_token));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamResponse

View File

@ -277,7 +277,6 @@ class CausalLM(Model):
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
)
@tracer.start_as_current_span("forward")
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -327,93 +326,93 @@ class CausalLM(Model):
batch.all_input_ids,
)
with tracer.start_as_current_span("post_processing"):
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:

View File

@ -328,7 +328,6 @@ class Seq2SeqLM(Model):
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
@tracer.start_as_current_span("forward")
def forward(
self,
input_ids,
@ -407,94 +406,91 @@ class Seq2SeqLM(Model):
batch.decoder_input_ids,
)
with tracer.start_as_current_span("post_processing"):
# For each member of the batch
for i, (
request,
input_length,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
# For each member of the batch
for i, (
request,
input_length,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(
decoder_input_ids[-new_decoder_input_length:]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generations.append(generation)
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:

View File

@ -36,16 +36,14 @@ class Sampling:
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
return logits.argmax()
class NextTokenChooser:
@ -87,8 +85,9 @@ class NextTokenChooser:
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_ids = self.choice(scores)
return next_ids, logprobs
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(