update to latest vllm extension ops

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-08 23:43:41 -07:00
parent 316cb087f3
commit 249ccfc939
3 changed files with 25 additions and 19 deletions

View File

@ -61,6 +61,7 @@ ENV ATTENTION=default
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -96,7 +97,8 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router

View File

@ -72,7 +72,7 @@ from text_generation_server.utils.import_utils import (
import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.bucketing.common import get_bucketing_context
tracer = trace.get_tracer(__name__)
@ -1497,6 +1497,11 @@ class FlashCausalLM(Model):
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
if os.environ.get("MAX_BATCH_SIZE") is None:
raise RuntimeError(
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
"using `--max-batch-size xxx`"
)
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
@ -1566,25 +1571,28 @@ class FlashCausalLM(Model):
self.device,
)
self.max_batch_prefill_tokens = max_input_tokens * len(batch)
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens)
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None:
max_total_blocks = (
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1
)
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context()
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
self.bucketing_ctx = HPUBucketingContext(
max_num_seqs,
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE,
num_blocks * BLOCK_SIZE,
max_num_seqs * max_total_tokens_aligned,
False,
self.tokenizer.model_max_length,
max_input_tokens,
max_total_tokens_aligned,
)
self.bucketing_ctx.num_hpu_blocks = num_blocks
max_blocks = max(
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
)
self.bucketing_ctx.num_hpu_blocks = max_blocks
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
@ -1606,8 +1614,6 @@ class FlashCausalLM(Model):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)

View File

@ -350,8 +350,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)