mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
use only last logits
This commit is contained in:
parent
189fba28d1
commit
e9441a1ea2
@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
|
|||||||
parameters: {
|
parameters: {
|
||||||
max_new_tokens: max_new_tokens,
|
max_new_tokens: max_new_tokens,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
top_p: 0.9
|
top_p: 0.9,
|
||||||
|
seed: 0
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let params = {
|
let params = {
|
||||||
|
@ -66,7 +66,7 @@ impl Client {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
@ -77,7 +77,7 @@ impl Client {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batches
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
|
@ -53,7 +53,7 @@ impl ShardedClient {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
@ -69,7 +69,7 @@ impl ShardedClient {
|
|||||||
///
|
///
|
||||||
/// Returns Generation for each request in batches
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
|
@ -96,7 +96,7 @@ impl Infer {
|
|||||||
request: valid_request,
|
request: valid_request,
|
||||||
response_tx,
|
response_tx,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
batch_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
_permit: permit,
|
||||||
@ -228,6 +228,18 @@ async fn batching_task(
|
|||||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
let new_batch_size = new_batch.size;
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span =
|
||||||
|
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||||
|
// Add relationship
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
wrap_future(client.prefill(new_batch), &mut new_entries)
|
||||||
@ -253,7 +265,7 @@ async fn batching_task(
|
|||||||
// Add relationship
|
// Add relationship
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.batch_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
||||||
@ -289,7 +301,7 @@ async fn wrap_future(
|
|||||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
entries.drain().for_each(|(_, entry)| {
|
entries.drain().for_each(|(_, entry)| {
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
@ -312,7 +324,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
|
@ -18,9 +18,8 @@ pub(crate) struct Entry {
|
|||||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Span that will live as long as entry
|
/// Span that will live as long as entry
|
||||||
pub span: Span,
|
pub span: Span,
|
||||||
/// Span for every inference batch
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
/// This span will only live as long as one prefill/decode
|
pub temp_span: Option<Span>,
|
||||||
pub batch_span: Option<Span>,
|
|
||||||
/// Instant when this entry was queued
|
/// Instant when this entry was queued
|
||||||
pub queue_time: Instant,
|
pub queue_time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
@ -125,7 +124,12 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
/// Append an entry to the queue
|
||||||
fn append(&mut self, entry: Entry) {
|
fn append(&mut self, mut entry: Entry) {
|
||||||
|
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||||
|
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||||
|
entry.temp_span = Some(queue_span);
|
||||||
|
|
||||||
|
// Push entry in the queue
|
||||||
self.entries.push((self.next_id, entry));
|
self.entries.push((self.next_id, entry));
|
||||||
self.next_id += 1;
|
self.next_id += 1;
|
||||||
}
|
}
|
||||||
@ -163,7 +167,7 @@ impl State {
|
|||||||
// Add relationship
|
// Add relationship
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.batch_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
@ -235,7 +239,7 @@ mod tests {
|
|||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
batch_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
_permit: permit,
|
||||||
|
@ -76,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token,
|
time_per_token,
|
||||||
seed
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
@ -179,7 +179,8 @@ async fn generate(
|
|||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token
|
time_per_token,
|
||||||
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
@ -241,13 +242,11 @@ async fn generate_stream(
|
|||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
span.record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{:?}", total_time));
|
||||||
span
|
span.record("validation_time", format!("{:?}", validation_time));
|
||||||
.record("validation_time", format!("{:?}", validation_time));
|
|
||||||
span.record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
span
|
span.record("inference_time", format!("{:?}", inference_time));
|
||||||
.record("inference_time", format!("{:?}", inference_time));
|
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
span
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
.record("time_per_token", format!("{:?}", time_per_token));
|
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
|
@ -277,7 +277,6 @@ class CausalLM(Model):
|
|||||||
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("forward")
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
@ -327,7 +326,6 @@ class CausalLM(Model):
|
|||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
with tracer.start_as_current_span("post_processing"):
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
@ -338,8 +336,9 @@ class CausalLM(Model):
|
|||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
next_token_id, logprobs = next_token_chooser(
|
||||||
next_token_id = tokens[-1].view(1, 1)
|
all_input_ids.view(1, -1), logits
|
||||||
|
)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
|
@ -328,7 +328,6 @@ class Seq2SeqLM(Model):
|
|||||||
def decode(self, decoder_ids: List[int]) -> str:
|
def decode(self, decoder_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
@tracer.start_as_current_span("forward")
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -407,7 +406,6 @@ class Seq2SeqLM(Model):
|
|||||||
batch.decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
with tracer.start_as_current_span("post_processing"):
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
@ -443,9 +441,7 @@ class Seq2SeqLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||||
decoder_input_ids[-new_decoder_input_length:]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -36,16 +36,14 @@ class Sampling:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits)
|
||||||
next_tokens = torch.multinomial(
|
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
||||||
probs, num_samples=1, generator=self.generator
|
|
||||||
).squeeze(1)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax()
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -87,8 +85,9 @@ class NextTokenChooser:
|
|||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
||||||
# Choose tokens
|
# Choose tokens
|
||||||
next_ids = self.choice(scores)
|
next_id = self.choice(scores[-1])
|
||||||
return next_ids, logprobs
|
|
||||||
|
return next_id.view(1, 1), logprobs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
|
Loading…
Reference in New Issue
Block a user