mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Pass the max_batch_total_tokens to causal_lm
refine the warmup Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
bab529c916
commit
67ee45a270
@ -97,5 +97,5 @@ FROM base
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
# CMD ["--json-output"]
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -110,6 +110,7 @@ impl Client {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
@ -175,6 +176,7 @@ impl Client {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
|
@ -104,6 +104,7 @@ impl ShardedClient {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
@ -114,6 +115,7 @@ impl ShardedClient {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
|
@ -110,6 +110,7 @@ impl Client {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
@ -203,6 +204,7 @@ impl Client {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
|
@ -104,6 +104,7 @@ impl ShardedClient {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
@ -114,6 +115,7 @@ impl ShardedClient {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
|
@ -27,6 +27,8 @@ impl BackendV3 {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
@ -51,6 +53,8 @@ impl BackendV3 {
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
@ -152,6 +156,7 @@ pub(crate) async fn batching_task(
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
tracing::error!("Enter cached batch loop");
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
|
@ -111,6 +111,7 @@ impl Client {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
@ -203,6 +204,7 @@ impl Client {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
|
@ -105,6 +105,7 @@ impl ShardedClient {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
@ -115,6 +116,7 @@ impl ShardedClient {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
|
@ -94,6 +94,7 @@ pub async fn connect_backend(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
@ -114,6 +115,8 @@ pub async fn connect_backend(
|
||||
let backend = BackendV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_input_tokens as u32,
|
||||
max_total_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
|
@ -49,6 +49,8 @@ impl Queue {
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
@ -61,6 +63,8 @@ impl Queue {
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
queue_receiver,
|
||||
));
|
||||
@ -114,6 +118,8 @@ async fn queue_task(
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
@ -123,6 +129,8 @@ async fn queue_task(
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
|
||||
@ -174,6 +182,15 @@ struct State {
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
|
||||
/// Require padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// max input tokens
|
||||
max_input_tokens: u32,
|
||||
|
||||
/// max total tokens,
|
||||
max_total_tokens: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
@ -183,6 +200,8 @@ impl State {
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
@ -202,6 +221,9 @@ impl State {
|
||||
window_size,
|
||||
speculate,
|
||||
block_allocator,
|
||||
requires_padding,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
@ -272,10 +294,19 @@ impl State {
|
||||
None => {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
if self.requires_padding {
|
||||
prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens;
|
||||
} else{
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
|
||||
} else {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
}
|
||||
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||
|
@ -228,6 +228,7 @@ message WarmupRequest {
|
||||
uint32 max_input_length = 2;
|
||||
uint32 max_prefill_tokens = 3;
|
||||
uint32 max_total_tokens = 4;
|
||||
uint32 max_batch_total_tokens = 5;
|
||||
}
|
||||
|
||||
message WarmupResponse {
|
||||
|
@ -261,6 +261,7 @@ message WarmupRequest {
|
||||
uint32 max_input_length = 2;
|
||||
uint32 max_prefill_tokens = 3;
|
||||
uint32 max_total_tokens = 4;
|
||||
uint32 max_batch_total_tokens = 5;
|
||||
}
|
||||
|
||||
message WarmupResponse {
|
||||
|
@ -97,6 +97,7 @@ def serve(
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = "bfloat16" if dtype is None else dtype.value
|
||||
logger.info(f"quantize={quantize}")
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
|
@ -54,18 +54,12 @@ from text_generation_server.utils.debug import dbg_trace
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
||||
# BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
# PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
||||
# PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
# CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
# LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
||||
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 65536))
|
||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
||||
|
||||
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||
@ -81,6 +75,9 @@ def torch_compile_for_eager(func):
|
||||
|
||||
def round_up(warmup_list:list, num) :
|
||||
i = 0
|
||||
if len(warmup_list) == 0:
|
||||
return num
|
||||
|
||||
for i in warmup_list:
|
||||
if num <= i :
|
||||
break
|
||||
@ -525,14 +522,12 @@ class CausalLMBatch(Batch):
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
).to(device)
|
||||
)
|
||||
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
|
||||
# Round up sequence length
|
||||
bucket_size = max_input_length
|
||||
left_padding = max_input_length - input_len
|
||||
if is_warmup is False:
|
||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
||||
@ -554,7 +549,7 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
all_input_ids = torch.nn.functional.pad(
|
||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||
).T.split(1, dim=1)[0:len(pb.requests)]
|
||||
).T.split(1, dim=1)
|
||||
input_len = bucket_size
|
||||
for r in requests:
|
||||
r.input_length = input_len
|
||||
@ -567,7 +562,6 @@ class CausalLMBatch(Batch):
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
@ -908,6 +902,7 @@ class CausalLM(Model):
|
||||
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
||||
|
||||
kwargs.update(self.kwargs)
|
||||
|
||||
if past_key_values is not None:
|
||||
return self.model.forward(**kwargs)
|
||||
else:
|
||||
@ -972,7 +967,7 @@ class CausalLM(Model):
|
||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
||||
'top_token_ids': batch_top_token_ids[req_idx],
|
||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req_idx],
|
||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
||||
|
||||
})
|
||||
|
||||
@ -1203,7 +1198,7 @@ class CausalLM(Model):
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
||||
|
||||
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
|
||||
def generate_warmup_batch(self, request, seq_len, batch_size):
|
||||
batch = copy.deepcopy(request.batch)
|
||||
for req in batch.requests:
|
||||
req.truncate = seq_len
|
||||
@ -1211,11 +1206,13 @@ class CausalLM(Model):
|
||||
for i in range(len(batch.requests) - batch_size):
|
||||
batch.requests.pop()
|
||||
|
||||
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup)
|
||||
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True)
|
||||
|
||||
|
||||
def warmup(self, request) -> None:
|
||||
is_warmup = True
|
||||
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
|
||||
try:
|
||||
# max prefill batch size warmup
|
||||
@ -1226,99 +1223,43 @@ class CausalLM(Model):
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
)
|
||||
|
||||
global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
|
||||
max_input_length = batch.input_ids.shape[1]
|
||||
#warmup decode batch size
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||
batch_size = 1
|
||||
while batch_size <= max_prefill_batch_size:
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
batch_size = batch_size * 2
|
||||
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
||||
|
||||
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||
i = 0
|
||||
while seq_len <= max_input_length:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
||||
i += 1
|
||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
||||
|
||||
#Prefill and decode warmup
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
||||
prefill_batch = None
|
||||
decode_batch = None
|
||||
try:
|
||||
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
|
||||
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
|
||||
except:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle following prefill and decode warmup."
|
||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
)
|
||||
|
||||
mem_stats = get_hpu_memory_stats(self.device)
|
||||
logger.info(
|
||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||
batch_size = max_prefill_batch_size * 2
|
||||
# Decode warmup with bigger batch_size
|
||||
try:
|
||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
||||
batches = []
|
||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
batches.append(prefill_batch)
|
||||
while batch_size <= max_decode_batch_size:
|
||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
batch_size = batch_size * 2
|
||||
batches.clear()
|
||||
|
||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
batches.append(prefill_batch)
|
||||
|
||||
batches.clear()
|
||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
||||
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
||||
batch_size = max_decode_batch_size
|
||||
for i in range(int(max_decode_batch_size / 2)) :
|
||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
|
||||
self.limit_hpu_graph = True
|
||||
try:
|
||||
while batch_size > 1:
|
||||
batches= []
|
||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
for i in range(iters):
|
||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
batches.append(prefill_batch)
|
||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
|
||||
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
|
||||
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
|
||||
except :
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
|
||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
||||
f"max_decode_batch_size is {max_decode_batch_size}"
|
||||
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
|
||||
)
|
||||
|
||||
if batch_size % max_prefill_batch_size != 0:
|
||||
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
batches.append(batch)
|
||||
|
||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||
logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}")
|
||||
batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE)
|
||||
except:
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
|
||||
self.model.clear_cache()
|
||||
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
|
||||
logger.warning(
|
||||
f"Not enough memory to warmup all batch size of decode."
|
||||
f"You need to decrease `--max-batch-total-tokens`"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})."
|
||||
f"You need to decrease `--max-batch-total-tokens`"
|
||||
)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.sort()
|
||||
mem_stats = get_hpu_memory_stats(self.device)
|
||||
logger.info(
|
||||
f"\nFollowing decode warmup successfully.\n"
|
||||
@ -1326,4 +1267,48 @@ class CausalLM(Model):
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
# Warmup prefill batch_size
|
||||
max_input_length = request.max_input_length
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
batch_size = max_prefill_batch_size
|
||||
while batch_size >= 1:
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE)
|
||||
|
||||
seq_len = max_input_length
|
||||
while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||
seq_len = math.floor(seq_len/2)
|
||||
|
||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||
PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||
|
||||
#Prefill and decode warmup
|
||||
prefill_batch = None
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST.sort()
|
||||
PREFILL_WARMUP_SEQLEN_LIST.sort()
|
||||
|
||||
try:
|
||||
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||
except:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to run following prefill batch_size."
|
||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
)
|
||||
|
||||
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
if limit_hpu_graph == False:
|
||||
mem_stats = get_hpu_memory_stats(self.device)
|
||||
logger.info(
|
||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
||||
f"Memory stats: {mem_stats} "
|
||||
)
|
||||
|
||||
return MAX_BATCH_TOTAL_TOKENS
|
@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@ -10,7 +11,7 @@ from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
|
||||
import time
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
@ -110,6 +111,7 @@ class Model(ABC):
|
||||
all_input_ids[prefix_offset:read_offset],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
|
||||
new_text = self.tokenizer.decode(
|
||||
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user