From 67ee45a27038f261c4ad5677f51a8463088e6f60 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Thu, 10 Oct 2024 07:31:50 +0000 Subject: [PATCH] Pass the max_batch_total_tokens to causal_lm refine the warmup Signed-off-by: yuanwu --- Dockerfile | 4 +- backends/client/src/v2/client.rs | 2 + backends/client/src/v2/sharded_client.rs | 2 + backends/client/src/v3/client.rs | 2 + backends/client/src/v3/sharded_client.rs | 2 + backends/v3/src/backend.rs | 5 + backends/v3/src/client/grpc_client.rs | 2 + backends/v3/src/client/sharded_client.rs | 2 + backends/v3/src/lib.rs | 3 + backends/v3/src/queue.rs | 37 ++- proto/generate.proto | 1 + proto/v3/generate.proto | 1 + server/text_generation_server/cli.py | 1 + .../models/causal_lm.py | 211 ++++++++---------- server/text_generation_server/models/model.py | 4 +- 15 files changed, 160 insertions(+), 119 deletions(-) diff --git a/Dockerfile b/Dockerfile index f4430a33..c7967bea 100644 --- a/Dockerfile +++ b/Dockerfile @@ -97,5 +97,5 @@ FROM base COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh -#ENTRYPOINT ["/tgi-entrypoint.sh"] -# CMD ["--json-output"] +ENTRYPOINT ["/tgi-entrypoint.sh"] +CMD ["--json-output"] diff --git a/backends/client/src/v2/client.rs b/backends/client/src/v2/client.rs index 9a2e6ac7..7a6922f0 100644 --- a/backends/client/src/v2/client.rs +++ b/backends/client/src/v2/client.rs @@ -110,6 +110,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let mut n_tokens = 0; @@ -175,6 +176,7 @@ impl Client { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); diff --git a/backends/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs index 7b24aec3..2709ea88 100644 --- a/backends/client/src/v2/sharded_client.rs +++ b/backends/client/src/v2/sharded_client.rs @@ -104,6 +104,7 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self @@ -114,6 +115,7 @@ impl ShardedClient { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, max_batch_size, )) }) diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 479d31bf..6274c359 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -110,6 +110,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let mut n_tokens = 0; @@ -203,6 +204,7 @@ impl Client { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 645c076a..9709dfcb 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -104,6 +104,7 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self @@ -114,6 +115,7 @@ impl ShardedClient { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, max_batch_size, )) }) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index f8a10ca2..122b4909 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -27,6 +27,8 @@ impl BackendV3 { pub(crate) fn new( client: ShardedClient, waiting_served_ratio: f32, + max_input_tokens: u32, + max_total_tokens: u32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, @@ -51,6 +53,8 @@ impl BackendV3 { prefix_caching, window_size, speculate, + max_input_tokens, + max_total_tokens, max_batch_total_tokens, ); let batching_task_notifier = Arc::new(Notify::new()); @@ -152,6 +156,7 @@ pub(crate) async fn batching_task( .await; let mut waiting_tokens = 1; + tracing::error!("Enter cached batch loop"); // We loop until we do not receive any cached batch from the inference server (== until // all requests have met their stopping criteria) while let Some(batch) = cached_batch { diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 648662db..4508b92d 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -111,6 +111,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let mut n_tokens = 0; @@ -203,6 +204,7 @@ impl Client { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index ea77a696..0b6e916e 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -105,6 +105,7 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self @@ -115,6 +116,7 @@ impl ShardedClient { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, max_batch_size, )) }) diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 77a9a11a..f3372923 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -94,6 +94,7 @@ pub async fn connect_backend( max_input_tokens as u32, max_batch_prefill_tokens, max_total_tokens as u32, + max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))), max_batch_size, ) .await @@ -114,6 +115,8 @@ pub async fn connect_backend( let backend = BackendV3::new( sharded_client, waiting_served_ratio, + max_input_tokens as u32, + max_total_tokens as u32, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index f8123b57..4ce54a79 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -49,6 +49,8 @@ impl Queue { prefix_caching: bool, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, max_batch_total_tokens: u32, ) -> Self { // Create channel @@ -61,6 +63,8 @@ impl Queue { prefix_caching, window_size, speculate, + max_input_tokens, + max_total_tokens, max_batch_total_tokens, queue_receiver, )); @@ -114,6 +118,8 @@ async fn queue_task( prefix_caching: bool, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, max_batch_total_tokens: u32, mut receiver: mpsc::UnboundedReceiver, ) { @@ -123,6 +129,8 @@ async fn queue_task( prefix_caching, window_size, speculate, + max_input_tokens, + max_total_tokens, max_batch_total_tokens, ); @@ -174,6 +182,15 @@ struct State { /// Paged Attention Block Allocation block_allocator: Option, + + /// Require padding + requires_padding: bool, + + /// max input tokens + max_input_tokens: u32, + + /// max total tokens, + max_total_tokens: u32, } impl State { @@ -183,6 +200,8 @@ impl State { prefix_caching: bool, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, max_batch_total_tokens: u32, ) -> Self { let block_allocator = (!requires_padding).then(|| { @@ -202,6 +221,9 @@ impl State { window_size, speculate, block_allocator, + requires_padding, + max_input_tokens, + max_total_tokens, } } @@ -272,10 +294,19 @@ impl State { None => { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch.len() + 1) as u32 * max_input_length; + if self.requires_padding { + prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens; + } else{ + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch.len() + 1) as u32 * max_input_length; + } + + if self.requires_padding { + decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens); + } else { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate; if prefill_tokens > prefill_token_budget || total_tokens > token_budget { diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f..7154db0e 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -228,6 +228,7 @@ message WarmupRequest { uint32 max_input_length = 2; uint32 max_prefill_tokens = 3; uint32 max_total_tokens = 4; + uint32 max_batch_total_tokens = 5; } message WarmupResponse { diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 34894bda..014ee391 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -261,6 +261,7 @@ message WarmupRequest { uint32 max_input_length = 2; uint32 max_prefill_tokens = 3; uint32 max_total_tokens = 4; + uint32 max_batch_total_tokens = 5; } message WarmupResponse { diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index baf94986..756322c9 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -97,6 +97,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = "bfloat16" if dtype is None else dtype.value + logger.info(f"quantize={quantize}") if dtype is not None and quantize not in { None, "bitsandbytes", diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 9f88c4fc..6d8092b9 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -54,18 +54,12 @@ from text_generation_server.utils.debug import dbg_trace from text_generation_server.utils.speculate import get_speculate tracer = trace.get_tracer(__name__) - -# MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048)) -# BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) -# PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) -# PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) -# CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -# LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) -MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) -MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 65536)) +MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) +BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8)) +PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2)) PREFILL_WARMUP_BATCH_SIZE_LIST = [] @@ -81,6 +75,9 @@ def torch_compile_for_eager(func): def round_up(warmup_list:list, num) : i = 0 + if len(warmup_list) == 0: + return num + for i in warmup_list: if num <= i : break @@ -525,22 +522,20 @@ class CausalLMBatch(Batch): return_token_type_ids=False, truncation=True, max_length=max_truncation, - ).to(device) + ) input_len = tokenized_inputs["input_ids"].shape[1] - # Round up sequence length bucket_size = max_input_length left_padding = max_input_length - input_len - if is_warmup is False: - if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len + if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: + assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] @@ -554,7 +549,7 @@ class CausalLMBatch(Batch): ) all_input_ids = torch.nn.functional.pad( input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1)[0:len(pb.requests)] + ).T.split(1, dim=1) input_len = bucket_size for r in requests: r.input_length = input_len @@ -567,7 +562,6 @@ class CausalLMBatch(Batch): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 @@ -908,6 +902,7 @@ class CausalLM(Model): kwargs["bypass_hpu_graphs"] = bypass_hpu_graph kwargs.update(self.kwargs) + if past_key_values is not None: return self.model.forward(**kwargs) else: @@ -972,7 +967,7 @@ class CausalLM(Model): 'top_n_tokens': batch.top_n_tokens[req_idx], 'top_token_ids': batch_top_token_ids[req_idx], 'top_token_logprobs': batch_top_token_logprobs[req_idx], - 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req_idx], + 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], }) @@ -1203,7 +1198,7 @@ class CausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, batch if not stopped else None, (forward_ns, decode_ns) - def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): + def generate_warmup_batch(self, request, seq_len, batch_size): batch = copy.deepcopy(request.batch) for req in batch.requests: req.truncate = seq_len @@ -1211,11 +1206,13 @@ class CausalLM(Model): for i in range(len(batch.requests) - batch_size): batch.requests.pop() - return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup) + return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True) def warmup(self, request) -> None: is_warmup = True + MAX_TOTAL_TOKENS = request.max_total_tokens + MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup) try: # max prefill batch size warmup @@ -1226,104 +1223,92 @@ class CausalLM(Model): f"You need to decrease `--max-batch-prefill-tokens`" ) - global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - max_input_length = batch.input_ids.shape[1] + #warmup decode batch size max_prefill_batch_size = batch.input_ids.shape[0] - PREFILL_WARMUP_BATCH_SIZE_LIST = [] - batch_size = 1 - while batch_size <= max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : - PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - - seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF - PREFILL_WARMUP_SEQLEN_LIST = [] - i = 0 - while seq_len <= max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) - i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) - - #Prefill and decode warmup - DECODE_WARMUP_BATCH_SIZE_LIST = [] - prefill_batch = None - decode_batch = None - try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : - for seq_len in PREFILL_WARMUP_SEQLEN_LIST : - batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) - - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - - except: - raise RuntimeError( - f"Not enough memory to handle following prefill and decode warmup." - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info( - f"\nFollowing prefill and decode warmup successfully.\n" - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats} " - ) - max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - batch_size = max_prefill_batch_size * 2 - # Decode warmup with bigger batch_size + batch_size = max_decode_batch_size + self.limit_hpu_graph = True try: - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: - batches = [] - for i in range(int(batch_size/max_prefill_batch_size)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + while batch_size > 1: + batches= [] + iters = math.floor(batch_size/max_prefill_batch_size) + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + for i in range(iters): + batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size) _, prefill_batch, _ = self.generate_token([batch], is_warmup) batches.append(prefill_batch) - while batch_size <= max_decode_batch_size: - _, decode_batch, _ = self.generate_token(batches, is_warmup) - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - batches.clear() - for i in range(int(batch_size/max_prefill_batch_size)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) - - batches.clear() - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: - max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 - batch_size = max_decode_batch_size - for i in range(int(max_decode_batch_size / 2)) : - batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) - _, prefill_batch, _ = self.generate_token([batch], is_warmup) - batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup) - DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) - max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS - MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens - except : - raise RuntimeError( - f"Not enough memory to handle batch_size({batch_size}) decode warmup." - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"max_decode_batch_size is {max_decode_batch_size}" - f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'" - ) + if batch_size % max_prefill_batch_size != 0: + batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(batch) + _, decode_batch, _ = self.generate_token(batches, is_warmup) + logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}") + batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE) + except: + DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1) + self.model.clear_cache() + if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0: + logger.warning( + f"Not enough memory to warmup all batch size of decode." + f"You need to decrease `--max-batch-total-tokens`" + ) + else: + raise RuntimeError( + f"Not enough memory to warmup decode batch_size({max_decode_batch_size})." + f"You need to decrease `--max-batch-total-tokens`" + ) + DECODE_WARMUP_BATCH_SIZE_LIST.sort() mem_stats = get_hpu_memory_stats(self.device) logger.info( f"\nFollowing decode warmup successfully.\n" f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" - f"Memory stats: {mem_stats}" + f"Memory stats: {mem_stats} " ) - return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file + # Warmup prefill batch_size + max_input_length = request.max_input_length + max_prefill_batch_size = batch.input_ids.shape[0] + batch_size = max_prefill_batch_size + while batch_size >= 1: + PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE) + + seq_len = max_input_length + while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF: + PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) + seq_len = math.floor(seq_len/2) + + if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF: + PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF) + + #Prefill and decode warmup + prefill_batch = None + PREFILL_WARMUP_BATCH_SIZE_LIST.sort() + PREFILL_WARMUP_SEQLEN_LIST.sort() + + try: + for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : + for seq_len in PREFILL_WARMUP_SEQLEN_LIST : + batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + except: + raise RuntimeError( + f"Not enough memory to run following prefill batch_size." + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + if limit_hpu_graph == False: + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Memory stats: {mem_stats} " + ) + + return MAX_BATCH_TOTAL_TOKENS diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 20402e07..4568b8dd 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -1,4 +1,5 @@ import inspect +from loguru import logger import torch from abc import ABC, abstractmethod @@ -10,7 +11,7 @@ from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights - +import time BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -110,6 +111,7 @@ class Model(ABC): all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens, ) + new_text = self.tokenizer.decode( all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens )