mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
re-create slots
This commit is contained in:
parent
4db5e7dde6
commit
b49978ff67
@ -237,7 +237,11 @@ impl Client {
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let request = tonic::Request::new(DecodeRequest {
|
||||
batch: None,
|
||||
batches,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
|
@ -200,6 +200,8 @@ pub(crate) async fn batching_task(
|
||||
(min_size, max_size, max_batch_prefill_tokens)
|
||||
};
|
||||
|
||||
let mut additional_batch = None;
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
@ -209,11 +211,19 @@ pub(crate) async fn batching_task(
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
let counter = if support_chunking {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
};
|
||||
counter.increment(1);
|
||||
}
|
||||
|
||||
if support_chunking {
|
||||
entries.extend(new_entries);
|
||||
additional_batch = Some(new_batch);
|
||||
} else {
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
@ -237,6 +247,7 @@ pub(crate) async fn batching_task(
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
@ -252,7 +263,7 @@ pub(crate) async fn batching_task(
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
cached_batch = decode(&mut client, additional_batch, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
@ -306,6 +317,7 @@ async fn prefill(
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batch: Option<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
@ -313,7 +325,7 @@ async fn decode(
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
match client.decode(batch, batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
|
@ -235,9 +235,10 @@ impl Client {
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batch: Option<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
|
@ -167,12 +167,13 @@ impl ShardedClient {
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batch: Option<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.map(|client| Box::pin(client.decode(batch.clone(), batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
|
@ -243,6 +243,8 @@ message PrefillResponse {
|
||||
message DecodeRequest {
|
||||
/// Cached batches
|
||||
repeated CachedBatch batches = 1;
|
||||
/// Optional Batch
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
|
||||
message DecodeResponse {
|
||||
|
@ -179,6 +179,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if len(batches) == 0:
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if self.model.support_chunking:
|
||||
if request.HasField("batch"):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) > 1:
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
|
Loading…
Reference in New Issue
Block a user