mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
wip
This commit is contained in:
parent
dfca1dfc5e
commit
18e77a5cc7
@ -17,6 +17,8 @@ service TextGenerationService {
|
|||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
|
/// Update batch
|
||||||
|
rpc Update(UpdateRequest) returns (UpdateResponse);
|
||||||
/// Health check
|
/// Health check
|
||||||
rpc Health (HealthRequest) returns (HealthResponse);
|
rpc Health (HealthRequest) returns (HealthResponse);
|
||||||
}
|
}
|
||||||
@ -198,6 +200,8 @@ message Generation {
|
|||||||
optional GeneratedText generated_text = 4;
|
optional GeneratedText generated_text = 4;
|
||||||
/// Top tokens
|
/// Top tokens
|
||||||
repeated Tokens top_tokens = 5;
|
repeated Tokens top_tokens = 5;
|
||||||
|
/// Current length of the request: prompt tokens + number of generated tokens until this point
|
||||||
|
uint32 current_length = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
@ -251,6 +255,26 @@ message DecodeResponse {
|
|||||||
optional uint64 concat_ns = 6;
|
optional uint64 concat_ns = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message ExtendedRequest {
|
||||||
|
/// Request ID
|
||||||
|
uint64 request_id = 1;
|
||||||
|
/// Paged attention blocks to add
|
||||||
|
repeated uint32 blocks = 2;
|
||||||
|
/// Paged attention slots to add
|
||||||
|
repeated uint32 slots = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateRequest {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 batch_id = 1;
|
||||||
|
/// Requests to update
|
||||||
|
repeated ExtendedRequest extend_requests = 2;
|
||||||
|
/// Requests to terminate
|
||||||
|
repeated uint64 terminated_request_ids = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateResponse {}
|
||||||
|
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
|
@ -33,6 +33,8 @@ pub(crate) struct Entry {
|
|||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
/// Block Allocation
|
/// Block Allocation
|
||||||
pub block_allocation: Option<BlockAllocation>,
|
pub block_allocation: Option<BlockAllocation>,
|
||||||
|
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
|
||||||
|
pub current_length: u32
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
@ -498,6 +500,7 @@ mod tests {
|
|||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
block_allocation: None,
|
block_allocation: None,
|
||||||
|
current_length: 0,
|
||||||
};
|
};
|
||||||
(entry, receiver_tx)
|
(entry, receiver_tx)
|
||||||
}
|
}
|
||||||
|
@ -88,6 +88,7 @@ impl Scheduler for SchedulerV3 {
|
|||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
block_allocation: None,
|
block_allocation: None,
|
||||||
|
current_length: input_length,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
@ -287,6 +288,8 @@ async fn decode(
|
|||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
filter_update_allocations(client, entries).await;
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
@ -355,8 +358,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// Get entry
|
// Get entry
|
||||||
// We can `expect` here as the request id should always be in the entries
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.get(&id)
|
.get_mut(&id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
entry.current_length = generation.current_length;
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
@ -374,6 +378,35 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if block allocations need to be extended
|
||||||
|
/// If we don't have enough blocks, request will be filtered with an OutOfPages finish reason
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn filter_update_allocations(client: &mut ShardedClient, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
// let mut extend_entries = Vec::with_capacity(entries.len());
|
||||||
|
// let mut finish_entries = Vec::with_capacity(entries.len());
|
||||||
|
|
||||||
|
// for (request_id, entry) in entries.into_iter() {
|
||||||
|
// tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length);
|
||||||
|
//
|
||||||
|
// if let Some(block_allocation) = &mut entry.block_allocation {
|
||||||
|
// tracing::info!("Allocation {:?}", block_allocation);
|
||||||
|
//
|
||||||
|
// if entry.current_length > block_allocation.allocated_tokens {
|
||||||
|
// // We need to add new blocks to this entry
|
||||||
|
// let remaining_tokens = block_allocation.total_tokens - entry.current_length;
|
||||||
|
// match block_allocation.extend(remaining_tokens).await {
|
||||||
|
// true => {
|
||||||
|
//
|
||||||
|
// },
|
||||||
|
// false => {
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
/// Send responses through the `entry` response channel
|
||||||
fn send_responses(
|
fn send_responses(
|
||||||
generation: Generation,
|
generation: Generation,
|
||||||
|
@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason {
|
|||||||
EndOfSequenceToken,
|
EndOfSequenceToken,
|
||||||
#[schema(rename = "stop_sequence")]
|
#[schema(rename = "stop_sequence")]
|
||||||
StopSequence,
|
StopSequence,
|
||||||
|
#[schema(rename = "out_of_pages")]
|
||||||
|
OutOfPages
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for FinishReason {
|
impl std::fmt::Display for FinishReason {
|
||||||
@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason {
|
|||||||
FinishReason::Length => write!(f, "length"),
|
FinishReason::Length => write!(f, "length"),
|
||||||
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
||||||
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
||||||
|
FinishReason::OutOfPages => write!(f, "out_of_pages"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -746,6 +746,7 @@ class CausalLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -79,8 +79,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Paged Attention values
|
# Paged Attention values
|
||||||
|
|
||||||
# Set when creating the batch
|
# Set when creating the batch
|
||||||
# CPU tensor of length b indicating the start of each sequence in slots
|
|
||||||
start_slots: torch.Tensor
|
|
||||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
slot_indices: torch.Tensor
|
slot_indices: torch.Tensor
|
||||||
|
|
||||||
@ -88,8 +86,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables: List[List[int]]
|
block_tables: List[List[int]]
|
||||||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
|
# list of length b of list of length s_i
|
||||||
|
slots: List[List[int]]
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: torch.Tensor
|
slots_tensor: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
@ -154,7 +154,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
sliding_window = get_sliding_windows()
|
sliding_window = get_sliding_windows()
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
start_slots = []
|
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
prefill_cache_indices = []
|
prefill_cache_indices = []
|
||||||
|
|
||||||
@ -176,7 +175,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_max_length = 0
|
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
@ -186,6 +184,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
|
flat_slots = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -204,6 +203,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
speculative_length = get_speculate()
|
||||||
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
@ -226,13 +228,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
|
||||||
speculative_length = get_speculate()
|
|
||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
|
# Remove one as the first token des not have a past
|
||||||
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
@ -247,15 +246,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
request_slots = r.slots
|
request_slots = r.slots
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
slots.extend(request_slots[:total_tokens])
|
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_max_length)
|
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
cumulative_max_length,
|
len(flat_slots),
|
||||||
cumulative_max_length + input_length,
|
len(flat_slots) + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
|
slots.append(request_slots)
|
||||||
|
flat_slots.extend(request_slots)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
@ -289,7 +288,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -299,7 +297,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
|
||||||
|
|
||||||
# Padded all_input_ids_tensor
|
# Padded all_input_ids_tensor
|
||||||
all_input_ids_tensor = np.zeros(
|
all_input_ids_tensor = np.zeros(
|
||||||
@ -356,7 +353,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
@ -372,11 +369,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
slots_tensor=slots_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
@ -423,18 +420,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# slots to keep after filtering
|
slot_indices = []
|
||||||
slot_filtering_indices = torch.zeros(
|
|
||||||
self.slots.shape[0], dtype=torch.bool, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
|
||||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
start_slots = []
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
slots = []
|
||||||
|
flat_slots = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -446,8 +438,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
# Cumulative length
|
|
||||||
cumulative_max_length = 0
|
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
@ -471,27 +461,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
remaining_tokens = (
|
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
request_block_table = self.block_tables[idx]
|
request_block_table = self.block_tables[idx]
|
||||||
num_blocks += len(request_block_table)
|
num_blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
start_slots.append(cumulative_max_length)
|
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# List of slots allocated for this request
|
||||||
slot_indices[i] = cumulative_max_length + request_input_length - 1
|
request_slots = self.slots[idx]
|
||||||
|
slots.append(request_slots)
|
||||||
|
|
||||||
# Set slice
|
# Index
|
||||||
slot_filtering_indices[
|
slot_indices.append(len(flat_slots) + request_input_length - 1)
|
||||||
self.start_slots[idx] : self.start_slots[idx]
|
flat_slots.extend(request_slots)
|
||||||
+ request_input_length
|
|
||||||
+ remaining_tokens
|
|
||||||
- 1
|
|
||||||
] = True
|
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
|
||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
@ -501,17 +481,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
# Allocate on GPU
|
||||||
|
slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device)
|
||||||
# Move to GPU now that we have the whole tensor
|
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||||
slot_indices = slot_indices.to(device)
|
|
||||||
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -521,11 +499,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
slots_tensor=slots_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -560,13 +538,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots_tensor)
|
||||||
num_blocks += b.num_blocks
|
num_blocks += b.num_blocks
|
||||||
speculative_length = (
|
speculative_length = (
|
||||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
)
|
)
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
|
# When we filter, we do not recompute this value so we do so here
|
||||||
max_length = max(
|
max_length = max(
|
||||||
max_length,
|
max_length,
|
||||||
max(
|
max(
|
||||||
@ -582,7 +561,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
slots_tensor = batches[0].slots_tensor.new_empty(total_slots)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
@ -597,7 +576,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
total_batch_size,
|
total_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_slots = []
|
slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
@ -627,7 +606,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
slots_start_index = cumulative_slots
|
slots_start_index = cumulative_slots
|
||||||
slots_end_index = cumulative_slots + len(batch.slots)
|
slots_end_index = cumulative_slots + len(batch.slots_tensor)
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
@ -635,7 +614,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
@ -645,8 +624,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
slots.extend(batch.slots)
|
||||||
|
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
|
||||||
@ -662,9 +640,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots_tensor)
|
||||||
|
|
||||||
start_slots = torch.concat(start_slots)
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
@ -688,11 +664,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
slots_tensor=slots_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
@ -993,7 +969,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots_tensor[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
@ -1032,7 +1008,7 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots_tensor[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
@ -1374,6 +1350,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
input_length + n_accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -829,6 +829,7 @@ class IdeficsCausalLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -775,6 +775,7 @@ class Mamba(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -801,6 +801,7 @@ class Seq2SeqLM(Model):
|
|||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
new_decoder_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -84,6 +84,7 @@ class Generation:
|
|||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
# Optional for now, since it's not yet supported for every model.
|
# Optional for now, since it's not yet supported for every model.
|
||||||
top_tokens: Optional[List[Tokens]]
|
top_tokens: Optional[List[Tokens]]
|
||||||
|
current_length: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
@ -100,4 +101,5 @@ class Generation:
|
|||||||
if self.top_tokens is not None
|
if self.top_tokens is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
current_length=self.current_length,
|
||||||
)
|
)
|
||||||
|
@ -228,7 +228,7 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots_tensor[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
@ -267,7 +267,7 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots_tensor[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
Loading…
Reference in New Issue
Block a user