Pass the max_batch_total_tokens to causal_lm

refine the warmup

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-10-10 07:31:50 +00:00
parent bab529c916
commit 67ee45a270
15 changed files with 160 additions and 119 deletions

View File

@ -97,5 +97,5 @@ FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
#ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]

View File

@ -110,6 +110,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
@ -175,6 +176,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();

View File

@ -104,6 +104,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
@ -114,6 +115,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})

View File

@ -110,6 +110,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
@ -203,6 +204,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();

View File

@ -104,6 +104,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
@ -114,6 +115,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})

View File

@ -27,6 +27,8 @@ impl BackendV3 {
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
@ -51,6 +53,8 @@ impl BackendV3 {
prefix_caching,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
@ -152,6 +156,7 @@ pub(crate) async fn batching_task(
.await;
let mut waiting_tokens = 1;
tracing::error!("Enter cached batch loop");
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {

View File

@ -111,6 +111,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
@ -203,6 +204,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();

View File

@ -105,6 +105,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
@ -115,6 +116,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})

View File

@ -94,6 +94,7 @@ pub async fn connect_backend(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
max_batch_size,
)
.await
@ -114,6 +115,8 @@ pub async fn connect_backend(
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_input_tokens as u32,
max_total_tokens as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,

View File

@ -49,6 +49,8 @@ impl Queue {
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
) -> Self {
// Create channel
@ -61,6 +63,8 @@ impl Queue {
prefix_caching,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
max_batch_total_tokens,
queue_receiver,
));
@ -114,6 +118,8 @@ async fn queue_task(
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
@ -123,6 +129,8 @@ async fn queue_task(
prefix_caching,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
max_batch_total_tokens,
);
@ -174,6 +182,15 @@ struct State {
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
/// Require padding
requires_padding: bool,
/// max input tokens
max_input_tokens: u32,
/// max total tokens,
max_total_tokens: u32,
}
impl State {
@ -183,6 +200,8 @@ impl State {
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
) -> Self {
let block_allocator = (!requires_padding).then(|| {
@ -202,6 +221,9 @@ impl State {
window_size,
speculate,
block_allocator,
requires_padding,
max_input_tokens,
max_total_tokens,
}
}
@ -272,10 +294,19 @@ impl State {
None => {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
if self.requires_padding {
prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens;
} else{
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
}
if self.requires_padding {
decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
} else {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
}
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {

View File

@ -228,6 +228,7 @@ message WarmupRequest {
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
uint32 max_batch_total_tokens = 5;
}
message WarmupResponse {

View File

@ -261,6 +261,7 @@ message WarmupRequest {
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
uint32 max_batch_total_tokens = 5;
}
message WarmupResponse {

View File

@ -97,6 +97,7 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",

View File

@ -54,18 +54,12 @@ from text_generation_server.utils.debug import dbg_trace
from text_generation_server.utils.speculate import get_speculate
tracer = trace.get_tracer(__name__)
# MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
# BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
# PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
# PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
# CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
# LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 65536))
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
PREFILL_WARMUP_BATCH_SIZE_LIST = []
@ -81,6 +75,9 @@ def torch_compile_for_eager(func):
def round_up(warmup_list:list, num) :
i = 0
if len(warmup_list) == 0:
return num
for i in warmup_list:
if num <= i :
break
@ -525,14 +522,12 @@ class CausalLMBatch(Batch):
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
)
input_len = tokenized_inputs["input_ids"].shape[1]
# Round up sequence length
bucket_size = max_input_length
left_padding = max_input_length - input_len
if is_warmup is False:
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
@ -554,7 +549,7 @@ class CausalLMBatch(Batch):
)
all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
).T.split(1, dim=1)[0:len(pb.requests)]
).T.split(1, dim=1)
input_len = bucket_size
for r in requests:
r.input_length = input_len
@ -567,7 +562,6 @@ class CausalLMBatch(Batch):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
htorch.core.mark_step()
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
@ -908,6 +902,7 @@ class CausalLM(Model):
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
kwargs.update(self.kwargs)
if past_key_values is not None:
return self.model.forward(**kwargs)
else:
@ -972,7 +967,7 @@ class CausalLM(Model):
'top_n_tokens': batch.top_n_tokens[req_idx],
'top_token_ids': batch_top_token_ids[req_idx],
'top_token_logprobs': batch_top_token_logprobs[req_idx],
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req_idx],
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
})
@ -1203,7 +1198,7 @@ class CausalLM(Model):
decode_ns = time.time_ns() - start_decode
return generations, batch if not stopped else None, (forward_ns, decode_ns)
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
def generate_warmup_batch(self, request, seq_len, batch_size):
batch = copy.deepcopy(request.batch)
for req in batch.requests:
req.truncate = seq_len
@ -1211,11 +1206,13 @@ class CausalLM(Model):
for i in range(len(batch.requests) - batch_size):
batch.requests.pop()
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup)
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True)
def warmup(self, request) -> None:
is_warmup = True
MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
try:
# max prefill batch size warmup
@ -1226,99 +1223,43 @@ class CausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`"
)
global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
max_input_length = batch.input_ids.shape[1]
#warmup decode batch size
max_prefill_batch_size = batch.input_ids.shape[0]
PREFILL_WARMUP_BATCH_SIZE_LIST = []
batch_size = 1
while batch_size <= max_prefill_batch_size:
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = batch_size * 2
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
PREFILL_WARMUP_SEQLEN_LIST = []
i = 0
while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
i += 1
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
#Prefill and decode warmup
DECODE_WARMUP_BATCH_SIZE_LIST = []
prefill_batch = None
decode_batch = None
try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
except:
raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"You need to decrease `--max-batch-prefill-tokens`"
)
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats} "
)
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
batch_size = max_prefill_batch_size * 2
# Decode warmup with bigger batch_size
try:
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
batches = []
for i in range(int(batch_size/max_prefill_batch_size)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
while batch_size <= max_decode_batch_size:
_, decode_batch, _ = self.generate_token(batches, is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = batch_size * 2
batches.clear()
for i in range(int(batch_size/max_prefill_batch_size)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
batches.clear()
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
batch_size = max_decode_batch_size
for i in range(int(max_decode_batch_size / 2)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
self.limit_hpu_graph = True
try:
while batch_size > 1:
batches= []
iters = math.floor(batch_size/max_prefill_batch_size)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
for i in range(iters):
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
except :
raise RuntimeError(
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"max_decode_batch_size is {max_decode_batch_size}"
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
)
if batch_size % max_prefill_batch_size != 0:
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup)
logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}")
batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE)
except:
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
self.model.clear_cache()
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
logger.warning(
f"Not enough memory to warmup all batch size of decode."
f"You need to decrease `--max-batch-total-tokens`"
)
else:
raise RuntimeError(
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})."
f"You need to decrease `--max-batch-total-tokens`"
)
DECODE_WARMUP_BATCH_SIZE_LIST.sort()
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
@ -1326,4 +1267,48 @@ class CausalLM(Model):
f"Memory stats: {mem_stats} "
)
# Warmup prefill batch_size
max_input_length = request.max_input_length
max_prefill_batch_size = batch.input_ids.shape[0]
batch_size = max_prefill_batch_size
while batch_size >= 1:
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE)
seq_len = max_input_length
while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len = math.floor(seq_len/2)
if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF:
PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF)
#Prefill and decode warmup
prefill_batch = None
PREFILL_WARMUP_BATCH_SIZE_LIST.sort()
PREFILL_WARMUP_SEQLEN_LIST.sort()
try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
except:
raise RuntimeError(
f"Not enough memory to run following prefill batch_size."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"You need to decrease `--max-batch-prefill-tokens`"
)
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
if limit_hpu_graph == False:
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Memory stats: {mem_stats} "
)
return MAX_BATCH_TOTAL_TOKENS

View File

@ -1,4 +1,5 @@
import inspect
from loguru import logger
import torch
from abc import ABC, abstractmethod
@ -10,7 +11,7 @@ from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
import time
BASE_MODEL_ADAPTER_ID = "__base_model__"
@ -110,6 +111,7 @@ class Model(ABC):
all_input_ids[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
)