mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
Pass the max_batch_total_tokens to causal_lm
refine the warmup Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
bab529c916
commit
67ee45a270
@ -97,5 +97,5 @@ FROM base
|
|||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
# CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -110,6 +110,7 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
@ -175,6 +176,7 @@ impl Client {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
@ -104,6 +104,7 @@ impl ShardedClient {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
@ -114,6 +115,7 @@ impl ShardedClient {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
|
@ -110,6 +110,7 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
@ -203,6 +204,7 @@ impl Client {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
@ -104,6 +104,7 @@ impl ShardedClient {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
@ -114,6 +115,7 @@ impl ShardedClient {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
|
@ -27,6 +27,8 @@ impl BackendV3 {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
@ -51,6 +53,8 @@ impl BackendV3 {
|
|||||||
prefix_caching,
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
@ -152,6 +156,7 @@ pub(crate) async fn batching_task(
|
|||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
tracing::error!("Enter cached batch loop");
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
// all requests have met their stopping criteria)
|
// all requests have met their stopping criteria)
|
||||||
while let Some(batch) = cached_batch {
|
while let Some(batch) = cached_batch {
|
||||||
|
@ -111,6 +111,7 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
@ -203,6 +204,7 @@ impl Client {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
@ -105,6 +105,7 @@ impl ShardedClient {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
@ -115,6 +116,7 @@ impl ShardedClient {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
|
@ -94,6 +94,7 @@ pub async fn connect_backend(
|
|||||||
max_input_tokens as u32,
|
max_input_tokens as u32,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_total_tokens as u32,
|
max_total_tokens as u32,
|
||||||
|
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -114,6 +115,8 @@ pub async fn connect_backend(
|
|||||||
let backend = BackendV3::new(
|
let backend = BackendV3::new(
|
||||||
sharded_client,
|
sharded_client,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_total_tokens as u32,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
@ -49,6 +49,8 @@ impl Queue {
|
|||||||
prefix_caching: bool,
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
@ -61,6 +63,8 @@ impl Queue {
|
|||||||
prefix_caching,
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
@ -114,6 +118,8 @@ async fn queue_task(
|
|||||||
prefix_caching: bool,
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
@ -123,6 +129,8 @@ async fn queue_task(
|
|||||||
prefix_caching,
|
prefix_caching,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -174,6 +182,15 @@ struct State {
|
|||||||
|
|
||||||
/// Paged Attention Block Allocation
|
/// Paged Attention Block Allocation
|
||||||
block_allocator: Option<BlockAllocator>,
|
block_allocator: Option<BlockAllocator>,
|
||||||
|
|
||||||
|
/// Require padding
|
||||||
|
requires_padding: bool,
|
||||||
|
|
||||||
|
/// max input tokens
|
||||||
|
max_input_tokens: u32,
|
||||||
|
|
||||||
|
/// max total tokens,
|
||||||
|
max_total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
@ -183,6 +200,8 @@ impl State {
|
|||||||
prefix_caching: bool,
|
prefix_caching: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = (!requires_padding).then(|| {
|
let block_allocator = (!requires_padding).then(|| {
|
||||||
@ -202,6 +221,9 @@ impl State {
|
|||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
block_allocator,
|
block_allocator,
|
||||||
|
requires_padding,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,10 +294,19 @@ impl State {
|
|||||||
None => {
|
None => {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
|
if self.requires_padding {
|
||||||
|
prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens;
|
||||||
|
} else{
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.requires_padding {
|
||||||
|
decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
|
||||||
|
} else {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||||
|
@ -228,6 +228,7 @@ message WarmupRequest {
|
|||||||
uint32 max_input_length = 2;
|
uint32 max_input_length = 2;
|
||||||
uint32 max_prefill_tokens = 3;
|
uint32 max_prefill_tokens = 3;
|
||||||
uint32 max_total_tokens = 4;
|
uint32 max_total_tokens = 4;
|
||||||
|
uint32 max_batch_total_tokens = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message WarmupResponse {
|
message WarmupResponse {
|
||||||
|
@ -261,6 +261,7 @@ message WarmupRequest {
|
|||||||
uint32 max_input_length = 2;
|
uint32 max_input_length = 2;
|
||||||
uint32 max_prefill_tokens = 3;
|
uint32 max_prefill_tokens = 3;
|
||||||
uint32 max_total_tokens = 4;
|
uint32 max_total_tokens = 4;
|
||||||
|
uint32 max_batch_total_tokens = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message WarmupResponse {
|
message WarmupResponse {
|
||||||
|
@ -97,6 +97,7 @@ def serve(
|
|||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = "bfloat16" if dtype is None else dtype.value
|
dtype = "bfloat16" if dtype is None else dtype.value
|
||||||
|
logger.info(f"quantize={quantize}")
|
||||||
if dtype is not None and quantize not in {
|
if dtype is not None and quantize not in {
|
||||||
None,
|
None,
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
|
@ -54,18 +54,12 @@ from text_generation_server.utils.debug import dbg_trace
|
|||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
||||||
# MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048))
|
|
||||||
# BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
|
||||||
# PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
|
|
||||||
# PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
|
||||||
# CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
|
||||||
# LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
|
||||||
MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192))
|
|
||||||
MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 65536))
|
|
||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
|
||||||
|
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||||
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))
|
||||||
|
|
||||||
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
@ -81,6 +75,9 @@ def torch_compile_for_eager(func):
|
|||||||
|
|
||||||
def round_up(warmup_list:list, num) :
|
def round_up(warmup_list:list, num) :
|
||||||
i = 0
|
i = 0
|
||||||
|
if len(warmup_list) == 0:
|
||||||
|
return num
|
||||||
|
|
||||||
for i in warmup_list:
|
for i in warmup_list:
|
||||||
if num <= i :
|
if num <= i :
|
||||||
break
|
break
|
||||||
@ -525,14 +522,12 @@ class CausalLMBatch(Batch):
|
|||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
)
|
||||||
|
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
|
||||||
# Round up sequence length
|
# Round up sequence length
|
||||||
bucket_size = max_input_length
|
bucket_size = max_input_length
|
||||||
left_padding = max_input_length - input_len
|
left_padding = max_input_length - input_len
|
||||||
if is_warmup is False:
|
|
||||||
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
|
||||||
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
|
||||||
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
|
||||||
@ -554,7 +549,7 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
all_input_ids = torch.nn.functional.pad(
|
all_input_ids = torch.nn.functional.pad(
|
||||||
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
|
||||||
).T.split(1, dim=1)[0:len(pb.requests)]
|
).T.split(1, dim=1)
|
||||||
input_len = bucket_size
|
input_len = bucket_size
|
||||||
for r in requests:
|
for r in requests:
|
||||||
r.input_length = input_len
|
r.input_length = input_len
|
||||||
@ -567,7 +562,6 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
|
||||||
|
|
||||||
top_n_tokens_tensor = torch.tensor(
|
top_n_tokens_tensor = torch.tensor(
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
@ -908,6 +902,7 @@ class CausalLM(Model):
|
|||||||
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
||||||
|
|
||||||
kwargs.update(self.kwargs)
|
kwargs.update(self.kwargs)
|
||||||
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
return self.model.forward(**kwargs)
|
return self.model.forward(**kwargs)
|
||||||
else:
|
else:
|
||||||
@ -972,7 +967,7 @@ class CausalLM(Model):
|
|||||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
'top_n_tokens': batch.top_n_tokens[req_idx],
|
||||||
'top_token_ids': batch_top_token_ids[req_idx],
|
'top_token_ids': batch_top_token_ids[req_idx],
|
||||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
||||||
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req_idx],
|
'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx],
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -1203,7 +1198,7 @@ class CausalLM(Model):
|
|||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
return generations, batch if not stopped else None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
|
def generate_warmup_batch(self, request, seq_len, batch_size):
|
||||||
batch = copy.deepcopy(request.batch)
|
batch = copy.deepcopy(request.batch)
|
||||||
for req in batch.requests:
|
for req in batch.requests:
|
||||||
req.truncate = seq_len
|
req.truncate = seq_len
|
||||||
@ -1211,11 +1206,13 @@ class CausalLM(Model):
|
|||||||
for i in range(len(batch.requests) - batch_size):
|
for i in range(len(batch.requests) - batch_size):
|
||||||
batch.requests.pop()
|
batch.requests.pop()
|
||||||
|
|
||||||
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup)
|
return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True)
|
||||||
|
|
||||||
|
|
||||||
def warmup(self, request) -> None:
|
def warmup(self, request) -> None:
|
||||||
is_warmup = True
|
is_warmup = True
|
||||||
|
MAX_TOTAL_TOKENS = request.max_total_tokens
|
||||||
|
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
|
||||||
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
|
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
@ -1226,99 +1223,43 @@ class CausalLM(Model):
|
|||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
)
|
)
|
||||||
|
|
||||||
global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
|
#warmup decode batch size
|
||||||
max_input_length = batch.input_ids.shape[1]
|
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
|
||||||
batch_size = 1
|
|
||||||
while batch_size <= max_prefill_batch_size:
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
||||||
batch_size = batch_size * 2
|
|
||||||
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
|
|
||||||
|
|
||||||
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
|
||||||
i = 0
|
|
||||||
while seq_len <= max_input_length:
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
|
||||||
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
|
|
||||||
i += 1
|
|
||||||
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
|
|
||||||
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
|
|
||||||
|
|
||||||
#Prefill and decode warmup
|
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST = []
|
|
||||||
prefill_batch = None
|
|
||||||
decode_batch = None
|
|
||||||
try:
|
|
||||||
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
|
||||||
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
|
||||||
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
|
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
|
||||||
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)
|
|
||||||
|
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
||||||
|
|
||||||
except:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Not enough memory to handle following prefill and decode warmup."
|
|
||||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
|
||||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
|
||||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
|
||||||
)
|
|
||||||
|
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
|
||||||
logger.info(
|
|
||||||
f"\nFollowing prefill and decode warmup successfully.\n"
|
|
||||||
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
|
||||||
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
|
||||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
|
|
||||||
f"Memory stats: {mem_stats} "
|
|
||||||
)
|
|
||||||
|
|
||||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
batch_size = max_prefill_batch_size * 2
|
|
||||||
# Decode warmup with bigger batch_size
|
|
||||||
try:
|
|
||||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
|
|
||||||
batches = []
|
|
||||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
|
||||||
batches.append(prefill_batch)
|
|
||||||
while batch_size <= max_decode_batch_size:
|
|
||||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
|
||||||
batch_size = batch_size * 2
|
|
||||||
batches.clear()
|
|
||||||
|
|
||||||
for i in range(int(batch_size/max_prefill_batch_size)) :
|
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
|
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
|
||||||
batches.append(prefill_batch)
|
|
||||||
|
|
||||||
batches.clear()
|
|
||||||
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
|
|
||||||
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
|
|
||||||
batch_size = max_decode_batch_size
|
batch_size = max_decode_batch_size
|
||||||
for i in range(int(max_decode_batch_size / 2)) :
|
self.limit_hpu_graph = True
|
||||||
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
|
try:
|
||||||
|
while batch_size > 1:
|
||||||
|
batches= []
|
||||||
|
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
|
for i in range(iters):
|
||||||
|
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
batches.append(prefill_batch)
|
batches.append(prefill_batch)
|
||||||
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
|
|
||||||
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
|
|
||||||
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
|
|
||||||
except :
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
|
|
||||||
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
|
|
||||||
f"max_decode_batch_size is {max_decode_batch_size}"
|
|
||||||
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if batch_size % max_prefill_batch_size != 0:
|
||||||
|
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
_, decode_batch, _ = self.generate_token(batches, is_warmup)
|
||||||
|
logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}")
|
||||||
|
batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE)
|
||||||
|
except:
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
|
||||||
|
self.model.clear_cache()
|
||||||
|
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Not enough memory to warmup all batch size of decode."
|
||||||
|
f"You need to decrease `--max-batch-total-tokens`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})."
|
||||||
|
f"You need to decrease `--max-batch-total-tokens`"
|
||||||
|
)
|
||||||
|
DECODE_WARMUP_BATCH_SIZE_LIST.sort()
|
||||||
mem_stats = get_hpu_memory_stats(self.device)
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"\nFollowing decode warmup successfully.\n"
|
f"\nFollowing decode warmup successfully.\n"
|
||||||
@ -1326,4 +1267,48 @@ class CausalLM(Model):
|
|||||||
f"Memory stats: {mem_stats} "
|
f"Memory stats: {mem_stats} "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Warmup prefill batch_size
|
||||||
|
max_input_length = request.max_input_length
|
||||||
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
|
batch_size = max_prefill_batch_size
|
||||||
|
while batch_size >= 1:
|
||||||
|
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
|
batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE)
|
||||||
|
|
||||||
|
seq_len = max_input_length
|
||||||
|
while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
|
||||||
|
seq_len = math.floor(seq_len/2)
|
||||||
|
|
||||||
|
if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF:
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF)
|
||||||
|
|
||||||
|
#Prefill and decode warmup
|
||||||
|
prefill_batch = None
|
||||||
|
PREFILL_WARMUP_BATCH_SIZE_LIST.sort()
|
||||||
|
PREFILL_WARMUP_SEQLEN_LIST.sort()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
|
||||||
|
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
|
||||||
|
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
|
||||||
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
|
||||||
|
except:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not enough memory to run following prefill batch_size."
|
||||||
|
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
|
||||||
|
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
|
||||||
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
|
)
|
||||||
|
|
||||||
|
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
|
if limit_hpu_graph == False:
|
||||||
|
mem_stats = get_hpu_memory_stats(self.device)
|
||||||
|
logger.info(
|
||||||
|
f"\nFollowing prefill and decode warmup successfully.\n"
|
||||||
|
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
|
||||||
|
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
|
||||||
|
f"Memory stats: {mem_stats} "
|
||||||
|
)
|
||||||
|
|
||||||
return MAX_BATCH_TOTAL_TOKENS
|
return MAX_BATCH_TOTAL_TOKENS
|
@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -10,7 +11,7 @@ from text_generation_server.models.types import Batch, Generation
|
|||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
|
import time
|
||||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
@ -110,6 +111,7 @@ class Model(ABC):
|
|||||||
all_input_ids[prefix_offset:read_offset],
|
all_input_ids[prefix_offset:read_offset],
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_text = self.tokenizer.decode(
|
new_text = self.tokenizer.decode(
|
||||||
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user