diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 52a2ea61..334f004e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1508,7 +1508,7 @@ class FlashCausalLM(Model): num_blocks * BLOCK_SIZE, ) self.bucketing_ctx.num_hpu_blocks = num_blocks - if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true": + if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": logger.info("skip warmup hpu graph, not recommmended") del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens @@ -1524,23 +1524,26 @@ class FlashCausalLM(Model): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size) + self.warmup_prefill(seq_len, batch_size, batch) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) for index in range(warmup_times): - self.warmup_decode(batch_size, block_num) + self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) - def warmup_prefill(self, prompt_len: int, bs: int): - logger.info(f"warmup prefill seq {prompt_len} bs {bs}") + def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch): input_ids = torch.zeros( - prompt_len, dtype=torch.int64, device=self.device + prompt_len, dtype=batch.input_ids.dtype, device=self.device ).repeat(bs) position_ids = torch.arange( - prompt_len, dtype=torch.int32, device=self.device + prompt_len, dtype=batch.position_ids.dtype, device=self.device ).repeat(bs) max_bt = (prompt_len // BLOCK_SIZE + 1) * bs block_tables = torch.arange( @@ -1552,7 +1555,7 @@ class FlashCausalLM(Model): for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len @@ -1581,10 +1584,13 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) - def warmup_decode(self, batch_size: int, block_num: int): - logger.info(f"warmup decode bs {batch_size} block_num {block_num}") - input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) - position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch): + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size past_len = [] @@ -1599,7 +1605,7 @@ class FlashCausalLM(Model): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -1725,7 +1731,6 @@ class FlashCausalLM(Model): padded_bs = input_lengths.shape[0] orig_bs = input_lengths.shape[0] if padded_bs != input_lengths.shape[0]: - orig_bs = input_lengths.shape[0] padded_input_lengths = F.pad( input_lengths, (0, padded_bs - orig_bs), @@ -1754,7 +1759,7 @@ class FlashCausalLM(Model): position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 ) slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 ) if lm_head_indices is not None: lm_head_indices = F.pad( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 2f9de99f..725e7517 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -383,9 +383,12 @@ class FlashVlmCausalLM(FlashCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch ): - logger.info(f"warmup decode bs {batch_size} block_num {block_num}") - input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) - position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) if batch.position_ids is not None and batch.position_ids.dim() == 2: # qwen2_vl and qwen2_5_vl case position_ids = position_ids.unsqueeze(-1).repeat( @@ -405,7 +408,7 @@ class FlashVlmCausalLM(FlashCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -438,9 +441,12 @@ class FlashVlmCausalLM(FlashCausalLM): kv_cache=self.kv_cache, slots=slots, seqlen=trim_seqlen_metadata(seqlen), - lm_head_indices=None, - adapter_data=None, hpu_attention_meta=hpu_attention_meta, + lm_head_indices=None, + pixel_values=None, + pixel_attention_mask=None, + image_sizes=None, + image_grid_thw=None, ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): @@ -450,6 +456,9 @@ class FlashVlmCausalLM(FlashCausalLM): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) @@ -546,8 +555,8 @@ class FlashVlmCausalLM(FlashCausalLM): ) else: padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] if padded_bs != input_lengths.shape[0]: - orig_bs = input_lengths.shape[0] padded_input_lengths = F.pad( input_lengths, (0, padded_bs - orig_bs), @@ -586,7 +595,7 @@ class FlashVlmCausalLM(FlashCausalLM): value=1, ) slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 ) if lm_head_indices is not None: lm_head_indices = F.pad( @@ -621,4 +630,6 @@ class FlashVlmCausalLM(FlashCausalLM): batch.image_sizes = None if batch.image_grid_thw is not None: batch.image_grid_thw = None - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 55d80ca5..acd5d9a5 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -27,6 +27,7 @@ from text_generation_server.utils.import_utils import ( synchronize, ) import torch.nn.functional as F +from text_generation_server.utils.log import log_master tracer = trace.get_tracer(__name__) @@ -208,9 +209,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch ): - logger.info(f"warmup decode bs {batch_size} block_num {block_num}") - input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) - position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size past_len = [] @@ -225,7 +229,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -266,12 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): - logger.info(f"warmup prefill seq {prompt_len} bs {bs}") input_ids = torch.zeros( - prompt_len, dtype=torch.int64, device=self.device + prompt_len, dtype=batch.input_ids.dtype, device=self.device ).repeat(bs) position_ids = torch.arange( - prompt_len, dtype=torch.int32, device=self.device + prompt_len, dtype=batch.position_ids.dtype, device=self.device ).repeat(bs) max_bt = (prompt_len // BLOCK_SIZE + 1) * bs block_tables = torch.arange( @@ -283,7 +286,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len @@ -320,12 +323,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) @@ -425,8 +432,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) else: padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] if padded_bs != input_lengths.shape[0]: - orig_bs = input_lengths.shape[0] padded_input_lengths = F.pad( input_lengths, (0, padded_bs - orig_bs), @@ -455,7 +462,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 ) slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 ) if lm_head_indices is not None: lm_head_indices = F.pad( @@ -484,4 +491,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) if batch.pixel_values is not None: batch.pixel_values = None - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + )