diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 54a0bb7c..bd6c58b4 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -1,5 +1,5 @@ # Those arguments are required to build the image -ARG HABANA_VERSION=1.20.0 +ARG HABANA_VERSION=1.21.0 ARG PYTORCH_VERSION=2.6.0 # Rust builder @@ -62,6 +62,7 @@ ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV PT_HPU_LAZY_MODE=1 ENV PT_HPU_WEIGHT_SHARING=0 +ENV VLLM_EXPONENTIAL_BUCKETING=true # Text Generation Inference base env ENV HF_HOME=/data \ diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index c153a5ff..77581517 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) root_dir := ${mkfile_dir}/../.. -HABANA_VERSION := 1.20.0 +HABANA_VERSION := 1.21.0 PYTORCH_VERSION := 2.6.0 .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index eb0f7454..bc0d240e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -76,6 +76,7 @@ import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools from vllm_hpu_extension.bucketing.common import get_bucketing_context +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -1357,6 +1358,8 @@ class FlashCausalLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() + if world_size > 1: + self.process_group_cpu = torch.distributed.new_group(backend="gloo") device = torch.device("hpu") dtype = torch.bfloat16 if dtype is None else dtype @@ -1453,6 +1456,7 @@ class FlashCausalLM(Model): self.limit_hpu_graph = ( os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true" ) + self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true" self.max_seq_len_to_capture = 8192 super().__init__( model_id=model_id, @@ -1521,7 +1525,7 @@ class FlashCausalLM(Model): # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() - + self.graphed_buckets = set() # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() @@ -1533,7 +1537,20 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = cache_block_size * 2 total_cache_size = self.num_layers * cache_block_size * dtype_size - + free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM) + self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION)) + graph_reserved_mem = ( + float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) + if htorch.utils.internal.is_lazy() + else 0 + ) + mem_used_from_graph = int( + (free_memory - self.mem_reserved) * graph_reserved_mem + ) + log_master( + logger.info, + f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}", + ) try: self.init_kv_cache( batch.num_blocks, @@ -1548,15 +1565,6 @@ class FlashCausalLM(Model): num_tokens = batch.to_pb().current_tokens synchronize(self.device) - free_memory = get_free_memory( - self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM - ) - real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) - log_master( - logger.debug, - f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", - ) - _, _batch, _ = self.generate_token([batch]) except Exception: raise RuntimeError( @@ -1565,8 +1573,9 @@ class FlashCausalLM(Model): ) synchronize(self.device) - free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) - kv_memory = free_memory + free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM) + + kv_memory = free_memory - self.mem_reserved - mem_used_from_graph num_blocks = ( # Leave 5% for some wiggle room int(kv_memory // total_cache_size) @@ -1583,7 +1592,6 @@ class FlashCausalLM(Model): self.kv_cache = [] empty_cache() - self.init_kv_cache( num_blocks, self.num_layers, @@ -1595,11 +1603,16 @@ class FlashCausalLM(Model): self.max_batch_prefill_tokens = get_max_prefill_tokens() max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) HPUBucketingContext = get_bucketing_context() - max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE + # need to warmup one more step since block is allocated from 1 + block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE) + max_total_tokens_aligned = math.ceil( + max_total_tokens / BLOCK_SIZE + ) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs) model_max_length = self.tokenizer.model_max_length max_position_embeddings = getattr( self.config, "max_position_embeddings", model_max_length ) + self.bucketing_ctx = HPUBucketingContext( max_num_seqs, max_num_seqs, # self.max_num_prefill_seqs, #TODO @@ -1610,31 +1623,75 @@ class FlashCausalLM(Model): max_input_tokens, max_total_tokens_aligned, ) - max_blocks = ( - max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1 + max_blocks = max( + BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE ) self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks) - if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": + synchronize(self.device) + if self.skip_warmup: self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_decode_buckets( self.bucketing_ctx.num_hpu_blocks ) - logger.info("skip warmup hpu graph, not recommmended") + log_master( + logger.info, "skip warmup hpu graph, not recommmended, may cause OOM" + ) del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - self.warmup_hpu_graph(batch) del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture): - if self.limit_hpu_graph: - return prefill - else: - return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture + def log_warmup(self, prefilling, i, max_i, batch_size, seq_len): + free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory()) + phase = "Prompt" if prefilling else "Decode" + dim = "seq_len" if prefilling else "num_blocks" + graphed_bucket = (batch_size, seq_len, prefilling) + bypass = graphed_bucket not in self.graphed_buckets + msg = ( + f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"bypass:{bypass} " + f"free_mem:{free_mem}" + ) + log_master(logger.info, msg) + + def use_graphs(self, prefill, seq_len, batch_size): + if self.limit_hpu_graph and prefill: + return False + + if self.skip_warmup: + return True + + return (batch_size, seq_len, prefill) in self.graphed_buckets + + def align_workers(self, value, op): + if self.world_size <= 1: + return value + value_t = torch.tensor(value, device="cpu") + torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu) + return value_t.item() def warmup_hpu_graph(self, batch): + prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3")) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) + log_master(logger.info, msg) start_time = time.time() warmup_shape_count = 0 warmup_times = 3 @@ -1646,15 +1703,34 @@ class FlashCausalLM(Model): buckets = list( sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) ) - + total_batch_seq = 0.001 + total_mem = 0 + available_mem = prompt_available_memory for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, seq_len, True) + if not ( + mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture + ): + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") - for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size, batch) - synchronize(self.device) + self.log_warmup(True, i, len(buckets), batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -1663,16 +1739,34 @@ class FlashCausalLM(Model): buckets = list( sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) ) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + total_batch_seq = 0.001 + total_mem = 0 + available_mem = free_mem - self.mem_reserved for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", @@ -1707,8 +1801,8 @@ class FlashCausalLM(Model): lm_head_indices = input_lengths - 1 kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - True, input_ids.shape[0] + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + True, prompt_len, batch_size ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. @@ -1762,7 +1856,9 @@ class FlashCausalLM(Model): slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + False, hpu_attention_meta.block_list.shape[0], batch_size + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), @@ -1858,8 +1954,14 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - batch.prefilling, input_ids.shape[0] + batch_size = input_lengths.shape[0] + prompt_len = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, prompt_len, batch_size ) logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index d9c57f20..fd239b3e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -27,6 +27,7 @@ import time from text_generation_server.utils.import_utils import ( synchronize, ) +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -487,6 +488,19 @@ class FlashVlmCausalLM(FlashCausalLM): ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + decode_available_memory = graph_free_mem + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(decode_available_memory)} for decode " + ) + log_master(logger.info, msg) start_time = time.time() warmup_shape_count = 0 warmup_times = 3 @@ -499,16 +513,34 @@ class FlashVlmCausalLM(FlashCausalLM): buckets = list( sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) ) + total_batch_seq = 0.001 + total_mem = 0 + available_mem = decode_available_memory for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", @@ -585,8 +617,15 @@ class FlashVlmCausalLM(FlashCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = batch.prefilling - + batch_size = input_lengths.shape[0] + seqlen = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, seqlen, batch_size + ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 0e5544f2..db3904a2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -33,6 +33,8 @@ from text_generation_server.utils.import_utils import ( import torch.nn.functional as F from text_generation_server.utils.log import log_master import time +import os +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -268,6 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states, image_indices, input_lengths, 1, False ) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + False, hpu_attention_meta.block_list.shape[0], batch_size + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -281,6 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states=cross_attention_states, indices=_async_h2d_tensor_copy(indices), cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), + **kwargs, ) def warmup_prefill( @@ -326,8 +334,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - True, input_ids.shape[0] + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + True, prompt_len, batch_size ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), @@ -346,6 +354,23 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3")) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) + log_master(logger.info, msg) start_time = time.time() warmup_shape_count = 0 warmup_times = 3 @@ -357,14 +382,35 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): buckets = list( sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) ) + graph_free_mem + total_batch_seq = 0.001 + total_mem = 0 + available_mem = prompt_available_memory for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, seq_len, True) + if not ( + mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture + ): + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") - for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size, batch) - synchronize(self.device) + self.log_warmup(True, i, len(buckets), batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq def ordering_function_max_bs(b): return (-b[0], b[1]) @@ -373,16 +419,34 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): buckets = list( sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) ) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + total_batch_seq = 0.001 + total_mem = 0 + available_mem = free_mem - self.mem_reserved for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) warmup_shape_count += 1 - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + log_master( logger.info, f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", @@ -462,9 +526,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( - batch.prefilling, input_ids.shape[0] + batch_size = input_lengths.shape[0] + seqlen = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, seqlen, batch_size + ) + if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 9866710b..1c45713e 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -7,7 +7,7 @@ from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) -MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) class FakeBarrier: diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index d25484d6..bdcfc9fa 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,20 +1,9 @@ import torch -from loguru import logger -import habana_frameworks.torch as htorch -import os def get_hpu_free_memory(device, memory_fraction): - graph_reserved_mem = ( - float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) - if htorch.utils.internal.is_lazy() - else 0 - ) - free_memory = int( - torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) - ) - logger.info(f"Free memory on device {device}: {free_memory} bytes.") - return free_memory + free_hpu_memory, _ = torch.hpu.mem_get_info() + return free_hpu_memory def synchronize_hpu(device):