From d34ffc4fe9f468436d6f81b6334454e24af31ba5 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Fri, 2 Aug 2024 04:36:59 +0000 Subject: [PATCH] Refile the hpu warmup Signed-off-by: yuanwu --- .../models/custom_modeling/llava_next.py | 1 - .../models/vlm_causal_lm.py | 329 ++++++++++++++---- server/text_generation_server/server.py | 7 +- 3 files changed, 269 insertions(+), 68 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 4268cc9ba..319a6d28e 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -26,7 +26,6 @@ from transformers.models.llava_next.modeling_llava_next import ( ) from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration from transformers.image_processing_utils import select_best_resolution -from loguru import logger def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 51c4b3409..5c6c90c69 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -10,6 +10,7 @@ import numpy from opentelemetry import trace from loguru import logger from typing import Optional, Tuple, List, Type, Dict +import itertools import tempfile import copy from text_generation_server.models import Model @@ -20,9 +21,10 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import ( CausalLMBatch, CausalLMRequest, - round_up, - remove_kv_cache_from_output + remove_kv_cache_from_output, + biggest_single_chunk, ) + from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, ) @@ -43,6 +45,7 @@ from text_generation_server.utils import ( import habana_frameworks.torch as htorch from optimum.habana.utils import HabanaProfile from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES +from optimum.habana.utils import get_hpu_memory_stats from transformers import ( AutoTokenizer, @@ -70,18 +73,20 @@ tracer = trace.get_tracer(__name__) IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) -BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 1)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) -PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 1)) +MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) -PREFILL_GRAPH_NUM = int(os.environ.get('PREFILL_GRAPH_NUM', 16)) -os.environ['MAX_TOTAL_TOKENS'] = str(MAX_TOTAL_TOKENS) -os.environ['BATCH_BUCKET_SIZE'] = str(BATCH_BUCKET_SIZE) -os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF) -os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE) -os.environ['LAZY_MODE'] = str(LAZY_MODE) +PREFILL_WARMUP_BATCH_SIZE_LIST = [] +PREFILL_WARMUP_SEQLEN_LIST = [] +DECODE_WARMUP_BATCH_SIZE_LIST = [] +def round_up(warmup_list:list, num) : + i = 0 + for i in warmup_list: + if num <= i : + break + return i def split(string) -> List[Dict[str, str]]: parts = [] @@ -186,6 +191,7 @@ class VlmCausalLMBatch(CausalLMBatch): batch_tokenized_inputs, dtype: torch.dtype, device: torch.device, + is_warmup: bool = False, ) -> "VlmCausalLMBatch": dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') @@ -200,7 +206,7 @@ class VlmCausalLMBatch(CausalLMBatch): # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out - new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) parameters = [r.parameters for r in pb.requests] # append the dummy parameters for dummy request parameters = pad_next_token_chooser_parameters(parameters, new_bs) @@ -217,14 +223,14 @@ class VlmCausalLMBatch(CausalLMBatch): bucket_size = max_input_length left_padding = max_input_length - input_len - if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len + if is_warmup is False: + if input_len < max_input_length : + rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] @@ -269,7 +275,7 @@ class VlmCausalLMBatch(CausalLMBatch): ) @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs(cls, requests, tokenizer, processor, config, is_warmup): batch_inputs = [] image_inputs = [] max_truncation = 0 @@ -301,17 +307,19 @@ class VlmCausalLMBatch(CausalLMBatch): batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) - new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) - missing_inputs = new_bs - len(requests) - dummy_images = [] - dummy_inputs = [] - if len(batch_inputs) > 0 and len(image_inputs) > 0: - dummy_inputs = [batch_inputs[0]] * missing_inputs - dummy_images = [image_inputs[0]] * missing_inputs + if is_warmup is False: + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + missing_inputs = new_bs - len(requests) + dummy_images = [] + dummy_inputs = [] + if len(batch_inputs) > 0 and len(image_inputs) > 0: + dummy_inputs = [batch_inputs[0]] * missing_inputs + dummy_images = [image_inputs[0]] * missing_inputs + image_inputs += dummy_images + batch_inputs += dummy_inputs - image_inputs += dummy_images batch_tokenized_inputs = tokenizer( - batch_inputs + dummy_inputs, + batch_inputs, truncation=True, max_length=max_truncation, return_tensors="pt", @@ -347,9 +355,10 @@ class VlmCausalLMBatch(CausalLMBatch): config, dtype: torch.dtype, device: torch.device, + is_warmup: bool = False, ) -> "VlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config + pb.requests, tokenizer, processor, config, is_warmup ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: @@ -370,6 +379,114 @@ class VlmCausalLMBatch(CausalLMBatch): batch.image_sizes = None return batch + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch": + return cls.recombine(batches, pad_token_id, is_warmup) + + + + @classmethod + def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch": + if not all(b.past_key_values is not None for b in batches): + raise ValueError("KV cache not allocated! Cannot recombine before prefill!") + + total_requests = sum(len(b) for b in batches) + new_bs = total_requests + if is_warmup is False : + new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) + batch_id = batches[0].batch_id + device = batches[0].input_ids.device + + input_lengths = [b.input_length for b in batches] + max_input_length = max(input_lengths) + offsets = [max_input_length - b.input_length for b in batches] + + cur_padding = [b.right_padding for b in batches] + # For prefill there is a space allocated only for first token + # Need to add padding to the max total tokens before first decode + + moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + reshape = (batches[dst_batch_idx].batch_size < new_bs) + + # TODO: Add support for changing max seq len, i.e. due to output length bucketing + # FIXME: max_seq_len for non optimized code + if len(batches) > 1: + scenario = 'CONCAT' + elif reshape: + scenario = 'RESHAPE' + elif cur_padding[dst_batch_idx] <= 0: + scenario = 'SHIFT' + offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + max_input_length = max_input_length + offsets[dst_batch_idx] + else: + # Nothing to do + return batches[0] + + dbg_trace( + scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' + f' reqs:{[len(b) for b in batches]}' + f' offsets:{offsets}' + f' input_lengths:{input_lengths}' + f' cur_padding:{cur_padding}' + f' dst_batch:{dst_batch_idx}') + + grouped_requests = [[req for req in batch.requests] for batch in batches] + flat_requests = list(itertools.chain(*grouped_requests)) + + for i in range(len(batches)): + target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size + batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) + batches[i].realign(target_bs, offsets[i], pad_token_id) + batches[i].split_kv_cache_if_needed(i == dst_batch_idx) + batches[dst_batch_idx].expand_bs(new_bs) + batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) + + top_n_tokens = [r.data.top_n_tokens for r in flat_requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + parameters = [r.data.parameters for r in flat_requests] + # append the dummy parameters for dummy requests + batch_size = batches[dst_batch_idx].batch_size + parameters = pad_next_token_chooser_parameters(parameters, batch_size) + + # update past grammar states + fsm_grammar_states = [0] * batch_size + for batch in batches: + for i, req in enumerate(batch.requests): + fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + parameters, + batches[dst_batch_idx].next_token_chooser.dtype, + batches[dst_batch_idx].next_token_chooser.device, + batches[dst_batch_idx].next_token_chooser.tokenizer, + fsm_grammar_states, + quantization_enabled=hq_env.is_quantization_enabled, + ) + + input_ids = batches[dst_batch_idx].input_ids + attention_mask = batches[dst_batch_idx].attention_mask + position_ids = batches[dst_batch_idx].position_ids + past_key_values = batches[dst_batch_idx].past_key_values + input_length = max_input_length + + htorch.core.mark_step() + + return cls( + batch_id=batch_id, + requests=flat_requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_length, + ) class VlmCausalLM(Model): def __init__( @@ -672,7 +789,7 @@ class VlmCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batches: List[VlmCausalLMBatch] + self, batches: List[VlmCausalLMBatch], is_warmup: bool = False ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: start = time.time_ns() # Results @@ -755,7 +872,7 @@ class VlmCausalLM(Model): # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) else: batch = batches[0] @@ -763,12 +880,12 @@ class VlmCausalLM(Model): # Check if we need to do any bookkeeping first if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) scenario = 'PREFILL' if prefill else 'GENERATE' - if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs: + if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: self.model.clear_cache() - self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE) + self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) dbg_trace( scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') #assert batch.right_padding > 0, 'No more room for next token!' @@ -959,17 +1076,18 @@ class VlmCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, batch if not stopped else None, (forward_ns, decode_ns) - def batch_from_pb(self, batch): + def batch_from_pb(self, batch, is_warmup): return VlmCausalLMBatch.from_pb_processor( batch, self.tokenizer, self.processor, self.model.config, self.dtype, - self.device + self.device, + is_warmup ) - def generate_warmup_batch(self, request, seq_len, batch_size): + def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): batch = copy.deepcopy(request.batches[0]) for req in batch.requests: req.truncate = seq_len @@ -977,39 +1095,122 @@ class VlmCausalLM(Model): for i in range(len(batch.requests) - batch_size): batch.requests.pop() - return self.batch_from_pb(batch) + return self.batch_from_pb(batch, is_warmup) def warmup(self, request) -> None: - batches = [self.batch_from_pb(batch) for batch in request.batches] + is_warmup = True + batches = [self.batch_from_pb(batch, is_warmup) for batch in request.batches] try: - # prefill - _, prefill_batch, _ = self.generate_token([batches[0]]) - except torch.cuda.OutOfMemoryError as e: + # max prefill batch size warmup + _, prefill_batch, _ = self.generate_token([batches[0]], is_warmup) + except: raise RuntimeError( f"Not enough memory to handle {len(batches[0].input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" - ) from e + ) - global BASE_IMAGE_TOKENS, PAD_SEQUENCE_TO_MULTIPLE_OF, PREFILL_BATCH_BUCKET_SIZE, PREFILL_GRAPH_NUM + self.model.clear_inputs() + global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST max_input_length = batches[0].input_ids.shape[1] - max_batch_size = batches[0].input_ids.shape[0] - seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF - batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE - while batch_num > PREFILL_GRAPH_NUM : - PREFILL_BATCH_BUCKET_SIZE = PREFILL_BATCH_BUCKET_SIZE * 2 - os.environ['PREFILL_BATCH_BUCKET_SIZE'] = str(PREFILL_BATCH_BUCKET_SIZE) - batch_num = max_batch_size / PREFILL_BATCH_BUCKET_SIZE + max_prefill_batch_size = batches[0].input_ids.shape[0] + PREFILL_WARMUP_BATCH_SIZE_LIST = [] + batch_size = 1 + while batch_size <= max_prefill_batch_size: + PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : + PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - while seq_num * batch_num >= PREFILL_GRAPH_NUM : - PAD_SEQUENCE_TO_MULTIPLE_OF = PAD_SEQUENCE_TO_MULTIPLE_OF * 2 - os.environ['PAD_SEQUENCE_TO_MULTIPLE_OF'] = str(PAD_SEQUENCE_TO_MULTIPLE_OF) - seq_num = (max_input_length - BASE_IMAGE_TOKENS) / PAD_SEQUENCE_TO_MULTIPLE_OF + seq_len = BASE_IMAGE_TOKENS + PREFILL_WARMUP_SEQLEN_LIST = [] + i = 0 + while seq_len <= max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) + seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) + i += 1 + if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) - seq_lens_list = numpy.arange(BASE_IMAGE_TOKENS + PAD_SEQUENCE_TO_MULTIPLE_OF, max_input_length + 1, PAD_SEQUENCE_TO_MULTIPLE_OF).tolist() - batch_sizes_list = numpy.arange(PREFILL_BATCH_BUCKET_SIZE, max_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE).tolist() - for seq_len in seq_lens_list : - for batch_size in batch_sizes_list : - batch = self.generate_warmup_batch(request, seq_len, batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - _, decode_batch, _ = self.generate_token([prefill_batch]) + #Prefill and decode warmup + DECODE_WARMUP_BATCH_SIZE_LIST = [] + prefill_batch = None + decode_batch = None + try: + for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : + for seq_len in PREFILL_WARMUP_SEQLEN_LIST : + batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) + + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + + except: + raise RuntimeError( + f"Not enough memory to handle following prefill and decode warmup." + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats} " + ) + + self.model.clear_inputs() + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) + batch_size = max_prefill_batch_size * 2 + # Decode warmup with bigger batch_size + try: + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: + batches = [] + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + while batch_size <= max_decode_batch_size: + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + batches.clear() + + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + + batches.clear() + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: + max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 + batch_size = max_decode_batch_size + for i in range(int(max_decode_batch_size / 2)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) + max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS + MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens + except : + raise RuntimeError( + f"Not enough memory to handle batch_size({batch_size}) decode warmup." + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"max_decode_batch_size is {max_decode_batch_size}" + f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats}" + ) + + self.model.clear_inputs() + return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 0b5e9e035..4cb7fb24f 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -97,12 +97,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) if self.model.batch_type in VLM_BATCH_TYPES : - self.model.warmup(request) + max_supported_total_tokens = self.model.warmup(request) + return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) else: batches = [batch_from_pb(batch) for batch in request.batches] self.model.warmup(batches) + return generate_pb2.WarmupResponse() - return generate_pb2.WarmupResponse() async def Prefill(self, request, context): start = time.time_ns() @@ -171,7 +172,7 @@ def serve( uds_path: Path, ): # Remove default handler - logger.remove() + #logger.remove() logger.add( sys.stdout, format="{message}",