Deepseek R1 for Gaudi backend (#3211)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-19 22:36:39 +08:00 committed by GitHub
parent 58934c8b61
commit d658b5def3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1133 additions and 238 deletions

View File

@ -60,6 +60,8 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor
ENV ATTENTION=default ENV ATTENTION=default
ENV PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0 ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
@ -95,7 +97,8 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router

View File

@ -26,6 +26,11 @@ class Dtype(str, Enum):
bloat16 = "bfloat16" bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
@ -34,6 +39,7 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -93,7 +99,8 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}") kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
if dtype is not None and quantize not in { if dtype is not None and quantize not in {
None, None,
"bitsandbytes", "bitsandbytes",
@ -175,6 +182,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens, max_input_tokens,

View File

@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods. # Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.lora import ( from text_generation_server.layers.lora import (
LoraLinear, LoraLinear,
@ -27,6 +28,7 @@ __all__ = [
"TensorParallelEmbedding", "TensorParallelEmbedding",
"SpeculativeHead", "SpeculativeHead",
"LoraLinear", "LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear", "TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear", "TensorParallelAdapterRowLinear",
"load_layer_norm", "load_layer_norm",

View File

@ -10,18 +10,21 @@ from .hpu import (
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
paged_attention_mla,
) )
# KVCache needs `reshape_and_cache`, so ensure that it is defined already. # KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales from .kv_cache import KVCache, get_kv_scales, KVCompressCache
__all__ = [ __all__ = [
"attention", "attention",
"get_kv_scales", "get_kv_scales",
"paged_attention", "paged_attention",
"paged_attention_mla",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache", "KVCache",
"KVCompressCache",
"Seqlen", "Seqlen",
"HPUPagedAttentionMetadata", "HPUPagedAttentionMetadata",
"trim_seqlen_metadata", "trim_seqlen_metadata",

View File

@ -11,11 +11,61 @@ import os
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks): class FP8Matmul(torch.nn.Module):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)] def __init__(self, scale_other):
else: super().__init__()
return cache.index_select(0, blocks) self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def attention( def attention(
@ -67,6 +117,7 @@ def paged_attention(
hpu_attention_meta: HPUPagedAttentionMetadata, hpu_attention_meta: HPUPagedAttentionMetadata,
): ):
batch_size, head_num, head_size = query.shape batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa( output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size), query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key, key_cache=kv_cache.key,
@ -76,19 +127,50 @@ def paged_attention(
block_bias=hpu_attention_meta.attn_bias, block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups, block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale, scale=softmax_scale,
matmul_qk_op=Matmul(), matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(), batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(), block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache, keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=fetch_from_cache, values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
) )
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, head_num, head_size) return output.view(batch_size, head_num, head_size)
__all__ = [ def paged_attention_mla(
"SUPPORTS_WINDOWING", query: torch.Tensor,
"attention", kv_cache: KVCache,
"paged_attention", kv_head_mapping: torch.Tensor,
] softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]

View File

@ -50,6 +50,8 @@ class KVCache:
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
## TODO FP8 kv cache support ## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = ( self.kv_cache = (
torch.zeros( torch.zeros(
@ -101,22 +103,92 @@ class KVCache:
key_cache, key_cache,
value_cache, value_cache,
slots, slots,
kv_scales.key_scale_cpu, kv_scales.key_scale,
kv_scales.value_scale_cpu, kv_scales.value_scale,
) )
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks, BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
def paged_reshape_and_cache( def paged_reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0, k_scale: torch.Tensor,
v_scale: float = 1.0, v_scale: torch.Tensor,
): ):
block_idx = slots // BLOCK_SIZE block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE block_offset = slots % BLOCK_SIZE
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)

View File

@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn quant_dtype: torch.dtype = torch.float8_e4m3fn
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
if is_hpu_gaudi2():
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n
if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
)
return padded_weight, M, N # Return original dimensions for unpadding
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
assert len(block_size) == 2
block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
return weight, orig_M, orig_N
def dynamic_quant(data, single_scale=False):
if single_scale:
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
else:
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
scale = scale.unsqueeze(-1)
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
data, 1.0 / scale, False, False, torch.float8_e4m3fn
)[0]
return data_fp8, scale.float()
def dequant_block_fp8_weight_naive(
weight,
weight_scale,
block_size,
dtype=torch.bfloat16,
original_M=None,
original_N=None,
do_unpad=False,
):
if weight_scale is None:
return weight
assert len(block_size) == 2
weight_shape_len = len(weight.shape)
block_size_m, block_size_n = block_size
# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")
if do_unpad:
dequant_weight = unpad_weight(
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
)
return dequant_weight
def apply_block_fp8_linear_hpu_dynamic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
x_fp8, x_scale = dynamic_quant(input_2d)
output = torch.ops.hpu.fp8_gemm_v2(
x_fp8,
False,
weight,
True,
None,
torch.bfloat16,
x_scale,
weight_scale,
None,
False,
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
@ -42,7 +182,7 @@ def per_tensor_dequantize(
) -> torch.Tensor: ) -> torch.Tensor:
device = tensor.device device = tensor.device
dtype = torch.bfloat16 dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: if is_hpu_gaudi2():
# dequant on cpu to avoid nan on gaudi2 # dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu") tensor = tensor.to("cpu")
@ -269,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
for p in prefixes
]
scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
@ -389,6 +589,22 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound = kwargs.get("scale_upper_bound", None) scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None) weight_block_size = kwargs.get("weight_block_size", None)
if weight_block_size is not None:
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
weight, scale, weight_block_size
)
weight, scale = dynamic_quant(
dequant_block_fp8_weight_naive(
weight,
scale,
weight_block_size,
original_M=orig_M,
original_N=orig_N,
do_unpad=True,
)
)
scale = scale.squeeze(-1)
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
@ -409,25 +625,10 @@ class Fp8Linear(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None: if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437 return apply_block_fp8_linear_hpu_dynamic(
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and input, self.qweight, self.scale, self.input_scale, self.bias
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
) )
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, input,
self.input_scale, self.input_scale,

View File

@ -4,7 +4,12 @@ from typing import List, Optional, Union
import torch import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
from .hpu import QuantLinear from .hpu import QuantLinear
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str, quant_method: str,
quantize: str, quantize: str,
sym: bool, sym: bool,
modules_to_not_convert: List[str],
): ):
self.bits = bits self.bits = bits
self.desc_act = desc_act self.desc_act = desc_act
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
self.quant_method = quant_method self.quant_method = quant_method
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights(self, weights: Weights, prefix: str): def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights) self._get_gptq_params(weights)
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try: try:
qweight = weights.get_tensor(f"{prefix}.qweight") qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader):
if self.bits != 4: if self.bits != 4:
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
if self.desc_act: if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False

View File

@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
from vllm_hpu_extension.kernels import rms_norm
orig_shape = hidden_states.shape
if residual is not None: if residual is not None:
residual += hidden_states.view(residual.shape) hidden_states += residual
else: residual = hidden_states
residual = hidden_states hidden_states = hidden_states.to(torch.float32)
# Note: HPUFusedRMSNorm requires 3D tensors as inputs variance = hidden_states.pow(2).mean(-1, keepdim=True)
if len(orig_shape) == 2: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
residual = residual.unsqueeze(0) return self.weight * hidden_states.to(self.weight.dtype), residual
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)

View File

@ -2,6 +2,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import os
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import ( from text_generation_server.layers.fp8 import (
@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import (
fp8_quantize, fp8_quantize,
quant_dtype, quant_dtype,
normalize_e4m3fn_to_native_float8, normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
) )
from text_generation_server.layers.moe.fused_moe import select_experts
try: import habana_frameworks.torch as htorch
from .unquantized import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module): class FP8SparseMoELayer(nn.Module):
@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
else:
self.ep_offset = 0
( (
self.gate_up_proj, self.gate_up_proj,
@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module):
gate_proj_name=gate_proj_name, gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name, up_proj_name=up_proj_name,
weights=weights, weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
) )
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module):
n_experts=n_experts, n_experts=n_experts,
name=down_proj_name, name=down_proj_name,
weights=weights, weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
) )
) )
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe( topk_weights, topk_ids = select_experts(
x, hidden_states=x,
w1=self.gate_up_proj, router_logits=gating_output,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None, use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group, top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
) )
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
if self.use_ep:
moe_n_slice = 1
n_expert_slice = (
total_num_experts + self.world_size - 1
) // self.world_size
else:
moe_n_slice = 1
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert + self.ep_offset,
experts_max=max_expert + self.ep_offset - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights( def _load_expert_weights(
@ -98,13 +170,14 @@ def _load_expert_weights(
n_experts: int, n_experts: int,
name: str, name: str,
weights: Weights, weights: Weights,
ep_offset: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
all_weight = None all_weight = None
all_weight_scales = None all_weight_scales = None
max_input_scale = None max_input_scale = None
for i in range(n_experts): for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights) weight = get_weight_fn(prefix, i + ep_offset, name, weights)
assert isinstance(weight, Fp8Weight) assert isinstance(weight, Fp8Weight)
@ -147,14 +220,26 @@ def _load_expert_multi_weights_col(
gate_proj_name: str, gate_proj_name: str,
up_proj_name: str, up_proj_name: str,
weights: Weights, weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights): def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_multi_weights_col( return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
) )
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights( return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=None,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
) )
@ -164,10 +249,20 @@ def _load_expert_weights_row(
n_experts: int, n_experts: int,
name: str, name: str,
weights: Weights, weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights): def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}") return weights.get_weights_row(f"{prefix}.{i}.{name}")
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights(f"{prefix}.{i}.{name}")
return _load_expert_weights( return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=name,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
) )

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Tuple, Optional
import torch import torch
@ -25,12 +25,36 @@ def grouped_topk(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1) assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = ( if e_score_correction_bias is not None:
scores.view(num_token, num_expert_group, -1).max(dim=-1).values # Store original scores before applying correction bias. We use biased
) # [n, n_group] # scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1 1
] # [n, top_k_group] ] # [n, top_k_group]
@ -41,13 +65,19 @@ def grouped_topk(
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1) .reshape(num_token, -1)
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk( def fused_topk(
@ -63,3 +93,39 @@ def fused_topk(
if renormalize: if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids

View File

@ -4,7 +4,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
class UnquantizedSparseMoELayer(nn.Module): class UnquantizedSparseMoELayer(nn.Module):
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights, weights=weights,
) )
self.hpu_fused_moe = DynamicFusedMOE(n_experts) self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
for i in range(n_experts): for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.hpu_fused_moe(x, gating_output, self.topk) htorch.core.mark_step()
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(
routing_weights, self.topk, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
final_hidden_states = self.MoeOp(
hidden_states=x,
expert_routing_table=selected_experts,
router_weights=routing_weights,
permuted_weights=True,
activation="silu",
)
return final_hidden_states.view(-1, x.shape[1])
def _load_expert_multi_weights_col( def _load_expert_multi_weights_col(

View File

@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float, mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
/ get_mscale(self.scaling_factor, mscale_all_dim) / get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation ) # Get n-d magnitude scaling corrected for interpolation
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,

View File

@ -343,6 +343,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[torch.dtype], dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
@ -468,7 +469,12 @@ def get_model(
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
kv_cache_dtype = dtype if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if FLASH_ATTENTION: if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
@ -934,6 +940,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[torch.dtype], dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
adapter_to_index: Dict[str, int], adapter_to_index: Dict[str, int],
@ -947,6 +954,7 @@ def get_model_with_lora_adapters(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) )

View File

@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
class CohereRotary(PositionRotaryEmbedding): class CohereRotary(PositionRotaryEmbedding):
def forward( def forward(
@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
import habana_frameworks.torch as htorch
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -28,11 +28,12 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
Fp8Linear,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention_mla,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -40,6 +41,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV3Config(PretrainedConfig): class DeepseekV3Config(PretrainedConfig):
@ -249,6 +263,44 @@ class DeepseekV3Attention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -261,14 +313,9 @@ class DeepseekV3Attention(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
if self.q_lora_rank is None: if self.q_lora_rank is None:
query = self.q_proj(hidden_states) hidden_states_or_q_c = hidden_states
else: else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split( compressed_kv, key_pe = torch.split(
@ -276,13 +323,18 @@ class DeepseekV3Attention(torch.nn.Module):
) )
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split( # Prefill
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 if cu_seqlen_prefill is not None:
) q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape batch_size, heads, head_dim = query_pe.shape
query_pe = ( query_pe = (
@ -297,33 +349,47 @@ class DeepseekV3Attention(torch.nn.Module):
.reshape(batch_size, heads, head_dim) .reshape(batch_size, heads, head_dim)
) )
self.rotary_emb(query_pe, key_pe, cos, sin) self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
query[..., self.qk_nope_head_dim :] = query_pe latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store( kv_cache.store(
key=key, key=latent_vec_k,
value=value, value=None,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
# Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query=query, query=query,
@ -334,9 +400,15 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen=seqlen, seqlen=seqlen,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) )
# Decode attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else: else:
attn_output = paged_attention( # Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query, query,
kv_cache, kv_cache,
self.kv_head_mapping, self.kv_head_mapping,
@ -344,14 +416,10 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen, seqlen,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
) )
attn_output = self._v_up_proj_and_o_proj(attn_output)
# Remove padding. return attn_output
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV3MLP(nn.Module): class DeepseekV3MLP(nn.Module):
@ -584,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -596,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Gemma2Config(PretrainedConfig): class Gemma2Config(PretrainedConfig):
@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention.kv_cache import get_kv_scales
import habana_frameworks.torch as htorch
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)

View File

@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, RotaryPosEmbeddingMode,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
def load_attention(config, prefix: str, weights): def load_attention(config, prefix: str, weights):
@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -26,7 +26,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
KVCache, KVCache,
get_kv_scales, get_kv_scales,
@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module):
cross_attention_states, cross_attention_states,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
import habana_frameworks.torch as htorch
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GPTNeoXConfig(TransformersGPTNeoXConfig): class GPTNeoXConfig(TransformersGPTNeoXConfig):
@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)

View File

@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
import habana_frameworks.torch as htorch
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -18,7 +18,6 @@
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
import habana_frameworks.torch as htorch
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module):
) )
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states = layer( hidden_states = layer(
hidden_states, hidden_states,
@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)

View File

@ -21,6 +21,7 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
import habana_frameworks.torch as htorch
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
import habana_frameworks.torch as htorch
def load_multi_mqa( def load_multi_mqa(
@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
KVCache, KVCache,
KVCompressCache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
trim_attn_metadata, trim_attn_metadata,
@ -68,11 +69,13 @@ from text_generation_server.utils.import_utils import (
synchronize, synchronize,
get_free_memory, get_free_memory,
) )
from text_generation_server.utils.prefill_chunking import (
get_max_prefill_tokens,
)
import vllm_hpu_extension.environment as environment import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import itertools import itertools
from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.bucketing.common import get_bucketing_context
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -153,7 +156,7 @@ def prepare_for_decode(
block_groups_device, num_classes=batch_size block_groups_device, num_classes=batch_size
) )
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1) mask = mask >= block_usage_device.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata( return trim_attn_metadata(
HPUPagedAttentionMetadata( HPUPagedAttentionMetadata(
@ -425,7 +428,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device # Create tensors on device
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
@ -1438,15 +1443,17 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None self.bucketing_ctx = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config) environment.set_model_config(self.config)
self.use_contiguous_pa = ( self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
) )
self.limit_hpu_graphs = ( self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
) )
self.max_seq_len_to_capture = 8192
super().__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
@ -1478,16 +1485,27 @@ class FlashCausalLM(Model):
): ):
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
self.kv_cache = [ if self.config.model_type == "deepseek_v3":
KVCache( self.kv_cache = [
num_blocks=num_blocks, KVCompressCache(
num_heads=num_heads, num_blocks=num_blocks,
head_size=head_size, head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
else:
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
num_heads=num_heads,
head_size=head_size,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
def warmup( def warmup(
self, self,
@ -1495,6 +1513,11 @@ class FlashCausalLM(Model):
max_input_tokens: Optional[int], max_input_tokens: Optional[int],
max_total_tokens: Optional[int], max_total_tokens: Optional[int],
): ):
if os.environ.get("MAX_BATCH_SIZE") is None:
raise RuntimeError(
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
"using `--max-batch-size xxx`"
)
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
@ -1502,8 +1525,14 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size if self.config.model_type == "deepseek_v3":
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)
else:
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
try: try:
self.init_kv_cache( self.init_kv_cache(
@ -1563,25 +1592,33 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: HPUBucketingContext = get_bucketing_context()
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: model_max_length = self.tokenizer.model_max_length
max_total_blocks = ( max_position_embeddings = getattr(
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 self.config, "max_position_embeddings", model_max_length
) )
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
self.bucketing_ctx = HPUBucketingContext( self.bucketing_ctx = HPUBucketingContext(
max_num_seqs, max_num_seqs,
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE, BLOCK_SIZE,
num_blocks * BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned,
False, False,
min(model_max_length, max_position_embeddings),
max_input_tokens,
max_total_tokens_aligned,
) )
self.bucketing_ctx.num_hpu_blocks = num_blocks max_blocks = (
max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1
)
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
logger.info("skip warmup hpu graph, not recommmended") logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
@ -1591,28 +1628,55 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
if self.limit_hpu_graph:
return prefill
else:
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
def warmup_hpu_graph(self, batch): def warmup_hpu_graph(self, batch):
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) def ordering_function_min_tokens(b):
): return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
warmup_shape_count += 1
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
warmup_shape_count += 1
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def warmup_prefill( def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
@ -1643,7 +1707,9 @@ class FlashCausalLM(Model):
lm_head_indices = input_lengths - 1 lm_head_indices = input_lengths - 1
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1792,8 +1858,8 @@ class FlashCausalLM(Model):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = ( kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling if self.limit_hpu_graphs else False batch.prefilling, input_ids.shape[0]
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
@ -1836,9 +1902,7 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
speculative_ids, speculative_ids,
) = batch.next_token_chooser( ) = batch.next_token_chooser(
_async_h2d_tensor_copy( batch.all_input_ids_tensor[:, : batch.max_current_length],
batch.all_input_ids_tensor[:, : batch.max_current_length]
),
batch.next_token_logits, batch.next_token_logits,
speculate, speculate,
batch.speculative_ids, batch.speculative_ids,
@ -1852,7 +1916,6 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
) )
if batch.valid_indices is not None: if batch.valid_indices is not None:
next_input_ids = next_input_ids.cpu()
next_token_logprobs = next_token_logprobs.cpu() next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu() accepted_ids = accepted_ids.cpu()
batch.all_input_ids_tensor = batch.all_input_ids_tensor[ batch.all_input_ids_tensor = batch.all_input_ids_tensor[
@ -1902,7 +1965,6 @@ class FlashCausalLM(Model):
accepted_ids = accepted_ids.cpu() accepted_ids = accepted_ids.cpu()
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
next_input_ids = next_input_ids.cpu()
if batch.speculative_logits is not None: if batch.speculative_logits is not None:
for i in range(len(batch)): for i in range(len(batch)):
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[
@ -1914,7 +1976,7 @@ class FlashCausalLM(Model):
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else: else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor) index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange( batch_idx = torch.arange(
0, 0,
batch.all_input_ids_tensor.shape[0], batch.all_input_ids_tensor.shape[0],
@ -1924,6 +1986,7 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor.index_put_( batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids (batch_idx, index.long()), next_input_ids
) )
next_input_ids = next_input_ids.cpu()
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
if batch.position_ids.dim() == 2: if batch.position_ids.dim() == 2:

View File

@ -23,6 +23,7 @@ from text_generation_server.layers.attention import (
_async_h2d_tensor_copy, _async_h2d_tensor_copy,
) )
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import time
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
synchronize, synchronize,
) )
@ -486,20 +487,32 @@ class FlashVlmCausalLM(FlashCausalLM):
) )
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless # only warmup decode, for prefill, image pixal size may change, make the warmup useless
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
warmup_shape_count += 1
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward( def forward(
self, self,

View File

@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import (
) )
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
import time
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -325,7 +326,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
)
self.model.forward( self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
@ -343,26 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) def ordering_function_min_tokens(b):
): return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
warmup_shape_count += 1
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
warmup_shape_count += 1
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward( def forward(
self, self,
@ -438,8 +462,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = ( kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling if self.limit_hpu_graphs else False batch.prefilling, input_ids.shape[0]
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids)

View File

@ -206,6 +206,7 @@ def serve(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_input_tokens: int, max_input_tokens: int,
@ -218,6 +219,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if not is_driver_compatible(): if not is_driver_compatible():
@ -261,6 +263,7 @@ def serve(
quantize, quantize,
speculate, speculate,
data_type, data_type,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index, adapter_to_index,
@ -308,6 +311,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
) )
) )

View File

@ -7,7 +7,7 @@ from loguru import logger
# Tensor Parallelism settings # Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0")) RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
class FakeBarrier: class FakeBarrier:

View File

@ -1,17 +1,19 @@
import torch import torch
from loguru import logger from loguru import logger
import habana_frameworks.torch as htorch
import os
def get_hpu_free_memory(device, memory_fraction): def get_hpu_free_memory(device, memory_fraction):
from habana_frameworks.torch.hpu import memory_stats graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
device_id = device.index if htorch.utils.internal.is_lazy()
mem_stats = memory_stats(device_id) else 0
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max(
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
) )
free_memory = int(
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
)
logger.info(f"Free memory on device {device}: {free_memory} bytes.")
return free_memory return free_memory

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, List
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
@ -18,6 +18,8 @@ class _QuantizerConfig:
groupsize: int groupsize: int
quant_method: str quant_method: str
sym: bool sym: bool
weight_block_size: Optional[List[int]]
modules_to_not_convert: List[str]
@dataclass @dataclass
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
activation_scale_ub: float activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization, def _get_config_json(model_id: str, revision: Optional[str], filename: str):
if os.path.exists(
os.path.join(
model_id,
)
):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
return json.load(f)
# We should probably do this with Pydantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params. # but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision): def _get_quantizer_config(model_id, revision):
bits = 4 bits = 4
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = None checkpoint_format = None
sym = False sym = False
desc_act = False desc_act = False
weight_block_size = None
modules_to_not_convert = []
filename = "config.json" filename = "config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
# FP8 config # FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8": if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig( return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]: if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"] sym = not data["quantization_config"]["zero_point"]
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"] bits = data["bits"]
groupsize = data["group_size"] groupsize = data["group_size"]
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"] bits = data["w_bit"]
groupsize = data["q_group_size"] groupsize = data["q_group_size"]
desc_act = data["desc_act"] desc_act = data["desc_act"]
@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format=checkpoint_format, checkpoint_format=checkpoint_format,
sym=sym, sym=sym,
desc_act=desc_act, desc_act=desc_act,
weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
) )
@ -134,6 +139,7 @@ def get_loader(
quant_method=quantizer_config.quant_method, quant_method=quantizer_config.quant_method,
quantize=quantize, quantize=quantize,
sym=quantizer_config.sym, sym=quantizer_config.sym,
modules_to_not_convert=quantizer_config.modules_to_not_convert,
) )
elif quantize == "fp8" or quantize is None: elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
@ -141,9 +147,14 @@ def get_loader(
# Since the default for the quantize config is _QuantizerConfig, # Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error # we need to add this check to not get an attribute error
activation_scale_ub = None activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig): if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else: else:
raise ValueError(f"Unknown quantization method: {quantize}") raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -62,6 +62,14 @@ class WeightsLoader(ABC):
""" """
... ...
@abstractmethod
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod @abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
""" """
@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader):
weights.get_sharded(f"{prefix}.weight", dim=1), weights.get_sharded(f"{prefix}.weight", dim=1),
) )
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
class Weights: class Weights:
def __init__( def __init__(
@ -393,6 +405,9 @@ class Weights:
def get_weights_row(self, prefix: str): def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix) return self.weights_loader.get_weights_row(self, prefix)
def get_multi_weights(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights(self, prefixes, dim)
@contextmanager @contextmanager
def use_loader(self, weights_loader: WeightsLoader): def use_loader(self, weights_loader: WeightsLoader):
""" """

View File

@ -8,6 +8,7 @@ use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
use text_generation_router::usage_stats::Env;
use text_generation_router::validation::{ use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters, ValidStoppingParameters,
@ -185,6 +186,9 @@ struct State {
/// Paged Attention Block Allocation /// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>, block_allocator: Option<BlockAllocator>,
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
is_hpu_device: bool,
} }
impl State { impl State {
@ -214,6 +218,7 @@ impl State {
speculate, speculate,
support_chunking, support_chunking,
block_allocator, block_allocator,
is_hpu_device: Env::new().is_hpu_device(),
} }
} }
@ -368,6 +373,21 @@ impl State {
} }
} }
if self.is_hpu_device {
//HPU needs to pad for the prefill
max_input_length = max_input_length.max(entry.request.input_length);
let actual_prefill_tokens_for_hpu =
(batch.len() + 1) as u32 * max_input_length;
if actual_prefill_tokens_for_hpu > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len; prefill_tokens += postfix_len;
Some(block_allocation) Some(block_allocation)