Refine warmup and upgrade to synapse AI 1.21.0 (#3234)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-20 16:22:43 +08:00 committed by GitHub
parent d658b5def3
commit 000e313a92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 278 additions and 76 deletions

View File

@ -1,5 +1,5 @@
# Those arguments are required to build the image
ARG HABANA_VERSION=1.20.0
ARG HABANA_VERSION=1.21.0
ARG PYTORCH_VERSION=2.6.0
# Rust builder
@ -62,6 +62,7 @@ ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
ENV VLLM_EXPONENTIAL_BUCKETING=true
# Text Generation Inference base env
ENV HF_HOME=/data \

View File

@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../..
HABANA_VERSION := 1.20.0
HABANA_VERSION := 1.21.0
PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install

View File

@ -76,6 +76,7 @@ import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@ -1357,6 +1358,8 @@ class FlashCausalLM(Model):
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if world_size > 1:
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
@ -1453,6 +1456,7 @@ class FlashCausalLM(Model):
self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
)
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
self.max_seq_len_to_capture = 8192
super().__init__(
model_id=model_id,
@ -1521,7 +1525,7 @@ class FlashCausalLM(Model):
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
self.graphed_buckets = set()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
@ -1533,7 +1537,20 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
mem_used_from_graph = int(
(free_memory - self.mem_reserved) * graph_reserved_mem
)
log_master(
logger.info,
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
)
try:
self.init_kv_cache(
batch.num_blocks,
@ -1548,15 +1565,6 @@ class FlashCausalLM(Model):
num_tokens = batch.to_pb().current_tokens
synchronize(self.device)
free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
)
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
log_master(
logger.debug,
f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
)
_, _batch, _ = self.generate_token([batch])
except Exception:
raise RuntimeError(
@ -1565,8 +1573,9 @@ class FlashCausalLM(Model):
)
synchronize(self.device)
free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
kv_memory = free_memory
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
num_blocks = (
# Leave 5% for some wiggle room
int(kv_memory // total_cache_size)
@ -1583,7 +1592,6 @@ class FlashCausalLM(Model):
self.kv_cache = []
empty_cache()
self.init_kv_cache(
num_blocks,
self.num_layers,
@ -1595,11 +1603,16 @@ class FlashCausalLM(Model):
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context()
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
# need to warmup one more step since block is allocated from 1
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
max_total_tokens_aligned = math.ceil(
max_total_tokens / BLOCK_SIZE
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
model_max_length = self.tokenizer.model_max_length
max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
self.bucketing_ctx = HPUBucketingContext(
max_num_seqs,
max_num_seqs, # self.max_num_prefill_seqs, #TODO
@ -1610,31 +1623,75 @@ class FlashCausalLM(Model):
max_input_tokens,
max_total_tokens_aligned,
)
max_blocks = (
max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1
max_blocks = max(
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
)
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
synchronize(self.device)
if self.skip_warmup:
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
logger.info("skip warmup hpu graph, not recommmended")
log_master(
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
if self.limit_hpu_graph:
return prefill
else:
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
phase = "Prompt" if prefilling else "Decode"
dim = "seq_len" if prefilling else "num_blocks"
graphed_bucket = (batch_size, seq_len, prefilling)
bypass = graphed_bucket not in self.graphed_buckets
msg = (
f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"bypass:{bypass} "
f"free_mem:{free_mem}"
)
log_master(logger.info, msg)
def use_graphs(self, prefill, seq_len, batch_size):
if self.limit_hpu_graph and prefill:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, prefill) in self.graphed_buckets
def align_workers(self, value, op):
if self.world_size <= 1:
return value
value_t = torch.tensor(value, device="cpu")
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
return value_t.item()
def warmup_hpu_graph(self, batch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
@ -1646,15 +1703,34 @@ class FlashCausalLM(Model):
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
def ordering_function_max_bs(b):
return (-b[0], b[1])
@ -1663,16 +1739,34 @@ class FlashCausalLM(Model):
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
@ -1707,8 +1801,8 @@ class FlashCausalLM(Model):
lm_head_indices = input_lengths - 1
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
@ -1762,7 +1856,9 @@ class FlashCausalLM(Model):
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
@ -1858,8 +1954,14 @@ class FlashCausalLM(Model):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling, input_ids.shape[0]
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
logits, speculative_logits = self.model.forward(

View File

@ -27,6 +27,7 @@ import time
from text_generation_server.utils.import_utils import (
synchronize,
)
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@ -487,6 +488,19 @@ class FlashVlmCausalLM(FlashCausalLM):
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
decode_available_memory = graph_free_mem
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(decode_available_memory)} for decode "
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
@ -499,16 +513,34 @@ class FlashVlmCausalLM(FlashCausalLM):
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = decode_available_memory
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
@ -585,8 +617,15 @@ class FlashVlmCausalLM(FlashCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots

View File

@ -33,6 +33,8 @@ from text_generation_server.utils.import_utils import (
import torch.nn.functional as F
from text_generation_server.utils.log import log_master
import time
import os
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@ -268,6 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states, image_indices, input_lengths, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
@ -281,6 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_prefill(
@ -326,8 +334,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
@ -346,6 +354,23 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
@ -357,14 +382,35 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
graph_free_mem
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
def ordering_function_max_bs(b):
return (-b[0], b[1])
@ -373,16 +419,34 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
@ -462,9 +526,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling, input_ids.shape[0]
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots

View File

@ -7,7 +7,7 @@ from loguru import logger
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
class FakeBarrier:

View File

@ -1,20 +1,9 @@
import torch
from loguru import logger
import habana_frameworks.torch as htorch
import os
def get_hpu_free_memory(device, memory_fraction):
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
free_memory = int(
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
)
logger.info(f"Free memory on device {device}: {free_memory} bytes.")
return free_memory
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory
def synchronize_hpu(device):