mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-23 12:02:09 +00:00
Refine warmup and upgrade to synapse AI 1.21.0 (#3234)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
d658b5def3
commit
000e313a92
@ -1,5 +1,5 @@
|
||||
# Those arguments are required to build the image
|
||||
ARG HABANA_VERSION=1.20.0
|
||||
ARG HABANA_VERSION=1.21.0
|
||||
ARG PYTORCH_VERSION=2.6.0
|
||||
|
||||
# Rust builder
|
||||
@ -62,6 +62,7 @@ ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV PT_HPU_LAZY_MODE=1
|
||||
ENV PT_HPU_WEIGHT_SHARING=0
|
||||
ENV VLLM_EXPONENTIAL_BUCKETING=true
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
|
@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||
mkfile_dir := $(dir $(mkfile_path))
|
||||
root_dir := ${mkfile_dir}/../..
|
||||
|
||||
HABANA_VERSION := 1.20.0
|
||||
HABANA_VERSION := 1.21.0
|
||||
PYTORCH_VERSION := 2.6.0
|
||||
|
||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||
|
@ -76,6 +76,7 @@ import vllm_hpu_extension.environment as environment
|
||||
import habana_frameworks.torch as htorch
|
||||
import itertools
|
||||
from vllm_hpu_extension.bucketing.common import get_bucketing_context
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -1357,6 +1358,8 @@ class FlashCausalLM(Model):
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if world_size > 1:
|
||||
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
|
||||
|
||||
device = torch.device("hpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
@ -1453,6 +1456,7 @@ class FlashCausalLM(Model):
|
||||
self.limit_hpu_graph = (
|
||||
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
)
|
||||
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
|
||||
self.max_seq_len_to_capture = 8192
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
@ -1521,7 +1525,7 @@ class FlashCausalLM(Model):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
self.graphed_buckets = set()
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||
@ -1533,7 +1537,20 @@ class FlashCausalLM(Model):
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
cache_block_size = cache_block_size * 2
|
||||
total_cache_size = self.num_layers * cache_block_size * dtype_size
|
||||
|
||||
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
|
||||
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
|
||||
graph_reserved_mem = (
|
||||
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
|
||||
if htorch.utils.internal.is_lazy()
|
||||
else 0
|
||||
)
|
||||
mem_used_from_graph = int(
|
||||
(free_memory - self.mem_reserved) * graph_reserved_mem
|
||||
)
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
|
||||
)
|
||||
try:
|
||||
self.init_kv_cache(
|
||||
batch.num_blocks,
|
||||
@ -1548,15 +1565,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
num_tokens = batch.to_pb().current_tokens
|
||||
synchronize(self.device)
|
||||
free_memory = get_free_memory(
|
||||
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
|
||||
)
|
||||
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
log_master(
|
||||
logger.debug,
|
||||
f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
|
||||
)
|
||||
|
||||
_, _batch, _ = self.generate_token([batch])
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
@ -1565,8 +1573,9 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
synchronize(self.device)
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
|
||||
kv_memory = free_memory
|
||||
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
|
||||
|
||||
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int(kv_memory // total_cache_size)
|
||||
@ -1583,7 +1592,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
self.init_kv_cache(
|
||||
num_blocks,
|
||||
self.num_layers,
|
||||
@ -1595,11 +1603,16 @@ class FlashCausalLM(Model):
|
||||
self.max_batch_prefill_tokens = get_max_prefill_tokens()
|
||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
||||
HPUBucketingContext = get_bucketing_context()
|
||||
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
|
||||
# need to warmup one more step since block is allocated from 1
|
||||
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
|
||||
max_total_tokens_aligned = math.ceil(
|
||||
max_total_tokens / BLOCK_SIZE
|
||||
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
|
||||
model_max_length = self.tokenizer.model_max_length
|
||||
max_position_embeddings = getattr(
|
||||
self.config, "max_position_embeddings", model_max_length
|
||||
)
|
||||
|
||||
self.bucketing_ctx = HPUBucketingContext(
|
||||
max_num_seqs,
|
||||
max_num_seqs, # self.max_num_prefill_seqs, #TODO
|
||||
@ -1610,31 +1623,75 @@ class FlashCausalLM(Model):
|
||||
max_input_tokens,
|
||||
max_total_tokens_aligned,
|
||||
)
|
||||
max_blocks = (
|
||||
max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1
|
||||
max_blocks = max(
|
||||
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
|
||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||
synchronize(self.device)
|
||||
if self.skip_warmup:
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
self.bucketing_ctx.generate_decode_buckets(
|
||||
self.bucketing_ctx.num_hpu_blocks
|
||||
)
|
||||
logger.info("skip warmup hpu graph, not recommmended")
|
||||
log_master(
|
||||
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
|
||||
)
|
||||
del _batch, batch
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
self.warmup_hpu_graph(batch)
|
||||
del _batch, batch
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
|
||||
if self.limit_hpu_graph:
|
||||
return prefill
|
||||
else:
|
||||
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
|
||||
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
|
||||
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
|
||||
phase = "Prompt" if prefilling else "Decode"
|
||||
dim = "seq_len" if prefilling else "num_blocks"
|
||||
graphed_bucket = (batch_size, seq_len, prefilling)
|
||||
bypass = graphed_bucket not in self.graphed_buckets
|
||||
msg = (
|
||||
f"[Warmup][{phase}][{i+1}/{max_i}] "
|
||||
f"batch_size:{batch_size} "
|
||||
f"{dim}:{seq_len} "
|
||||
f"bypass:{bypass} "
|
||||
f"free_mem:{free_mem}"
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
|
||||
def use_graphs(self, prefill, seq_len, batch_size):
|
||||
if self.limit_hpu_graph and prefill:
|
||||
return False
|
||||
|
||||
if self.skip_warmup:
|
||||
return True
|
||||
|
||||
return (batch_size, seq_len, prefill) in self.graphed_buckets
|
||||
|
||||
def align_workers(self, value, op):
|
||||
if self.world_size <= 1:
|
||||
return value
|
||||
value_t = torch.tensor(value, device="cpu")
|
||||
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
|
||||
return value_t.item()
|
||||
|
||||
def warmup_hpu_graph(self, batch):
|
||||
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
graph_free_mem = free_mem - self.mem_reserved
|
||||
graph_free_mem = self.align_workers(
|
||||
graph_free_mem, torch.distributed.ReduceOp.MIN
|
||||
)
|
||||
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
|
||||
decode_available_memory = graph_free_mem - prompt_available_memory
|
||||
msg = (
|
||||
f"Using {format_bytes(graph_free_mem)}"
|
||||
f"/{format_bytes(free_mem)} "
|
||||
"of free device memory for HPUGraphs, "
|
||||
f"{format_bytes(prompt_available_memory)} for prompt and "
|
||||
f"{format_bytes(decode_available_memory)} for decode "
|
||||
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
@ -1646,15 +1703,34 @@ class FlashCausalLM(Model):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||
)
|
||||
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = prompt_available_memory
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size * seq_len
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, seq_len, True)
|
||||
if not (
|
||||
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
|
||||
):
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
@ -1663,16 +1739,34 @@ class FlashCausalLM(Model):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = free_mem - self.mem_reserved
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, block_num, False)
|
||||
if not mem_estimate >= available_mem:
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
self.log_warmup(False, i, len(buckets), batch_size, block_num)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
@ -1707,8 +1801,8 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = input_lengths - 1
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
True, input_ids.shape[0]
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
True, prompt_len, batch_size
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
@ -1762,7 +1856,9 @@ class FlashCausalLM(Model):
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
False, hpu_attention_meta.block_list.shape[0], batch_size
|
||||
)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
@ -1858,8 +1954,14 @@ class FlashCausalLM(Model):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
batch.prefilling, input_ids.shape[0]
|
||||
batch_size = input_lengths.shape[0]
|
||||
prompt_len = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, prompt_len, batch_size
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
@ -27,6 +27,7 @@ import time
|
||||
from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
)
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -487,6 +488,19 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
graph_free_mem = free_mem - self.mem_reserved
|
||||
graph_free_mem = self.align_workers(
|
||||
graph_free_mem, torch.distributed.ReduceOp.MIN
|
||||
)
|
||||
decode_available_memory = graph_free_mem
|
||||
msg = (
|
||||
f"Using {format_bytes(graph_free_mem)}"
|
||||
f"/{format_bytes(free_mem)} "
|
||||
"of free device memory for HPUGraphs, "
|
||||
f"{format_bytes(decode_available_memory)} for decode "
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
@ -499,16 +513,34 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = decode_available_memory
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, block_num, False)
|
||||
if not mem_estimate >= available_mem:
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
self.log_warmup(False, i, len(buckets), batch_size, block_num)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
@ -585,8 +617,15 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||
|
||||
batch_size = input_lengths.shape[0]
|
||||
seqlen = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
|
@ -33,6 +33,8 @@ from text_generation_server.utils.import_utils import (
|
||||
import torch.nn.functional as F
|
||||
from text_generation_server.utils.log import log_master
|
||||
import time
|
||||
import os
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -268,6 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
cross_attention_states, image_indices, input_lengths, 1, False
|
||||
)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
False, hpu_attention_meta.block_list.shape[0], batch_size
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
@ -281,6 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
cross_attention_states=cross_attention_states,
|
||||
indices=_async_h2d_tensor_copy(indices),
|
||||
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def warmup_prefill(
|
||||
@ -326,8 +334,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
True, input_ids.shape[0]
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
True, prompt_len, batch_size
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
@ -346,6 +354,23 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
graph_free_mem = free_mem - self.mem_reserved
|
||||
graph_free_mem = self.align_workers(
|
||||
graph_free_mem, torch.distributed.ReduceOp.MIN
|
||||
)
|
||||
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
|
||||
decode_available_memory = graph_free_mem - prompt_available_memory
|
||||
msg = (
|
||||
f"Using {format_bytes(graph_free_mem)}"
|
||||
f"/{format_bytes(free_mem)} "
|
||||
"of free device memory for HPUGraphs, "
|
||||
f"{format_bytes(prompt_available_memory)} for prompt and "
|
||||
f"{format_bytes(decode_available_memory)} for decode "
|
||||
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
|
||||
)
|
||||
log_master(logger.info, msg)
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
@ -357,14 +382,35 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||
)
|
||||
graph_free_mem
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = prompt_available_memory
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size * seq_len
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, seq_len, True)
|
||||
if not (
|
||||
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
|
||||
):
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
@ -373,16 +419,34 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
||||
total_batch_seq = 0.001
|
||||
total_mem = 0
|
||||
available_mem = free_mem - self.mem_reserved
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
# Graph memory usage is proportional to seq dimension in a batch
|
||||
batch_seq = batch_size
|
||||
mem_estimate = batch_seq / total_batch_seq * total_mem
|
||||
graphed_bucket = (batch_size, block_num, False)
|
||||
if not mem_estimate >= available_mem:
|
||||
if graphed_bucket not in self.graphed_buckets:
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
warmup_shape_count += 1
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
self.log_warmup(False, i, len(buckets), batch_size, block_num)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
used_mem = self.align_workers(
|
||||
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
if graphed_bucket in self.graphed_buckets:
|
||||
available_mem -= used_mem
|
||||
total_mem += used_mem
|
||||
total_batch_seq += batch_seq
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
@ -462,9 +526,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
batch.prefilling, input_ids.shape[0]
|
||||
batch_size = input_lengths.shape[0]
|
||||
seqlen = (
|
||||
input_ids.shape[0] // batch_size
|
||||
if batch.prefilling
|
||||
else batch.hpu_attn_meta.block_list.shape[0]
|
||||
)
|
||||
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
|
||||
batch.prefilling, seqlen, batch_size
|
||||
)
|
||||
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
slots_pad[batch.prefill_cache_indices] = slots
|
||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
|
@ -1,20 +1,9 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
import habana_frameworks.torch as htorch
|
||||
import os
|
||||
|
||||
|
||||
def get_hpu_free_memory(device, memory_fraction):
|
||||
graph_reserved_mem = (
|
||||
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
|
||||
if htorch.utils.internal.is_lazy()
|
||||
else 0
|
||||
)
|
||||
free_memory = int(
|
||||
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
|
||||
)
|
||||
logger.info(f"Free memory on device {device}: {free_memory} bytes.")
|
||||
return free_memory
|
||||
free_hpu_memory, _ = torch.hpu.mem_get_info()
|
||||
return free_hpu_memory
|
||||
|
||||
|
||||
def synchronize_hpu(device):
|
||||
|
Loading…
Reference in New Issue
Block a user