mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add warmup
This commit is contained in:
parent
d649cd8e02
commit
ddfc02f2a4
@ -128,7 +128,10 @@ struct Args {
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
|
||||
#[clap(default_value = "32000", long, env)]
|
||||
/// Limits the number of tokens for the prefill operation.
|
||||
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||
/// to limit the number of requests that can be sent.
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
|
||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||
@ -143,13 +146,6 @@ struct Args {
|
||||
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
|
||||
/// or a single query of `1000` tokens.
|
||||
///
|
||||
/// So you don't have to control that finely
|
||||
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
|
||||
/// want maximum flexibility. However, for your users if they are asking for the full amount of
|
||||
/// total tokens, they are likely to wait for a very long time to get a spot
|
||||
/// in the batch (since they are going to be alone) so setting `max_batch_size`
|
||||
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
|
||||
///
|
||||
/// Overall this number should be the largest possible amount that fits the
|
||||
/// remaining memory (after the model is loaded). Since the actual memory overhead
|
||||
/// depends on other parameters like if you're using quantization, flash attention
|
||||
@ -448,7 +444,7 @@ fn shard_manager(
|
||||
|
||||
// We received a shutdown signal
|
||||
if *shutdown.lock().unwrap() {
|
||||
p.terminate().unwrap();
|
||||
p.kill().unwrap();
|
||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||
tracing::info!("Shard {rank} terminated");
|
||||
return;
|
||||
|
@ -11,6 +11,8 @@ service TextGenerationService {
|
||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Remove requests from a cached batch
|
||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||
/// Warmup the model and compute max cache size
|
||||
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
@ -192,3 +194,13 @@ message DecodeResponse {
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
}
|
||||
|
||||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
/// Maximum number of tokens that the client will send
|
||||
uint32 max_total_tokens = 2;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message WarmupResponse {}
|
||||
|
@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
|
||||
use crate::pb::generate::v1::*;
|
||||
use crate::Result;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use std::cmp::min;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
@ -94,6 +95,63 @@ impl Client {
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<()> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "test".to_string().repeat(max_input_length as usize),
|
||||
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
watermark: true,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 2,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
self.stub.warmup(request).await?.into_inner();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
|
@ -87,6 +87,27 @@ impl ShardedClient {
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||
})
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
|
@ -242,6 +242,7 @@ impl Infer {
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
@ -288,7 +289,8 @@ async fn batching_task(
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
tracing::info!("{token_budget} {batch_max_tokens}");
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
|
@ -34,7 +34,7 @@ struct Args {
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "32000", long, env)]
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(default_value = "32000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
@ -180,16 +180,23 @@ fn main() -> Result<(), std::io::Error> {
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client
|
||||
.info()
|
||||
.await
|
||||
.expect("Unable to get shard info");
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_length as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
.expect("Unable to warmup model");
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Binds on localhost
|
||||
|
@ -19,11 +19,12 @@ class Cache:
|
||||
def delete(self, batch_id: int):
|
||||
batch = self.pop(batch_id)
|
||||
if batch is not None:
|
||||
batch.cleanup()
|
||||
batch.free()
|
||||
del batch
|
||||
|
||||
def clear(self):
|
||||
for k in self.cache.keys():
|
||||
keys = list(self.cache.keys())
|
||||
for k in keys:
|
||||
self.delete(k)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import itertools
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -5,6 +6,7 @@ import torch.distributed
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
@ -21,6 +23,7 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
@ -35,7 +38,7 @@ class CacheManager:
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = 16
|
||||
self.block_size = BLOCK_SIZE
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
@ -60,26 +63,30 @@ class CacheManager:
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]:
|
||||
# Number of needed block to cover all tokens
|
||||
needed_blocks = (n_tokens // self.block_size) + 1
|
||||
|
||||
def allocate(self, num_blocks: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks"
|
||||
logger.info(f"Free blocks: {len(free_block_indices)}")
|
||||
assert (
|
||||
len(free_block_indices) >= num_blocks
|
||||
), f"Out of available cache blocks: asked {num_blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
block_indices = free_block_indices[:needed_blocks]
|
||||
block_indices = free_block_indices[:num_blocks]
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
# Get slots for the allocated blocks
|
||||
slots = self.slots[block_indices].flatten()[:n_tokens]
|
||||
slots = self.slots[block_indices].flatten()
|
||||
|
||||
return block_indices.flatten().tolist(), slots
|
||||
logger.info(f"allocate {num_blocks} blocks")
|
||||
|
||||
def free(self, block_indices: List[int]):
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
return block_indices.flatten(), slots
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None:
|
||||
# Reset mask
|
||||
logger.info(f"free {len(block_indices)} blocks")
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -97,16 +104,25 @@ class FlashCausalLMBatch(Batch):
|
||||
start_seq_prefill: Optional[torch.Tensor]
|
||||
# tensor of length b holding ending offset of each sequence, only used in prefill
|
||||
end_seq_prefill: Optional[torch.Tensor]
|
||||
# list of length b of list of length s_i // block_size
|
||||
block_tables: List[List[int]]
|
||||
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||
block_tables_tensor: torch.Tensor
|
||||
|
||||
# Paged Attention values
|
||||
|
||||
# Set when creating the batch
|
||||
# CPU tensor of length b indicating the start of each sequence in slots
|
||||
start_slots: torch.Tensor
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
slots: torch.Tensor
|
||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||
slot_indices: torch.Tensor
|
||||
# List of tuple of ints representing the number of blocks and slots needed by each sequence
|
||||
needed_blocks_slots: Optional[List[Tuple[int, int]]]
|
||||
|
||||
# Set in prefill by the CacheManager
|
||||
# list of length b of list of length s_i // block_size
|
||||
block_tables: Optional[List[List[int]]]
|
||||
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||
block_tables_tensor: Optional[torch.Tensor]
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
slots: Optional[torch.Tensor]
|
||||
|
||||
max_seqlen: int
|
||||
|
||||
# Prefill metadata tensors to efficiently compute logprobs
|
||||
@ -128,16 +144,17 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Number of blocks in this batch
|
||||
blocks: int
|
||||
# Maximum number of blocks
|
||||
max_blocks: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
global CACHE_MANAGER
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=len(self.slots),
|
||||
max_tokens=self.blocks * BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -148,8 +165,6 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
global CACHE_MANAGER
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
@ -163,9 +178,8 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = []
|
||||
start_seq_prefill = []
|
||||
end_seq_prefill = []
|
||||
block_tables = []
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slots = []
|
||||
slot_indices = []
|
||||
|
||||
input_lengths = []
|
||||
@ -188,6 +202,7 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_max_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
blocks = 0
|
||||
max_seqlen = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
@ -228,9 +243,9 @@ class FlashCausalLMBatch(Batch):
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
total_tokens = input_length + max_new_tokens - 1
|
||||
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens)
|
||||
block_tables.append(request_blocks)
|
||||
slots.extend(request_slots)
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
blocks += needed_blocks
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
@ -264,7 +279,7 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_length += input_length
|
||||
cumulative_max_length += total_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_blocks = max(max_blocks, len(request_blocks))
|
||||
max_blocks = max(max_blocks, needed_blocks)
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
@ -272,15 +287,6 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(pb.requests), max_blocks), dtype=torch.int32
|
||||
)
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
|
||||
# Padded all_input_ids_tensor
|
||||
all_input_ids_tensor = np.zeros(
|
||||
(len(all_input_ids), max_length), dtype=np.int64
|
||||
@ -312,7 +318,6 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
slots = torch.tensor(slots, dtype=torch.int32, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
@ -339,11 +344,12 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=start_seq_prefill,
|
||||
end_seq_prefill=end_seq_prefill,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
start_slots=start_slots,
|
||||
slots=slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
block_tables=None,
|
||||
block_tables_tensor=None,
|
||||
slots=None,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
@ -356,12 +362,12 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
global CACHE_MANAGER
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
@ -396,6 +402,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
stopping_criterias = []
|
||||
|
||||
blocks = 0
|
||||
max_blocks = 0
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
@ -425,6 +432,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
request_block_table = self.block_tables[idx]
|
||||
blocks += len(request_block_table)
|
||||
block_tables.append(request_block_table)
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
@ -443,6 +451,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
global CACHE_MANAGER
|
||||
# Iterate on all requests
|
||||
for i, r in enumerate(self.requests):
|
||||
# Filter requests that are not part of the new batch
|
||||
@ -472,11 +481,12 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=None,
|
||||
end_seq_prefill=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=None,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
start_slots=start_slots,
|
||||
slots=slots,
|
||||
slot_indices=slot_indices,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
@ -489,17 +499,18 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
global CACHE_MANAGER
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
blocks = 0
|
||||
total_batch_size = 0
|
||||
total_slots = 0
|
||||
max_blocks = 0
|
||||
@ -508,6 +519,7 @@ class FlashCausalLMBatch(Batch):
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
blocks += b.blocks
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||
max_length = max(
|
||||
@ -613,11 +625,12 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=None,
|
||||
end_seq_prefill=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=None,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
start_slots=start_slots,
|
||||
slots=slots,
|
||||
slot_indices=slot_indices,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
@ -630,13 +643,15 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
global CACHE_MANAGER
|
||||
# Free blocks
|
||||
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
||||
def free(self):
|
||||
if self.block_tables is not None:
|
||||
global CACHE_MANAGER
|
||||
# Free blocks
|
||||
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
@ -648,22 +663,17 @@ class FlashCausalLM(Model):
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
|
||||
global CACHE_MANAGER
|
||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||
CACHE_MANAGER = CacheManager(
|
||||
1000, num_layers, num_heads, head_size, dtype, device
|
||||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -678,6 +688,30 @@ class FlashCausalLM(Model):
|
||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
||||
global CACHE_MANAGER
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
CACHE_MANAGER = CacheManager(
|
||||
# Adds some wiggle room
|
||||
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
_, batch = self.generate_token(batch)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||
f"prefill tokens. "
|
||||
f"You need to decrease `--max-batch-total-tokens` and `--max-batch-prefill-tokens`"
|
||||
)
|
||||
raise e
|
||||
batch.free()
|
||||
|
||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
@ -718,6 +752,35 @@ class FlashCausalLM(Model):
|
||||
prefill = batch.start_seq_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
if batch.needed_blocks_slots:
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
slots = []
|
||||
block_tables = []
|
||||
try:
|
||||
for i, (needed_blocks, needed_slots) in enumerate(
|
||||
batch.needed_blocks_slots
|
||||
):
|
||||
allocated_blocks, allocated_slots = CACHE_MANAGER.allocate(
|
||||
needed_blocks
|
||||
)
|
||||
slots.append(allocated_slots[:needed_slots])
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
except Exception as e:
|
||||
for blocks in block_tables:
|
||||
CACHE_MANAGER.free(blocks)
|
||||
raise e
|
||||
|
||||
batch.needed_blocks_slots = None
|
||||
batch.block_tables = block_tables
|
||||
batch.block_tables_tensor = block_tables_tensor.to(self.device)
|
||||
batch.slots = torch.concat(slots).to(self.device)
|
||||
|
||||
out = self.forward(
|
||||
batch.input_ids,
|
||||
batch.position_ids,
|
||||
@ -931,7 +994,7 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
|
||||
if stopped:
|
||||
batch.cleanup()
|
||||
batch.free()
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, None
|
||||
|
||||
|
@ -68,7 +68,7 @@ class FlashLlama(FlashCausalLM):
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_heads=model.model.num_heads,
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -22,6 +22,9 @@ class Model(ABC):
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
@ -55,6 +58,9 @@ class Model(ABC):
|
||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def warmup(self, batch: B, max_total_tokens: int):
|
||||
self.generate_token(batch)
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
all_input_ids: List[int],
|
||||
|
@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets.append(1)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -35,7 +35,7 @@ class Batch(ABC):
|
||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup(self):
|
||||
def free(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -53,12 +53,24 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
self.model.warmup(batch, request.max_total_tokens)
|
||||
return generate_pb2.WarmupResponse()
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
try:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
except Exception as e:
|
||||
batch.free()
|
||||
raise e
|
||||
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
@ -81,11 +93,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if len(batches) > 1:
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
try:
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
except Exception as e:
|
||||
[batch.free() for batch in batches]
|
||||
raise e
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
try:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
except Exception as e:
|
||||
batch.free()
|
||||
raise e
|
||||
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user