feedback loop

This commit is contained in:
OlivierDehaene 2024-10-07 12:02:25 +02:00
parent ff4155dfea
commit c8a033b636
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
16 changed files with 153 additions and 163 deletions

View File

@ -218,8 +218,13 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
@ -237,11 +242,7 @@ impl Client {
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest {
batch: None,
batches,
})
.inject_context();
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,

View File

@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@ -256,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}

View File

@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
@ -36,18 +36,14 @@ impl BackendV2 {
speculate: u32,
) -> Self {
// Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
let block_size = match attention.as_str() {
"flashinfer" => 1,
"flashdecoding" => 256,
"paged" => 16,
_ => unreachable!(),
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());

View File

@ -1,12 +1,14 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::client::{
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
};
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
@ -31,32 +33,22 @@ impl BackendV3 {
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
support_chunking: bool,
shard_info: InfoResponse,
) -> Self {
if support_chunking {
if shard_info.support_chunking {
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
}
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let block_size = shard_info.block_size;
let queue = Queue::new(
requires_padding,
shard_info.requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
shard_info.use_prefix_caching,
shard_info.window_size,
shard_info.speculate,
max_batch_total_tokens,
support_chunking,
shard_info.support_chunking,
);
let batching_task_notifier = Arc::new(Notify::new());
@ -68,7 +60,7 @@ impl BackendV3 {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
support_chunking,
shard_info.support_chunking,
queue.clone(),
batching_task_notifier.clone(),
));
@ -154,7 +146,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -175,7 +167,8 @@ pub(crate) async fn batching_task(
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
let prefill_token_budget =
max_batch_prefill_tokens.saturating_sub(current_tokens);
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
@ -199,10 +192,8 @@ pub(crate) async fn batching_task(
(min_size, max_size, max_batch_prefill_tokens)
};
let mut additional_batch = None;
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
if let Some((new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await
{
@ -218,11 +209,11 @@ pub(crate) async fn batching_task(
};
counter.increment(1);
}
if support_chunking {
entries.extend(new_entries);
additional_batch = Some(new_batch);
let cached_batch = if support_chunking {
// Concat current batch to the new one
batches.pop()
} else {
// Request are waiting only if we don't support chunking
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
@ -233,18 +224,23 @@ pub(crate) async fn batching_task(
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
None
};
entries.extend(new_entries);
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
} else if support_chunking {
// New cached batch is empty, no work left
break;
}
}
@ -262,7 +258,7 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, additional_batch, batches, &mut entries)
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
@ -277,13 +273,14 @@ pub(crate) async fn batching_task(
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
@ -292,6 +289,10 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
@ -316,7 +317,6 @@ async fn prefill(
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
@ -324,7 +324,7 @@ async fn decode(
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batch, batches).await {
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries

View File

@ -218,13 +218,23 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
PrefillTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
@ -235,10 +245,9 @@ impl Client {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context();
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
@ -254,14 +263,16 @@ impl Client {
}
pub struct PrefillTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),

View File

@ -135,11 +135,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@ -167,13 +168,12 @@ impl ShardedClient {
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batch.clone(), batches.clone())))
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
@ -246,7 +246,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}

View File

@ -31,6 +31,12 @@ pub struct BackendInfo {
pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
#[schema(example = "false")]
pub prefix_caching: bool,
#[schema(example = "flashinfer")]
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
}
#[allow(clippy::too_many_arguments)]
@ -113,6 +119,9 @@ pub async fn connect_backend(
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
prefix_caching: shard_info.use_prefix_caching,
attention_impl: shard_info.attention_impl.clone(),
block_size: shard_info.block_size,
};
let backend = BackendV3::new(
@ -122,10 +131,7 @@ pub async fn connect_backend(
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
shard_info.support_chunking,
shard_info,
);
tracing::info!("Using backend V3");

View File

@ -89,6 +89,10 @@ impl Queue {
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if prefill_token_budget == 0 || token_budget == 0 {
return None;
};
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state

View File

@ -174,7 +174,7 @@ async fn prefill(
// Run prefill
let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
// Get latency
let latency = start_time.elapsed();

View File

@ -35,6 +35,9 @@ message InfoResponse {
optional uint32 window_size = 4;
uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
}
/// Empty request
@ -225,6 +228,8 @@ message FilterBatchResponse {
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}
message PrefillResponse {
@ -238,13 +243,13 @@ message PrefillResponse {
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
/// Optional Batch
optional Batch batch = 2;
}
message DecodeResponse {

View File

@ -18,45 +18,6 @@ use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {

View File

@ -76,7 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
current_tokens=len(self.input_ids),
)
@classmethod

View File

@ -171,7 +171,7 @@ class FlashCausalLMBatch(Batch):
# Will be set by `generate_token` and reset after each prefill forward
prefill_cu_outlens: Optional[List[int]]
# Will be set by `generate_token` and reset after each prefill forward
prefill_tokens: List[Optional[Tokens]]
prefill_logprob_tokens: List[Optional[Tokens]]
# Prefixes
prefix_ids: List[List[int]]
@ -290,8 +290,7 @@ class FlashCausalLMBatch(Batch):
prefix_length <= prompt_length
), f"Prefix {prefix_length} vs input {prompt_length}"
if prefix_length == prompt_length:
assert prefix_length > 0
prefix_length -= 1
assert False, "unreachable"
if prefix_length + postfix_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
@ -303,7 +302,9 @@ class FlashCausalLMBatch(Batch):
prefix_length : prefix_length + postfix_length
]
postfix_length = len(postfix_ids)
assert (
len(postfix_ids) == postfix_length
), "Rust and Python tokenizers are not aligned"
postfix_lengths.append(postfix_length)
prefix_offsets.append(prompt_length - 5)
@ -394,7 +395,7 @@ class FlashCausalLMBatch(Batch):
max_current_length=max_current_length,
prefilling=True,
prefilling_mask=[True] * len(pb.requests),
prefill_tokens=[None] * len(pb.requests),
prefill_logprob_tokens=[None] * len(pb.requests),
postfix_lengths=postfix_lengths,
prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets,
@ -475,7 +476,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = []
prefilling_mask = []
prefill_tokens = []
prefill_logprob_tokens = []
stopping_criterias = []
top_n_tokens = []
@ -518,7 +519,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
prefill_tokens.append(self.prefill_tokens[idx])
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
@ -611,7 +612,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_tokens=prefill_tokens,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths,
@ -726,7 +727,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = []
read_offsets = []
prefill_tokens = []
prefill_logprob_tokens = []
next_token_chooser_parameters = []
fsm_grammar_states = []
@ -814,7 +815,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
prefill_tokens.extend(batch.prefill_tokens)
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
@ -869,7 +870,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_tokens=prefill_tokens,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths,
@ -1769,9 +1770,10 @@ class FlashCausalLM(Model):
if get_support_chunking():
next_prefilling_mask = []
# Budget in tokens for the next batch
# We remove len(batch) to always have enough space for at least a single decode
# for the remaining requests
batch_budget = get_max_prefill_tokens() - len(batch)
# We remove (len(batch) - 1) to always have enough space for at least a single decode
# for the remaining requests -1 because the first request does not need to be removed from the budget
# (ex: you have one request in the batch, you want it to take the full budget not budget -1)
batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
# We reverse to prioritize older requests
# zip() is not reversible so reverse the underlying lists instead
for prefix_length, postfix_length, prompt_length in zip(
@ -1790,6 +1792,7 @@ class FlashCausalLM(Model):
finished_prefilling = False
next_prefilling_mask.append(True)
else:
# FIXME: use true number of accepted tokens instead of 1
# Since speculation will be turned off, this is always true
next_chunk_length = 1
next_prefilling_mask.append(False)
@ -1807,14 +1810,7 @@ class FlashCausalLM(Model):
batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask
# Turn off speculative if some requests are still prefilling
# It makes the logic easier to follow
if prefill and not finished_prefilling:
speculate = 0
speculative_logits = None
else:
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
@ -2045,18 +2041,18 @@ class FlashCausalLM(Model):
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if prefill and request.prefill_logprobs:
if request_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
request_prefill_tokens = batch.prefill_tokens[i]
request_prefill_logprobs = prefill_logprobs[
out_start_index : out_end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
if request_prefill_tokens is None:
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] * (
len(prefix_ids) + 1
@ -2069,18 +2065,20 @@ class FlashCausalLM(Model):
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
if request_prefill_tokens is not None:
prefill_tokens = request_prefill_tokens + prefill_tokens
if past_prefill_logprob_tokens is not None:
prefill_logprob_tokens = (
past_prefill_logprob_tokens + prefill_logprob_tokens
)
batch.prefill_tokens[i] = prefill_tokens
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
else:
batch.prefill_tokens[i] = None
batch.prefill_logprob_tokens[i] = None
# If it is, the tokens we decoded should be ignored
if request_prefilling:
@ -2178,7 +2176,7 @@ class FlashCausalLM(Model):
generation = Generation(
request.id,
batch.prefill_tokens[i],
batch.prefill_logprob_tokens[i],
Tokens(
_next_token_ids,
_next_token_logprobs,

View File

@ -7,6 +7,7 @@ from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
@ -94,6 +95,9 @@ class Model(ABC):
window_size=self.sliding_window,
speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,
attention_impl=ATTENTION,
block_size=BLOCK_SIZE,
)
@property

View File

@ -80,7 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
current_tokens=len(self.input_ids),
)
@classmethod

View File

@ -153,6 +153,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
concat_ns = None
if self.model.support_chunking:
if request.HasField("cached_batch"):
cached_batch = self.cache.pop(request.cached_batch.id)
if cached_batch is None:
raise ValueError(
f"Batch ID {request.cached_batch.id} not found in cache."
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([batch, cached_batch])
concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch)
@ -162,6 +174,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
concat_ns=concat_ns,
)
async def Decode(self, request, context):
@ -179,16 +192,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) == 0:
raise ValueError("All batches are empty")
if self.model.support_chunking:
if request.HasField("batch"):
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
self.model.dtype,
self.model.device,
)
batches.append(batch)
if len(batches) > 1:
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches)