mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
refine free memory and bypass graph logic
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
1cda91135e
commit
ff5bc1bbd1
@ -1445,9 +1445,10 @@ class FlashCausalLM(Model):
|
||||
self.use_contiguous_pa = (
|
||||
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
|
||||
)
|
||||
self.limit_hpu_graphs = (
|
||||
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
|
||||
self.limit_hpu_graph = (
|
||||
os.environ.get("LIMIT_HPU_GRAPH", "true").lower() == "true"
|
||||
)
|
||||
self.max_seq_len_to_capture = 8192
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
@ -1564,6 +1565,7 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
self.max_batch_prefill_tokens = max_input_tokens * len(batch)
|
||||
|
||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
|
||||
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
|
||||
@ -1592,15 +1594,23 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
|
||||
if self.limit_hpu_graph:
|
||||
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
|
||||
return False
|
||||
|
||||
def warmup_hpu_graph(self, batch):
|
||||
warmup_times = 3
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
@ -1644,7 +1654,9 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = input_lengths - 1
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
True, input_ids.shape[0]
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -1793,8 +1805,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = (
|
||||
batch.prefilling if self.limit_hpu_graphs else False
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
batch.prefilling, input_ids.shape[0]
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
|
@ -325,7 +325,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
True, input_ids.shape[0]
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
@ -348,9 +350,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
@ -438,8 +443,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = (
|
||||
batch.prefilling if self.limit_hpu_graphs else False
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
batch.prefilling, input_ids.shape[0]
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
|
@ -1,17 +1,19 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
import habana_frameworks.torch as htorch
|
||||
import os
|
||||
|
||||
|
||||
def get_hpu_free_memory(device, memory_fraction):
|
||||
from habana_frameworks.torch.hpu import memory_stats
|
||||
|
||||
device_id = device.index
|
||||
mem_stats = memory_stats(device_id)
|
||||
logger.info(f"mem_stats: {mem_stats}")
|
||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||
free_memory = max(
|
||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||
graph_reserved_mem = (
|
||||
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
|
||||
if htorch.utils.internal.is_lazy()
|
||||
else 0
|
||||
)
|
||||
free_memory = int(
|
||||
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
|
||||
)
|
||||
logger.info(f"Free memory on device {device}: {free_memory} bytes, ")
|
||||
return free_memory
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user