mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
Deepseek R1 for Gaudi backend (#3211)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
58934c8b61
commit
d658b5def3
@ -60,6 +60,8 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor
|
||||
ENV ATTENTION=default
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV PT_HPU_LAZY_MODE=1
|
||||
ENV PT_HPU_WEIGHT_SHARING=0
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
@ -95,7 +97,8 @@ RUN cd server && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
|
||||
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
|
@ -26,6 +26,11 @@ class Dtype(str, Enum):
|
||||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
class KVCacheDtype(str, Enum):
|
||||
fp8_e4m3fn = "fp8_e4m3fn"
|
||||
fp8_e5m2 = "fp8_e5m2"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
@ -34,6 +39,7 @@ def serve(
|
||||
quantize: Optional[Quantization] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
@ -93,7 +99,8 @@ def serve(
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = "bfloat16" if dtype is None else dtype.value
|
||||
logger.info(f"quantize={quantize}")
|
||||
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
@ -175,6 +182,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
|
@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
||||
# Just to add the `load` methods.
|
||||
from text_generation_server.layers.layernorm import load_layer_norm
|
||||
from text_generation_server.layers.conv import load_conv2d
|
||||
from text_generation_server.layers.fp8 import Fp8Linear
|
||||
|
||||
from text_generation_server.layers.lora import (
|
||||
LoraLinear,
|
||||
@ -27,6 +28,7 @@ __all__ = [
|
||||
"TensorParallelEmbedding",
|
||||
"SpeculativeHead",
|
||||
"LoraLinear",
|
||||
"Fp8Linear",
|
||||
"TensorParallelMultiAdapterLinear",
|
||||
"TensorParallelAdapterRowLinear",
|
||||
"load_layer_norm",
|
||||
|
@ -10,18 +10,21 @@ from .hpu import (
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
)
|
||||
|
||||
|
||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||
from .kv_cache import KVCache, get_kv_scales
|
||||
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
"get_kv_scales",
|
||||
"paged_attention",
|
||||
"paged_attention_mla",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"KVCompressCache",
|
||||
"Seqlen",
|
||||
"HPUPagedAttentionMetadata",
|
||||
"trim_seqlen_metadata",
|
||||
|
@ -11,11 +11,61 @@ import os
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
|
||||
def fetch_from_cache(cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
return cache[: blocks.size(0)]
|
||||
else:
|
||||
return cache.index_select(0, blocks)
|
||||
class FP8Matmul(torch.nn.Module):
|
||||
|
||||
def __init__(self, scale_other):
|
||||
super().__init__()
|
||||
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
|
||||
self.scale_other = scale_other
|
||||
|
||||
def quant_input(self, x, scale):
|
||||
return torch.ops.hpu.cast_to_fp8_v2(
|
||||
x, scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
|
||||
def matmul_fp8(
|
||||
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
|
||||
):
|
||||
return torch.ops.hpu.fp8_gemm_v2(
|
||||
A=x,
|
||||
trans_A=False,
|
||||
B=other,
|
||||
trans_B=False,
|
||||
D=None,
|
||||
out_dtype=out_dtype,
|
||||
A_scale_inv=scale_input_inv,
|
||||
B_scale_inv=scale_other_inv,
|
||||
bias=None,
|
||||
accumulate=False,
|
||||
)
|
||||
|
||||
def forward(self, input, other):
|
||||
qinput = self.quant_input(input, self.scale_input)
|
||||
qother = self.quant_input(other, self.scale_other)
|
||||
output = self.matmul_fp8(
|
||||
qinput,
|
||||
qother,
|
||||
out_dtype=torch.bfloat16,
|
||||
scale_input_inv=1.0 / self.scale_input,
|
||||
scale_other_inv=1.0 / self.scale_other,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FetchFromCache(torch.nn.Module):
|
||||
|
||||
def __init__(self, scale_inv):
|
||||
super().__init__()
|
||||
self.scale_inv = scale_inv
|
||||
|
||||
def forward(self, cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
out = cache[: blocks.size(0)]
|
||||
else:
|
||||
out = cache.index_select(0, blocks)
|
||||
if out.dtype == torch.float8_e4m3fn:
|
||||
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
|
||||
return out
|
||||
|
||||
|
||||
def attention(
|
||||
@ -67,6 +117,7 @@ def paged_attention(
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||
output = ops.flat_pa(
|
||||
query=query.view(batch_size, 1, head_num * head_size),
|
||||
key_cache=kv_cache.key,
|
||||
@ -76,19 +127,50 @@ def paged_attention(
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=Matmul(),
|
||||
matmul_av_op=Matmul(),
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
batch2block_matmul_op=Matmul(),
|
||||
block2batch_matmul_op=Matmul(),
|
||||
keys_fetch_func=fetch_from_cache,
|
||||
values_fetch_func=fetch_from_cache,
|
||||
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, head_num, head_size)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
]
|
||||
def paged_attention_mla(
|
||||
query: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
seqlen: Seqlen,
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
kv_lora_rank: int = 0,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
||||
output = ops.flat_pa_mla(
|
||||
query=query,
|
||||
key_cache=kv_cache.key,
|
||||
value_cache=None,
|
||||
block_list=hpu_attention_meta.block_list,
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
batch2block_matmul_op=Matmul(),
|
||||
block2batch_matmul_op=Matmul(),
|
||||
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||
values_fetch_func=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, head_num, -1)
|
||||
|
||||
|
||||
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|
||||
|
@ -50,6 +50,8 @@ class KVCache:
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
## TODO FP8 kv cache support
|
||||
if dtype is torch.float8_e5m2:
|
||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
@ -101,22 +103,92 @@ class KVCache:
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
kv_scales.key_scale,
|
||||
kv_scales.value_scale,
|
||||
)
|
||||
|
||||
|
||||
class KVCompressCache(KVCache):
|
||||
"""
|
||||
Key-value cache for attention layers.
|
||||
"""
|
||||
|
||||
kv_cache: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_blocks: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
## TODO FP8 kv cache support
|
||||
if dtype is torch.float8_e5m2:
|
||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||
|
||||
self.kv_cache = torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, 1, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get the data type of the cache."""
|
||||
return self.kv_cache.dtype
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
"""Get the key cache."""
|
||||
|
||||
return self.kv_cache
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Get the value cache."""
|
||||
|
||||
return self.kv_cache
|
||||
|
||||
def store(
|
||||
self,
|
||||
*,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
kv_scales: KVScales,
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
## TODO FP8 kv cache support
|
||||
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
if self.kv_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
|
||||
|
||||
|
||||
def paged_reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
):
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, k_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
value = torch.ops.hpu.cast_to_fp8_v2(
|
||||
value, v_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
|
||||
|
@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
|
||||
|
||||
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||
import habana_frameworks.torch.utils.experimental as htexp
|
||||
|
||||
w8a8_block_fp8_matmul = None
|
||||
per_token_group_quant_fp8 = None
|
||||
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
if is_hpu_gaudi2():
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
|
||||
|
||||
def pad_weight(weight, block_size):
|
||||
"""Pads a matrix to make its dimensions multiples of block_size."""
|
||||
M, N = weight.shape[-2:]
|
||||
block_size_m, block_size_n = block_size
|
||||
pad_M = (block_size_m - M % block_size_m) % block_size_m
|
||||
pad_N = (block_size_n - N % block_size_n) % block_size_n
|
||||
|
||||
if pad_M == 0 and pad_N == 0:
|
||||
return weight, M, N # No padding needed
|
||||
padded_weight = torch.nn.functional.pad(
|
||||
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
|
||||
)
|
||||
return padded_weight, M, N # Return original dimensions for unpadding
|
||||
|
||||
|
||||
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
|
||||
"""Removes padding from the matrix to restore its original shape."""
|
||||
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
|
||||
return weight
|
||||
if keep_first_dim:
|
||||
return weight[:, :original_M, :original_N]
|
||||
else:
|
||||
return weight[:original_M, :original_N]
|
||||
|
||||
|
||||
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
|
||||
|
||||
assert len(block_size) == 2
|
||||
|
||||
block_size_m, block_size_n = block_size
|
||||
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
|
||||
|
||||
weight, orig_M, orig_N = pad_weight(weight, block_size)
|
||||
M, N = weight.shape[-2:]
|
||||
|
||||
assert weight_scale_m == M // block_size_m
|
||||
assert weight_scale_n == N // block_size_n
|
||||
|
||||
return weight, orig_M, orig_N
|
||||
|
||||
|
||||
def dynamic_quant(data, single_scale=False):
|
||||
if single_scale:
|
||||
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
|
||||
else:
|
||||
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
|
||||
scale = scale.unsqueeze(-1)
|
||||
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
|
||||
data, 1.0 / scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
return data_fp8, scale.float()
|
||||
|
||||
|
||||
def dequant_block_fp8_weight_naive(
|
||||
weight,
|
||||
weight_scale,
|
||||
block_size,
|
||||
dtype=torch.bfloat16,
|
||||
original_M=None,
|
||||
original_N=None,
|
||||
do_unpad=False,
|
||||
):
|
||||
if weight_scale is None:
|
||||
return weight
|
||||
assert len(block_size) == 2
|
||||
|
||||
weight_shape_len = len(weight.shape)
|
||||
|
||||
block_size_m, block_size_n = block_size
|
||||
|
||||
# mul scale
|
||||
if weight_shape_len == 2:
|
||||
weight_scale_m, weight_scale_n = weight_scale.shape
|
||||
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
|
||||
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
|
||||
if is_hpu_gaudi2():
|
||||
fake_weight = weight.cpu().to(dtype).to(weight.device)
|
||||
dequant_weight = fake_weight * weight_scale.to(dtype)
|
||||
else:
|
||||
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
|
||||
dequant_weight = dequant_weight.view(
|
||||
weight_scale_m * block_size_m, weight_scale_n * block_size_n
|
||||
)
|
||||
keep_first_dim = False
|
||||
elif weight_shape_len == 3:
|
||||
fd, weight_scale_m, weight_scale_n = weight_scale.shape
|
||||
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
|
||||
weight = weight.view(
|
||||
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
|
||||
)
|
||||
if is_hpu_gaudi2():
|
||||
fake_weight = weight.cpu().to(dtype).to(weight.device)
|
||||
dequant_weight = fake_weight * weight_scale.to(dtype)
|
||||
else:
|
||||
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
|
||||
dequant_weight = dequant_weight.view(
|
||||
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
|
||||
)
|
||||
keep_first_dim = True
|
||||
else:
|
||||
raise ValueError("Only support original weight shape is either 2 or 3")
|
||||
|
||||
if do_unpad:
|
||||
dequant_weight = unpad_weight(
|
||||
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
|
||||
)
|
||||
|
||||
return dequant_weight
|
||||
|
||||
|
||||
def apply_block_fp8_linear_hpu_dynamic(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
x_fp8, x_scale = dynamic_quant(input_2d)
|
||||
|
||||
output = torch.ops.hpu.fp8_gemm_v2(
|
||||
x_fp8,
|
||||
False,
|
||||
weight,
|
||||
True,
|
||||
None,
|
||||
torch.bfloat16,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||
@ -42,7 +182,7 @@ def per_tensor_dequantize(
|
||||
) -> torch.Tensor:
|
||||
device = tensor.device
|
||||
dtype = torch.bfloat16
|
||||
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||
if is_hpu_gaudi2():
|
||||
# dequant on cpu to avoid nan on gaudi2
|
||||
tensor = tensor.to("cpu")
|
||||
|
||||
@ -269,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=dim)
|
||||
scale = scale.to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
input_scale = [
|
||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
@ -389,6 +589,22 @@ class Fp8Linear(torch.nn.Module):
|
||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||
weight_block_size = kwargs.get("weight_block_size", None)
|
||||
|
||||
if weight_block_size is not None:
|
||||
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
|
||||
weight, scale, weight_block_size
|
||||
)
|
||||
weight, scale = dynamic_quant(
|
||||
dequant_block_fp8_weight_naive(
|
||||
weight,
|
||||
scale,
|
||||
weight_block_size,
|
||||
original_M=orig_M,
|
||||
original_N=orig_N,
|
||||
do_unpad=True,
|
||||
)
|
||||
)
|
||||
scale = scale.squeeze(-1)
|
||||
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
@ -409,25 +625,10 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
# https://arxiv.org/pdf/2412.19437
|
||||
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||
# channels).
|
||||
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
self.qweight,
|
||||
scale,
|
||||
self.scale,
|
||||
self.weight_block_size,
|
||||
output_dtype=input.dtype,
|
||||
return apply_block_fp8_linear_hpu_dynamic(
|
||||
input, self.qweight, self.scale, self.input_scale, self.bias
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
self.input_scale,
|
||||
|
@ -4,7 +4,12 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
DefaultWeightsLoader,
|
||||
)
|
||||
|
||||
|
||||
from .hpu import QuantLinear
|
||||
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
modules_to_not_convert: List[str],
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
def is_layer_skipped_quantization(
|
||||
self, prefix: str, modules_to_not_convert: List[str]
|
||||
):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
self._get_gptq_params(weights)
|
||||
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_weights(weights, prefix)
|
||||
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_weights_col_packed(
|
||||
weights, prefix, block_sizes
|
||||
)
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
if self.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_weights_row(weights, prefix)
|
||||
|
||||
if self.desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
if residual is not None:
|
||||
residual += hidden_states.view(residual.shape)
|
||||
else:
|
||||
residual = hidden_states
|
||||
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||
if len(orig_shape) == 2:
|
||||
residual = residual.unsqueeze(0)
|
||||
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||
return x.view(orig_shape), residual.view(orig_shape)
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(self.weight.dtype), residual
|
||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.fp8 import (
|
||||
@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import (
|
||||
fp8_quantize,
|
||||
quant_dtype,
|
||||
normalize_e4m3fn_to_native_float8,
|
||||
dynamic_quant,
|
||||
dequant_block_fp8_weight_naive,
|
||||
)
|
||||
|
||||
try:
|
||||
from .unquantized import fused_moe
|
||||
except Exception:
|
||||
fused_moe = None
|
||||
from text_generation_server.layers.moe.fused_moe import select_experts
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class FP8SparseMoELayer(nn.Module):
|
||||
@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module):
|
||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.world_size = weights.process_group.size()
|
||||
self.rank = weights.process_group.rank()
|
||||
self.ep_rank = self.rank
|
||||
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
|
||||
|
||||
if self.use_ep:
|
||||
n_experts = (n_experts + self.world_size - 1) // self.world_size
|
||||
self.ep_offset = self.ep_rank * n_experts
|
||||
else:
|
||||
self.ep_offset = 0
|
||||
|
||||
(
|
||||
self.gate_up_proj,
|
||||
@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module):
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
weights=weights,
|
||||
use_ep=self.use_ep,
|
||||
ep_offset=self.ep_offset,
|
||||
)
|
||||
|
||||
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||
@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module):
|
||||
n_experts=n_experts,
|
||||
name=down_proj_name,
|
||||
weights=weights,
|
||||
use_ep=self.use_ep,
|
||||
ep_offset=self.ep_offset,
|
||||
)
|
||||
)
|
||||
if self.weight_block_size is not None:
|
||||
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
|
||||
dequant_block_fp8_weight_naive(
|
||||
self.gate_up_proj,
|
||||
self.gate_up_proj_weight_scale,
|
||||
self.weight_block_size,
|
||||
)
|
||||
)
|
||||
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
|
||||
dequant_block_fp8_weight_naive(
|
||||
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
|
||||
)
|
||||
)
|
||||
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
|
||||
self.gate_up_proj_weight_scale.squeeze(-1),
|
||||
self.down_proj_weight_scale.squeeze(-1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
w2=self.down_proj,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=gating_output,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
top_k=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.n_expert_group,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=self.gate_up_proj_weight_scale,
|
||||
w2_scale=self.down_proj_weight_scale,
|
||||
a1_scale=self.gate_up_proj_input_scale,
|
||||
a2_scale=self.down_proj_input_scale,
|
||||
)
|
||||
total_num_experts = gating_output.size(-1)
|
||||
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
|
||||
|
||||
if self.use_ep:
|
||||
moe_n_slice = 1
|
||||
n_expert_slice = (
|
||||
total_num_experts + self.world_size - 1
|
||||
) // self.world_size
|
||||
else:
|
||||
moe_n_slice = 1
|
||||
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
|
||||
for i in range(moe_n_slice):
|
||||
min_expert = i * n_expert_slice
|
||||
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
|
||||
w13_list_slice = [
|
||||
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
|
||||
]
|
||||
w2_list_slice = [
|
||||
self.down_proj[j, ...] for j in range(min_expert, max_expert)
|
||||
]
|
||||
w13_weight_scale = [
|
||||
self.gate_up_proj_weight_scale[j, ...]
|
||||
for j in range(min_expert, max_expert)
|
||||
]
|
||||
w2_weight_scale = [
|
||||
self.down_proj_weight_scale[j, ...]
|
||||
for j in range(min_expert, max_expert)
|
||||
]
|
||||
|
||||
current_hidden_states = torch.ops.hpu.mixture_of_experts(
|
||||
hidden_states=x_fp8,
|
||||
expert_routing_table=topk_ids.to(torch.int64),
|
||||
router_weights=topk_weights.to(x.dtype),
|
||||
w12=w13_list_slice,
|
||||
w3=w2_list_slice,
|
||||
d_scale_hidden_states=x_scale,
|
||||
d_scale_w12=w13_weight_scale,
|
||||
d_scale_w3=w2_weight_scale,
|
||||
permuted_weights=True,
|
||||
activation="silu",
|
||||
experts_min=min_expert + self.ep_offset,
|
||||
experts_max=max_expert + self.ep_offset - 1,
|
||||
)
|
||||
htorch.core.mark_step()
|
||||
if i == 0:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states.add_(current_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def _load_expert_weights(
|
||||
@ -98,13 +170,14 @@ def _load_expert_weights(
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight_scales = None
|
||||
max_input_scale = None
|
||||
|
||||
for i in range(n_experts):
|
||||
weight = get_weight_fn(prefix, i, name, weights)
|
||||
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
|
||||
|
||||
assert isinstance(weight, Fp8Weight)
|
||||
|
||||
@ -147,14 +220,26 @@ def _load_expert_multi_weights_col(
|
||||
gate_proj_name: str,
|
||||
up_proj_name: str,
|
||||
weights: Weights,
|
||||
use_ep: bool = False,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
def get_weight_fn_sharded(prefix, i, name, weights):
|
||||
return weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_multi_weights(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||
get_weight_fn if use_ep else get_weight_fn_sharded,
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=None,
|
||||
weights=weights,
|
||||
ep_offset=ep_offset if use_ep else 0,
|
||||
)
|
||||
|
||||
|
||||
@ -164,10 +249,20 @@ def _load_expert_weights_row(
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
use_ep: bool = False,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
def get_weight_fn_sharded(prefix, i, name, weights):
|
||||
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_weights(f"{prefix}.{i}.{name}")
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||
get_weight_fn if use_ep else get_weight_fn_sharded,
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=name,
|
||||
weights=weights,
|
||||
ep_offset=ep_offset if use_ep else 0,
|
||||
)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,12 +25,36 @@ def grouped_topk(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
gating_output = gating_output.float()
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.float()
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
@ -41,13 +65,19 @@ def grouped_topk(
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def fused_topk(
|
||||
@ -63,3 +93,39 @@ def fused_topk(
|
||||
if renormalize:
|
||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
@ -4,7 +4,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
|
||||
import habana_frameworks.torch as htorch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
|
||||
for i in range(n_experts):
|
||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||
htorch.core.mark_step()
|
||||
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.topk, dim=-1
|
||||
)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(x.dtype)
|
||||
|
||||
final_hidden_states = self.MoeOp(
|
||||
hidden_states=x,
|
||||
expert_routing_table=selected_experts,
|
||||
router_weights=routing_weights,
|
||||
permuted_weights=True,
|
||||
activation="silu",
|
||||
)
|
||||
|
||||
return final_hidden_states.view(-1, x.shape[1])
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
|
@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
mscale_all_dim: float,
|
||||
):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(
|
||||
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
||||
)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
||||
* self.attn_factor
|
||||
) # Get n-d magnitude scaling corrected for interpolation
|
||||
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
|
@ -343,6 +343,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
@ -468,7 +469,12 @@ def get_model(
|
||||
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
kv_cache_dtype = dtype
|
||||
if kv_cache_dtype == "fp8_e4m3fn":
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
kv_cache_dtype = dtype
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
if model_type == DEEPSEEK_V2:
|
||||
@ -934,6 +940,7 @@ def get_model_with_lora_adapters(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
@ -947,6 +954,7 @@ def get_model_with_lora_adapters(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import (
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class CohereRotary(PositionRotaryEmbedding):
|
||||
def forward(
|
||||
@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module):
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.weights import Weights
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -28,11 +28,12 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
Fp8Linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
@ -40,6 +41,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.weights import Weights
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
||||
if isinstance(layer, Fp8Linear):
|
||||
eye = torch.eye(
|
||||
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
|
||||
)
|
||||
dequant_weights = layer(eye)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
@ -249,6 +263,44 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.value_head_size,
|
||||
)
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||
q_nope, q_pe = (
|
||||
q_proj(x)
|
||||
.view(-1, self.num_heads, self.head_size)
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
|
||||
return self.o_proj(x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -261,14 +313,9 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
query = self.q_proj(hidden_states)
|
||||
hidden_states_or_q_c = hidden_states
|
||||
else:
|
||||
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
_, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, key_pe = torch.split(
|
||||
@ -276,13 +323,18 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
)
|
||||
|
||||
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||
)
|
||||
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
|
||||
query = q_proj(hidden_states_or_q_c)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
query_nope, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
else:
|
||||
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
|
||||
batch_size, heads, head_dim = query_pe.shape
|
||||
query_pe = (
|
||||
@ -297,33 +349,47 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
.reshape(batch_size, heads, head_dim)
|
||||
)
|
||||
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||
latent_vec_k = torch.concat(
|
||||
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
|
||||
)
|
||||
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
|
||||
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
|
||||
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
key=latent_vec_k,
|
||||
value=None,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
kv = self.kv_b_proj(kv_c_normed).view(
|
||||
-1,
|
||||
self.num_key_value_heads,
|
||||
self.qk_nope_head_dim + self.value_head_size,
|
||||
)
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
@ -334,9 +400,15 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
seqlen=seqlen,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
# Decode
|
||||
query = torch.cat([query_nope, query_pe], dim=-1)
|
||||
attn_output = paged_attention_mla(
|
||||
query,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
@ -344,14 +416,10 @@ class DeepseekV3Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
attn_output = self._v_up_proj_and_o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
@ -584,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -596,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
|
@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights):
|
||||
@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
|
@ -26,7 +26,7 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
import habana_frameworks.torch as htorch
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
get_kv_scales,
|
||||
@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
cross_attention_states,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class MistralConfig(PretrainedConfig):
|
||||
@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||
@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||
|
||||
|
@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import (
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -18,7 +18,6 @@
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module):
|
||||
)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.h):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
|
@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module):
|
||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
|
@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
|
@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
KVCompressCache,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
trim_attn_metadata,
|
||||
@ -68,11 +69,13 @@ from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.prefill_chunking import (
|
||||
get_max_prefill_tokens,
|
||||
)
|
||||
import vllm_hpu_extension.environment as environment
|
||||
import habana_frameworks.torch as htorch
|
||||
import itertools
|
||||
from vllm_hpu_extension.bucketing import HPUBucketingContext
|
||||
from vllm_hpu_extension.bucketing.common import get_bucketing_context
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -153,7 +156,7 @@ def prepare_for_decode(
|
||||
block_groups_device, num_classes=batch_size
|
||||
)
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= block_usage.unsqueeze(-1)
|
||||
mask = mask >= block_usage_device.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
@ -425,7 +428,9 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||
|
||||
@ -1438,15 +1443,17 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
self.bucketing_ctx = None
|
||||
htorch.core.hpu_set_env()
|
||||
if htorch.utils.internal.is_lazy():
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
environment.set_model_config(self.config)
|
||||
self.use_contiguous_pa = (
|
||||
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
|
||||
)
|
||||
self.limit_hpu_graphs = (
|
||||
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
|
||||
self.limit_hpu_graph = (
|
||||
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||
)
|
||||
self.max_seq_len_to_capture = 8192
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
@ -1478,16 +1485,27 @@ class FlashCausalLM(Model):
|
||||
):
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
self.kv_cache = [
|
||||
KVCache(
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
if self.config.model_type == "deepseek_v3":
|
||||
self.kv_cache = [
|
||||
KVCompressCache(
|
||||
num_blocks=num_blocks,
|
||||
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
self.kv_cache = [
|
||||
KVCache(
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def warmup(
|
||||
self,
|
||||
@ -1495,6 +1513,11 @@ class FlashCausalLM(Model):
|
||||
max_input_tokens: Optional[int],
|
||||
max_total_tokens: Optional[int],
|
||||
):
|
||||
if os.environ.get("MAX_BATCH_SIZE") is None:
|
||||
raise RuntimeError(
|
||||
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
|
||||
"using `--max-batch-size xxx`"
|
||||
)
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
@ -1502,8 +1525,14 @@ class FlashCausalLM(Model):
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
if self.config.model_type == "deepseek_v3":
|
||||
cache_block_size = BLOCK_SIZE * (
|
||||
self.config.kv_lora_rank + self.config.qk_rope_head_dim
|
||||
)
|
||||
else:
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
cache_block_size = cache_block_size * 2
|
||||
total_cache_size = self.num_layers * cache_block_size * dtype_size
|
||||
|
||||
try:
|
||||
self.init_kv_cache(
|
||||
@ -1563,25 +1592,33 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
|
||||
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
|
||||
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens)
|
||||
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None:
|
||||
max_total_blocks = (
|
||||
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1
|
||||
)
|
||||
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
|
||||
|
||||
self.max_batch_prefill_tokens = get_max_prefill_tokens()
|
||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
||||
HPUBucketingContext = get_bucketing_context()
|
||||
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
|
||||
model_max_length = self.tokenizer.model_max_length
|
||||
max_position_embeddings = getattr(
|
||||
self.config, "max_position_embeddings", model_max_length
|
||||
)
|
||||
self.bucketing_ctx = HPUBucketingContext(
|
||||
max_num_seqs,
|
||||
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
|
||||
max_num_seqs, # self.max_num_prefill_seqs, #TODO
|
||||
BLOCK_SIZE,
|
||||
num_blocks * BLOCK_SIZE,
|
||||
max_num_seqs * max_total_tokens_aligned,
|
||||
False,
|
||||
min(model_max_length, max_position_embeddings),
|
||||
max_input_tokens,
|
||||
max_total_tokens_aligned,
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
||||
max_blocks = (
|
||||
max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
|
||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
self.bucketing_ctx.generate_decode_buckets(
|
||||
self.bucketing_ctx.num_hpu_blocks
|
||||
)
|
||||
logger.info("skip warmup hpu graph, not recommmended")
|
||||
del _batch, batch
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
@ -1591,28 +1628,55 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
|
||||
if self.limit_hpu_graph:
|
||||
return prefill
|
||||
else:
|
||||
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
|
||||
|
||||
def warmup_hpu_graph(self, batch):
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
|
||||
def ordering_function_min_tokens(b):
|
||||
return (b[0] * b[1], b[1], b[0])
|
||||
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||
)
|
||||
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
warmup_shape_count += 1
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
warmup_shape_count += 1
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
synchronize(self.device)
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def warmup_prefill(
|
||||
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
||||
@ -1643,7 +1707,9 @@ class FlashCausalLM(Model):
|
||||
lm_head_indices = input_lengths - 1
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
True, input_ids.shape[0]
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
@ -1792,8 +1858,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = (
|
||||
batch.prefilling if self.limit_hpu_graphs else False
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
batch.prefilling, input_ids.shape[0]
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
@ -1836,9 +1902,7 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
_async_h2d_tensor_copy(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length]
|
||||
),
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
batch.next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
@ -1852,7 +1916,6 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
if batch.valid_indices is not None:
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
next_token_logprobs = next_token_logprobs.cpu()
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||
@ -1902,7 +1965,6 @@ class FlashCausalLM(Model):
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
if batch.speculative_logits is not None:
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
@ -1914,7 +1976,7 @@ class FlashCausalLM(Model):
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = index.to(batch.all_input_ids_tensor)
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
@ -1924,6 +1986,7 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
next_input_ids = next_input_ids.cpu()
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
|
@ -23,6 +23,7 @@ from text_generation_server.layers.attention import (
|
||||
_async_h2d_tensor_copy,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
import time
|
||||
from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
)
|
||||
@ -486,20 +487,32 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
|
||||
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
warmup_shape_count += 1
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
synchronize(self.device)
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import (
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
from text_generation_server.utils.log import log_master
|
||||
import time
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -325,7 +326,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
True, input_ids.shape[0]
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
@ -343,26 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||
start_time = time.time()
|
||||
warmup_shape_count = 0
|
||||
warmup_times = 3
|
||||
self.bucketing_ctx.generate_prompt_buckets()
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
|
||||
def ordering_function_min_tokens(b):
|
||||
return (b[0] * b[1], b[1], b[0])
|
||||
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||
)
|
||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||
continue
|
||||
warmup_shape_count += 1
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
synchronize(self.device)
|
||||
|
||||
def ordering_function_max_bs(b):
|
||||
return (-b[0], b[1])
|
||||
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
buckets = list(
|
||||
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||
)
|
||||
for i, (batch_size, block_num) in enumerate(buckets):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
warmup_shape_count += 1
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
synchronize(self.device)
|
||||
log_master(
|
||||
logger.info,
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -438,8 +462,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = (
|
||||
batch.prefilling if self.limit_hpu_graphs else False
|
||||
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
|
||||
batch.prefilling, input_ids.shape[0]
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
slots_pad = torch.zeros_like(input_ids)
|
||||
|
@ -206,6 +206,7 @@ def serve(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
max_input_tokens: int,
|
||||
@ -218,6 +219,7 @@ def serve(
|
||||
quantize: Optional[str] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if not is_driver_compatible():
|
||||
@ -261,6 +263,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
data_type,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
@ -308,6 +311,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
|
@ -1,17 +1,19 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
import habana_frameworks.torch as htorch
|
||||
import os
|
||||
|
||||
|
||||
def get_hpu_free_memory(device, memory_fraction):
|
||||
from habana_frameworks.torch.hpu import memory_stats
|
||||
|
||||
device_id = device.index
|
||||
mem_stats = memory_stats(device_id)
|
||||
logger.info(f"mem_stats: {mem_stats}")
|
||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||
free_memory = max(
|
||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||
graph_reserved_mem = (
|
||||
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
|
||||
if htorch.utils.internal.is_lazy()
|
||||
else 0
|
||||
)
|
||||
free_memory = int(
|
||||
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
|
||||
)
|
||||
logger.info(f"Free memory on device {device}: {free_memory} bytes.")
|
||||
return free_memory
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.weights import (
|
||||
@ -18,6 +18,8 @@ class _QuantizerConfig:
|
||||
groupsize: int
|
||||
quant_method: str
|
||||
sym: bool
|
||||
weight_block_size: Optional[List[int]]
|
||||
modules_to_not_convert: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
|
||||
activation_scale_ub: float
|
||||
|
||||
|
||||
# We should probably do this with Pytantic JSON deserialization,
|
||||
def _get_config_json(model_id: str, revision: Optional[str], filename: str):
|
||||
if os.path.exists(
|
||||
os.path.join(
|
||||
model_id,
|
||||
)
|
||||
):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# We should probably do this with Pydantic JSON deserialization,
|
||||
# but for now we'll stay close to the old _set_gptq_params.
|
||||
def _get_quantizer_config(model_id, revision):
|
||||
bits = 4
|
||||
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
|
||||
checkpoint_format = None
|
||||
sym = False
|
||||
desc_act = False
|
||||
weight_block_size = None
|
||||
modules_to_not_convert = []
|
||||
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data = _get_config_json(model_id, revision, filename)
|
||||
# FP8 config
|
||||
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||
return _FP8QuantizerConfig(
|
||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||
)
|
||||
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||
|
||||
if "zero_point" in data["quantization_config"]:
|
||||
sym = not data["quantization_config"]["zero_point"]
|
||||
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
quant_method = data["quantization_config"]["quant_method"]
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
desc_act = data["quantization_config"]["desc_act"]
|
||||
desc_act = data["quantization_config"].get("desc_act", False)
|
||||
modules_to_not_convert = data["quantization_config"].get(
|
||||
"modules_to_not_convert", []
|
||||
)
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
data = _get_config_json(model_id, revision, filename)
|
||||
bits = data["bits"]
|
||||
groupsize = data["group_size"]
|
||||
|
||||
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
data = _get_config_json(model_id, revision, filename)
|
||||
bits = data["w_bit"]
|
||||
groupsize = data["q_group_size"]
|
||||
desc_act = data["desc_act"]
|
||||
@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision):
|
||||
checkpoint_format=checkpoint_format,
|
||||
sym=sym,
|
||||
desc_act=desc_act,
|
||||
weight_block_size=weight_block_size,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
)
|
||||
|
||||
|
||||
@ -134,6 +139,7 @@ def get_loader(
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
modules_to_not_convert=quantizer_config.modules_to_not_convert,
|
||||
)
|
||||
elif quantize == "fp8" or quantize is None:
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
@ -141,9 +147,14 @@ def get_loader(
|
||||
# Since the default for the quantize config is _QuantizerConfig,
|
||||
# we need to add this check to not get an attribute error
|
||||
activation_scale_ub = None
|
||||
weight_block_size = quantizer_config.weight_block_size
|
||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||
|
||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||
return HybridFP8UnquantLoader(
|
||||
activation_scale_ub,
|
||||
to_fp8=quantize == "fp8",
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||
|
@ -62,6 +62,14 @@ class WeightsLoader(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
"""
|
||||
Get the weights at the given prefixes, column-split them for tensor
|
||||
parallelim, and then concatenate the weights along the given dimension.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||
)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
|
||||
return self.weight_class(torch.cat(w, dim=dim))
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
@ -393,6 +405,9 @@ class Weights:
|
||||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
def get_multi_weights(self, prefixes: List[str], dim: int):
|
||||
return self.weights_loader.get_multi_weights(self, prefixes, dim)
|
||||
|
||||
@contextmanager
|
||||
def use_loader(self, weights_loader: WeightsLoader):
|
||||
"""
|
||||
|
@ -8,6 +8,7 @@ use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::usage_stats::Env;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
@ -185,6 +186,9 @@ struct State {
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
|
||||
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
|
||||
is_hpu_device: bool,
|
||||
}
|
||||
|
||||
impl State {
|
||||
@ -214,6 +218,7 @@ impl State {
|
||||
speculate,
|
||||
support_chunking,
|
||||
block_allocator,
|
||||
is_hpu_device: Env::new().is_hpu_device(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -368,6 +373,21 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_hpu_device {
|
||||
//HPU needs to pad for the prefill
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
let actual_prefill_tokens_for_hpu =
|
||||
(batch.len() + 1) as u32 * max_input_length;
|
||||
|
||||
if actual_prefill_tokens_for_hpu > prefill_token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
}
|
||||
|
||||
prefill_tokens += postfix_len;
|
||||
|
||||
Some(block_allocation)
|
||||
|
Loading…
Reference in New Issue
Block a user