mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
use only last logits
This commit is contained in:
parent
189fba28d1
commit
e9441a1ea2
@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
|
|||||||
parameters: {
|
parameters: {
|
||||||
max_new_tokens: max_new_tokens,
|
max_new_tokens: max_new_tokens,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
top_p: 0.9
|
top_p: 0.9,
|
||||||
|
seed: 0
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let params = {
|
let params = {
|
||||||
|
@ -66,7 +66,7 @@ impl Client {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
@ -77,7 +77,7 @@ impl Client {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batches
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
|
@ -53,7 +53,7 @@ impl ShardedClient {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
@ -69,7 +69,7 @@ impl ShardedClient {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batches
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
|
@ -96,7 +96,7 @@ impl Infer {
|
|||||||
request: valid_request,
|
request: valid_request,
|
||||||
response_tx,
|
response_tx,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
batch_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
_permit: permit,
|
||||||
@ -228,6 +228,18 @@ async fn batching_task(
|
|||||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
let new_batch_size = new_batch.size;
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span =
|
||||||
|
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||||
|
// Add relationship
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
wrap_future(client.prefill(new_batch), &mut new_entries)
|
||||||
@ -253,7 +265,7 @@ async fn batching_task(
|
|||||||
// Add relationship
|
// Add relationship
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.batch_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
||||||
@ -289,7 +301,7 @@ async fn wrap_future(
|
|||||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
entries.drain().for_each(|(_, entry)| {
|
entries.drain().for_each(|(_, entry)| {
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
@ -312,7 +324,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
|
@ -18,9 +18,8 @@ pub(crate) struct Entry {
|
|||||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Span that will live as long as entry
|
/// Span that will live as long as entry
|
||||||
pub span: Span,
|
pub span: Span,
|
||||||
/// Span for every inference batch
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
/// This span will only live as long as one prefill/decode
|
pub temp_span: Option<Span>,
|
||||||
pub batch_span: Option<Span>,
|
|
||||||
/// Instant when this entry was queued
|
/// Instant when this entry was queued
|
||||||
pub queue_time: Instant,
|
pub queue_time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
@ -125,7 +124,12 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
/// Append an entry to the queue
|
||||||
fn append(&mut self, entry: Entry) {
|
fn append(&mut self, mut entry: Entry) {
|
||||||
|
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||||
|
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||||
|
entry.temp_span = Some(queue_span);
|
||||||
|
|
||||||
|
// Push entry in the queue
|
||||||
self.entries.push((self.next_id, entry));
|
self.entries.push((self.next_id, entry));
|
||||||
self.next_id += 1;
|
self.next_id += 1;
|
||||||
}
|
}
|
||||||
@ -163,7 +167,7 @@ impl State {
|
|||||||
// Add relationship
|
// Add relationship
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.batch_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
@ -235,7 +239,7 @@ mod tests {
|
|||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
batch_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
_permit: permit,
|
||||||
|
@ -76,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token,
|
time_per_token,
|
||||||
seed
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
@ -179,7 +179,8 @@ async fn generate(
|
|||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token
|
time_per_token,
|
||||||
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
@ -241,13 +242,11 @@ async fn generate_stream(
|
|||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
span.record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{:?}", total_time));
|
||||||
span
|
span.record("validation_time", format!("{:?}", validation_time));
|
||||||
.record("validation_time", format!("{:?}", validation_time));
|
|
||||||
span.record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
span
|
span.record("inference_time", format!("{:?}", inference_time));
|
||||||
.record("inference_time", format!("{:?}", inference_time));
|
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
span
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
.record("time_per_token", format!("{:?}", time_per_token));
|
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
|
@ -277,7 +277,6 @@ class CausalLM(Model):
|
|||||||
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("forward")
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
@ -327,93 +326,93 @@ class CausalLM(Model):
|
|||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
with tracer.start_as_current_span("post_processing"):
|
# For each member of the batch
|
||||||
# For each member of the batch
|
for i, (
|
||||||
for i, (
|
request,
|
||||||
request,
|
input_length,
|
||||||
input_length,
|
logits,
|
||||||
logits,
|
next_token_chooser,
|
||||||
next_token_chooser,
|
stopping_criteria,
|
||||||
stopping_criteria,
|
all_input_ids,
|
||||||
all_input_ids,
|
) in enumerate(iterator):
|
||||||
) in enumerate(iterator):
|
# Select next token
|
||||||
# Select next token
|
next_token_id, logprobs = next_token_chooser(
|
||||||
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
all_input_ids.view(1, -1), logits
|
||||||
next_token_id = tokens[-1].view(1, 1)
|
)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(
|
next_token_text = self.tokenizer.decode(
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate stopping criteria
|
||||||
|
stop, reason = stopping_criteria(
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
# Decode generated tokens
|
||||||
|
output_text = self.decode(
|
||||||
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Keep request in the batch
|
||||||
|
generated_text = None
|
||||||
|
next_batch_keep_indices.append(i)
|
||||||
|
next_batch_input_ids.append(next_token_id)
|
||||||
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
|
next_batch_size += 1
|
||||||
|
next_batch_input_lengths.append(new_input_length)
|
||||||
|
next_batch_max_sequence_length = max(
|
||||||
|
next_batch_max_sequence_length, new_input_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if stopping_criteria.current_tokens == 1:
|
||||||
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||||
|
1, all_input_ids[1:]
|
||||||
|
).squeeze(1)[-new_input_length:-1].tolist()
|
||||||
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
# Evaluate stopping criteria
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||||
stop, reason = stopping_criteria(
|
|
||||||
next_token_id_squeezed,
|
|
||||||
next_token_text,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
if stop:
|
generation = Generation(
|
||||||
# Decode generated tokens
|
request.id,
|
||||||
output_text = self.decode(
|
prefill_tokens,
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
next_token_id_squeezed,
|
||||||
)
|
next_token_logprob,
|
||||||
# Get seed
|
next_token_text,
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
generated_text,
|
||||||
seed = next_token_chooser.choice.seed
|
)
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
generations.append(generation)
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Keep request in the batch
|
|
||||||
generated_text = None
|
|
||||||
next_batch_keep_indices.append(i)
|
|
||||||
next_batch_input_ids.append(next_token_id)
|
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
|
||||||
next_batch_size += 1
|
|
||||||
next_batch_input_lengths.append(new_input_length)
|
|
||||||
next_batch_max_sequence_length = max(
|
|
||||||
next_batch_max_sequence_length, new_input_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if stopping_criteria.current_tokens == 1:
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
|
||||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
|
||||||
1, all_input_ids[1:]
|
|
||||||
).squeeze(1)[-new_input_length:-1].tolist()
|
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
prefill_tokens = PrefillTokens(
|
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
|
||||||
request.id,
|
|
||||||
prefill_tokens,
|
|
||||||
next_token_id_squeezed,
|
|
||||||
next_token_logprob,
|
|
||||||
next_token_text,
|
|
||||||
generated_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
generations.append(generation)
|
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
|
@ -328,7 +328,6 @@ class Seq2SeqLM(Model):
|
|||||||
def decode(self, decoder_ids: List[int]) -> str:
|
def decode(self, decoder_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
@tracer.start_as_current_span("forward")
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -407,94 +406,91 @@ class Seq2SeqLM(Model):
|
|||||||
batch.decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
with tracer.start_as_current_span("post_processing"):
|
# For each member of the batch
|
||||||
# For each member of the batch
|
for i, (
|
||||||
for i, (
|
request,
|
||||||
request,
|
input_length,
|
||||||
input_length,
|
decoder_input_length,
|
||||||
decoder_input_length,
|
logits,
|
||||||
logits,
|
next_token_chooser,
|
||||||
next_token_chooser,
|
stopping_criteria,
|
||||||
stopping_criteria,
|
input_tokens,
|
||||||
input_tokens,
|
decoder_input_ids,
|
||||||
decoder_input_ids,
|
) in enumerate(iterator):
|
||||||
) in enumerate(iterator):
|
# Select next token
|
||||||
# Select next token
|
next_token_id, logprobs = next_token_chooser(
|
||||||
next_token_id, logprobs = next_token_chooser(
|
decoder_input_ids.view(1, -1), logits
|
||||||
decoder_input_ids.view(1, -1), logits
|
)
|
||||||
|
|
||||||
|
# Append next token to decoder tokens
|
||||||
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||||
|
new_decoder_input_length = decoder_input_length + 1
|
||||||
|
|
||||||
|
# Generated token
|
||||||
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
|
next_token_text = self.tokenizer.decode(
|
||||||
|
next_token_id_squeezed,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate stopping criteria
|
||||||
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
# Slice with decoder_input_length to remove padding
|
||||||
|
# Decode all tokens
|
||||||
|
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||||
|
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Keep request in the batch
|
||||||
|
generated_text = None
|
||||||
|
next_batch_keep_indices.append(i)
|
||||||
|
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||||
|
next_batch_size += 1
|
||||||
|
next_batch_input_lengths.append(input_length)
|
||||||
|
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||||
|
next_batch_max_input_length = max(
|
||||||
|
next_batch_max_input_length, input_length
|
||||||
|
)
|
||||||
|
next_batch_max_decoder_input_length = max(
|
||||||
|
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Prefill
|
||||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
if stopping_criteria.current_tokens == 1:
|
||||||
new_decoder_input_length = decoder_input_length + 1
|
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
# Generated token
|
prefill_token_ids,
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
|
||||||
next_token_text = self.tokenizer.decode(
|
|
||||||
next_token_id_squeezed,
|
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
# Evaluate stopping criteria
|
prefill_token_ids, [float("nan")], prefill_texts
|
||||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
# Slice with decoder_input_length to remove padding
|
|
||||||
# Decode all tokens
|
|
||||||
output_text = self.decode(
|
|
||||||
decoder_input_ids[-new_decoder_input_length:]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get seed
|
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
|
||||||
seed = next_token_chooser.choice.seed
|
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Keep request in the batch
|
|
||||||
generated_text = None
|
|
||||||
next_batch_keep_indices.append(i)
|
|
||||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
|
||||||
next_batch_size += 1
|
|
||||||
next_batch_input_lengths.append(input_length)
|
|
||||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
|
||||||
next_batch_max_input_length = max(
|
|
||||||
next_batch_max_input_length, input_length
|
|
||||||
)
|
|
||||||
next_batch_max_decoder_input_length = max(
|
|
||||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if stopping_criteria.current_tokens == 1:
|
|
||||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
prefill_tokens = PrefillTokens(
|
|
||||||
prefill_token_ids, [float("nan")], prefill_texts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
|
||||||
request.id,
|
|
||||||
prefill_tokens,
|
|
||||||
next_token_id_squeezed,
|
|
||||||
next_token_logprob,
|
|
||||||
next_token_text,
|
|
||||||
generated_text,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
generations.append(generation)
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
|
@ -36,16 +36,14 @@ class Sampling:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits)
|
||||||
next_tokens = torch.multinomial(
|
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
||||||
probs, num_samples=1, generator=self.generator
|
|
||||||
).squeeze(1)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax()
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -87,8 +85,9 @@ class NextTokenChooser:
|
|||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
||||||
# Choose tokens
|
# Choose tokens
|
||||||
next_ids = self.choice(scores)
|
next_id = self.choice(scores[-1])
|
||||||
return next_ids, logprobs
|
|
||||||
|
return next_id.view(1, 1), logprobs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
|
Loading…
Reference in New Issue
Block a user