add warmup

This commit is contained in:
OlivierDehaene 2023-06-29 15:50:44 +02:00
parent d649cd8e02
commit ddfc02f2a4
14 changed files with 272 additions and 85 deletions

View File

@ -128,7 +128,10 @@ struct Args {
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)] /// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage /// **IMPORTANT** This is one critical control to allow maximum usage
@ -143,13 +146,6 @@ struct Args {
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
/// or a single query of `1000` tokens. /// or a single query of `1000` tokens.
/// ///
/// So you don't have to control that finely
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
/// want maximum flexibility. However, for your users if they are asking for the full amount of
/// total tokens, they are likely to wait for a very long time to get a spot
/// in the batch (since they are going to be alone) so setting `max_batch_size`
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
///
/// Overall this number should be the largest possible amount that fits the /// Overall this number should be the largest possible amount that fits the
/// remaining memory (after the model is loaded). Since the actual memory overhead /// remaining memory (after the model is loaded). Since the actual memory overhead
/// depends on other parameters like if you're using quantization, flash attention /// depends on other parameters like if you're using quantization, flash attention
@ -448,7 +444,7 @@ fn shard_manager(
// We received a shutdown signal // We received a shutdown signal
if *shutdown.lock().unwrap() { if *shutdown.lock().unwrap() {
p.terminate().unwrap(); p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90)); let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated"); tracing::info!("Shard {rank} terminated");
return; return;

View File

@ -11,6 +11,8 @@ service TextGenerationService {
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch /// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token /// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
@ -192,3 +194,13 @@ message DecodeResponse {
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
} }
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}
/// Empty response
message WarmupResponse {}

View File

@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
use crate::pb::generate::v1::*; use crate::pb::generate::v1::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -94,6 +95,63 @@ impl Client {
Ok(filtered_batch.batch) Ok(filtered_batch.batch)
} }
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "test".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
stop_sequences: vec![],
ignore_eos_token: false,
}),
prefill_logprobs: true,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch

View File

@ -87,6 +87,27 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()
} }
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch

View File

@ -242,6 +242,7 @@ impl Infer {
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
waiting_served_ratio: f32, waiting_served_ratio: f32,
@ -288,7 +289,8 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize) Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
}; };
let token_budget = max_batch_total_tokens - batch_max_tokens; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
tracing::info!("{token_budget} {batch_max_tokens}");
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue

View File

@ -34,7 +34,7 @@ struct Args {
max_total_tokens: usize, max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)] #[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -180,16 +180,23 @@ fn main() -> Result<(), std::io::Error> {
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
.expect("Could not connect to server"); .expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
// Get info from the shard // Get info from the shard
let shard_info = sharded_client let shard_info = sharded_client
.info() .info()
.await .await
.expect("Unable to get shard info"); .expect("Unable to get shard info");
// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
tracing::info!("Connected"); tracing::info!("Connected");
// Binds on localhost // Binds on localhost

View File

@ -19,11 +19,12 @@ class Cache:
def delete(self, batch_id: int): def delete(self, batch_id: int):
batch = self.pop(batch_id) batch = self.pop(batch_id)
if batch is not None: if batch is not None:
batch.cleanup() batch.free()
del batch del batch
def clear(self): def clear(self):
for k in self.cache.keys(): keys = list(self.cache.keys())
for k in keys:
self.delete(k) self.delete(k)
def __len__(self): def __len__(self):

View File

@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,

View File

@ -1,3 +1,4 @@
import math
import itertools import itertools
import torch import torch
import torch.distributed import torch.distributed
@ -5,6 +6,7 @@ import torch.distributed
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
@ -21,6 +23,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16
# Will be set in warmup # Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None CACHE_MANAGER: Optional["CacheManager"] = None
@ -35,7 +38,7 @@ class CacheManager:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
): ):
self.block_size = 16 self.block_size = BLOCK_SIZE
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size x = self.block_size // element_size
@ -60,26 +63,30 @@ class CacheManager:
0, num_blocks * self.block_size, dtype=torch.int32 0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size) ).view(num_blocks, self.block_size)
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]: def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]:
# Number of needed block to cover all tokens
needed_blocks = (n_tokens // self.block_size) + 1
# Get free blocks indices by finding values in mask that are not set to 0 # Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero() free_block_indices = self.free_block_mask.nonzero()
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks" logger.info(f"Free blocks: {len(free_block_indices)}")
assert (
len(free_block_indices) >= num_blocks
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks"
# Allocate the required number of blocks by setting the mask to 0 # Allocate the required number of blocks by setting the mask to 0
block_indices = free_block_indices[:needed_blocks] block_indices = free_block_indices[:num_blocks]
self.free_block_mask[block_indices] = 0 self.free_block_mask[block_indices] = 0
# Get slots for the allocated blocks # Get slots for the allocated blocks
slots = self.slots[block_indices].flatten()[:n_tokens] slots = self.slots[block_indices].flatten()
return block_indices.flatten().tolist(), slots logger.info(f"allocate {num_blocks} blocks")
def free(self, block_indices: List[int]): return block_indices.flatten(), slots
# Reset mask
self.free_block_mask[block_indices] = 1 def free(self, block_indices: Optional[List[int]]):
if block_indices is not None:
# Reset mask
logger.info(f"free {len(block_indices)} blocks")
self.free_block_mask[block_indices] = 1
@dataclass @dataclass
@ -97,16 +104,25 @@ class FlashCausalLMBatch(Batch):
start_seq_prefill: Optional[torch.Tensor] start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill # tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor] end_seq_prefill: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]] # Paged Attention values
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor # Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots # CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor start_slots: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor slot_indices: torch.Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots: Optional[List[Tuple[int, int]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor]
max_seqlen: int max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs # Prefill metadata tensors to efficiently compute logprobs
@ -128,16 +144,17 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Number of blocks in this batch
blocks: int
# Maximum number of blocks # Maximum number of blocks
max_blocks: int max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
global CACHE_MANAGER
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=len(self.slots), max_tokens=self.blocks * BLOCK_SIZE,
) )
@classmethod @classmethod
@ -148,8 +165,6 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global CACHE_MANAGER
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
for r in pb.requests: for r in pb.requests:
@ -163,9 +178,8 @@ class FlashCausalLMBatch(Batch):
position_ids = [] position_ids = []
start_seq_prefill = [] start_seq_prefill = []
end_seq_prefill = [] end_seq_prefill = []
block_tables = [] needed_blocks_slots = []
start_slots = [] start_slots = []
slots = []
slot_indices = [] slot_indices = []
input_lengths = [] input_lengths = []
@ -188,6 +202,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0 cumulative_max_length = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0 max_seqlen = 0
max_length = 0 max_length = 0
max_blocks = 0 max_blocks = 0
@ -228,9 +243,9 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1 total_tokens = input_length + max_new_tokens - 1
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
block_tables.append(request_blocks) blocks += needed_blocks
slots.extend(request_slots) needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length) start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
@ -264,7 +279,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -272,15 +287,6 @@ class FlashCausalLMBatch(Batch):
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded block tables
block_tables_tensor = torch.zeros(
(len(pb.requests), max_blocks), dtype=torch.int32
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
# Padded all_input_ids_tensor # Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros( all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64 (len(all_input_ids), max_length), dtype=np.int64
@ -312,7 +318,6 @@ class FlashCausalLMBatch(Batch):
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int32, device=device)
input_lengths_tensor = torch.tensor( input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device input_lengths, dtype=torch.int32, device=device
) )
@ -339,11 +344,12 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
start_seq_prefill=start_seq_prefill, start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill, end_seq_prefill=end_seq_prefill,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots, start_slots=start_slots,
slots=slots,
slot_indices=slot_indices, slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -356,12 +362,12 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
@ -396,6 +402,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
blocks = 0
max_blocks = 0 max_blocks = 0
# Cumulative length # Cumulative length
cumulative_max_length = 0 cumulative_max_length = 0
@ -425,6 +432,7 @@ class FlashCausalLMBatch(Batch):
) )
request_block_table = self.block_tables[idx] request_block_table = self.block_tables[idx]
blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_max_length)
@ -443,6 +451,7 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
global CACHE_MANAGER
# Iterate on all requests # Iterate on all requests
for i, r in enumerate(self.requests): for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch # Filter requests that are not part of the new batch
@ -472,11 +481,12 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
start_seq_prefill=None, start_seq_prefill=None,
end_seq_prefill=None, end_seq_prefill=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots, slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -489,17 +499,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
blocks = 0
total_batch_size = 0 total_batch_size = 0
total_slots = 0 total_slots = 0
max_blocks = 0 max_blocks = 0
@ -508,6 +519,7 @@ class FlashCausalLMBatch(Batch):
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
@ -613,11 +625,12 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
start_seq_prefill=None, start_seq_prefill=None,
end_seq_prefill=None, end_seq_prefill=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots, slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -630,13 +643,15 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
def cleanup(self): def free(self):
global CACHE_MANAGER if self.block_tables is not None:
# Free blocks global CACHE_MANAGER
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) # Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -648,22 +663,17 @@ class FlashCausalLM(Model):
model: torch.nn.Module, model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_layers: int, num_layers: int,
num_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
): ):
self.num_heads = num_heads self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_size = head_size self.head_size = head_size
global CACHE_MANAGER
torch.cuda.set_per_process_memory_fraction(1.0)
CACHE_MANAGER = CacheManager(
1000, num_layers, num_heads, head_size, dtype, device
)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -678,6 +688,30 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
global CACHE_MANAGER
torch.cuda.empty_cache()
try:
CACHE_MANAGER = CacheManager(
# Adds some wiggle room
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)
except Exception as e:
logger.error(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`"
)
raise e
batch.free()
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
@ -718,6 +752,35 @@ class FlashCausalLM(Model):
prefill = batch.start_seq_prefill is not None prefill = batch.start_seq_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
if batch.needed_blocks_slots:
# Padded block tables
block_tables_tensor = torch.zeros(
(len(batch), batch.max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
slots = []
block_tables = []
try:
for i, (needed_blocks, needed_slots) in enumerate(
batch.needed_blocks_slots
):
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
needed_blocks
)
slots.append(allocated_slots[:needed_slots])
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
except Exception as e:
for blocks in block_tables:
CACHE_MANAGER.free(blocks)
raise e
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor.to(self.device)
batch.slots = torch.concat(slots).to(self.device)
out = self.forward( out = self.forward(
batch.input_ids, batch.input_ids,
batch.position_ids, batch.position_ids,
@ -931,7 +994,7 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
if stopped: if stopped:
batch.cleanup() batch.free()
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
return generations, None return generations, None

View File

@ -68,7 +68,7 @@ class FlashLlama(FlashCausalLM):
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
num_heads=model.model.num_heads, num_kv_heads=model.model.num_heads,
head_size=model.model.head_size, head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,

View File

@ -22,6 +22,9 @@ class Model(ABC):
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
): ):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
@ -55,6 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):
self.generate_token(batch)
def decode_token( def decode_token(
self, self,
all_input_ids: List[int], all_input_ids: List[int],

View File

@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,

View File

@ -35,7 +35,7 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError raise NotImplementedError
def cleanup(self): def free(self):
pass pass
@abstractmethod @abstractmethod

View File

@ -53,12 +53,24 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
generations, next_batch = self.model.generate_token(batch) try:
generations, next_batch = self.model.generate_token(batch)
except Exception as e:
batch.free()
raise e
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
@ -81,11 +93,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError("All batches are empty") raise ValueError("All batches are empty")
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) try:
batch = self.model.batch_type.concatenate(batches)
except Exception as e:
[batch.free() for batch in batches]
raise e
else: else:
batch = batches[0] batch = batches[0]
generations, next_batch = self.model.generate_token(batch) try:
generations, next_batch = self.model.generate_token(batch)
except Exception as e:
batch.free()
raise e
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(