From d658b5def3fe6c32b09b4ffe36f770ba2aa959b4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 19 May 2025 22:36:39 +0800 Subject: [PATCH] Deepseek R1 for Gaudi backend (#3211) Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 5 +- .../server/text_generation_server/cli.py | 10 +- .../text_generation_server/layers/__init__.py | 2 + .../layers/attention/__init__.py | 5 +- .../layers/attention/hpu.py | 110 +++++++- .../layers/attention/kv_cache.py | 80 +++++- .../text_generation_server/layers/fp8.py | 243 ++++++++++++++++-- .../layers/gptq/__init__.py | 26 +- .../layers/layernorm.py | 17 +- .../text_generation_server/layers/moe/fp8.py | 143 +++++++++-- .../layers/moe/fused_moe.py | 82 +++++- .../layers/moe/unquantized.py | 28 +- .../text_generation_server/layers/rotary.py | 4 +- .../text_generation_server/models/__init__.py | 10 +- .../custom_modeling/flash_cohere_modeling.py | 8 +- .../custom_modeling/flash_dbrx_modeling.py | 7 +- .../flash_deepseek_v2_modeling.py | 6 + .../flash_deepseek_v3_modeling.py | 157 ++++++++--- .../custom_modeling/flash_gemma2_modeling.py | 7 + .../custom_modeling/flash_gemma_modeling.py | 6 + .../custom_modeling/flash_gpt2_modeling.py | 7 + .../custom_modeling/flash_gptj_modeling.py | 6 + .../custom_modeling/flash_llama_modeling.py | 7 +- .../custom_modeling/flash_mistral_modeling.py | 6 + .../custom_modeling/flash_mixtral_modeling.py | 6 + .../custom_modeling/flash_neox_modeling.py | 6 + .../custom_modeling/flash_phi_modeling.py | 6 + .../custom_modeling/flash_phi_moe_modeling.py | 1 - .../custom_modeling/flash_qwen2_modeling.py | 6 + .../custom_modeling/flash_rw_modeling.py | 6 + .../flash_santacoder_modeling.py | 6 + .../flash_starcoder2_modeling.py | 6 + .../models/flash_causal_lm.py | 157 +++++++---- .../models/flash_vlm_causal_lm.py | 21 +- .../models/mllama_causal_lm.py | 44 +++- .../server/text_generation_server/server.py | 4 + .../text_generation_server/utils/dist.py | 2 +- .../utils/import_utils.py | 18 +- .../utils/quantization.py | 65 +++-- .../text_generation_server/utils/weights.py | 15 ++ backends/v3/src/queue.rs | 20 ++ 41 files changed, 1133 insertions(+), 238 deletions(-) diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 06073fe4..54a0bb7c 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -60,6 +60,8 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor ENV ATTENTION=default ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 +ENV PT_HPU_LAZY_MODE=1 +ENV PT_HPU_WEIGHT_SHARING=0 # Text Generation Inference base env ENV HF_HOME=/data \ @@ -95,7 +97,8 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir -RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git +RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794 + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 53837ef7..b1a41534 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -26,6 +26,11 @@ class Dtype(str, Enum): bloat16 = "bfloat16" +class KVCacheDtype(str, Enum): + fp8_e4m3fn = "fp8_e4m3fn" + fp8_e5m2 = "fp8_e5m2" + + @app.command() def serve( model_id: str, @@ -34,6 +39,7 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, + kv_cache_dtype: Optional[KVCacheDtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -93,7 +99,8 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = "bfloat16" if dtype is None else dtype.value - logger.info(f"quantize={quantize}") + kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value + logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}") if dtype is not None and quantize not in { None, "bitsandbytes", @@ -175,6 +182,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, uds_path, max_input_tokens, diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py index 0000ca91..fd146728 100644 --- a/backends/gaudi/server/text_generation_server/layers/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/__init__.py @@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d +from text_generation_server.layers.fp8 import Fp8Linear from text_generation_server.layers.lora import ( LoraLinear, @@ -27,6 +28,7 @@ __all__ = [ "TensorParallelEmbedding", "SpeculativeHead", "LoraLinear", + "Fp8Linear", "TensorParallelMultiAdapterLinear", "TensorParallelAdapterRowLinear", "load_layer_norm", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 89a43d65..370e05bc 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -10,18 +10,21 @@ from .hpu import ( SUPPORTS_WINDOWING, attention, paged_attention, + paged_attention_mla, ) # KVCache needs `reshape_and_cache`, so ensure that it is defined already. -from .kv_cache import KVCache, get_kv_scales +from .kv_cache import KVCache, get_kv_scales, KVCompressCache __all__ = [ "attention", "get_kv_scales", "paged_attention", + "paged_attention_mla", "SUPPORTS_WINDOWING", "KVCache", + "KVCompressCache", "Seqlen", "HPUPagedAttentionMetadata", "trim_seqlen_metadata", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1d73dcb3..1c2e37c7 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -11,11 +11,61 @@ import os SUPPORTS_WINDOWING = False -def fetch_from_cache(cache, blocks): - if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": - return cache[: blocks.size(0)] - else: - return cache.index_select(0, blocks) +class FP8Matmul(torch.nn.Module): + + def __init__(self, scale_other): + super().__init__() + self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu") + self.scale_other = scale_other + + def quant_input(self, x, scale): + return torch.ops.hpu.cast_to_fp8_v2( + x, scale, False, False, torch.float8_e4m3fn + )[0] + + def matmul_fp8( + self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None + ): + return torch.ops.hpu.fp8_gemm_v2( + A=x, + trans_A=False, + B=other, + trans_B=False, + D=None, + out_dtype=out_dtype, + A_scale_inv=scale_input_inv, + B_scale_inv=scale_other_inv, + bias=None, + accumulate=False, + ) + + def forward(self, input, other): + qinput = self.quant_input(input, self.scale_input) + qother = self.quant_input(other, self.scale_other) + output = self.matmul_fp8( + qinput, + qother, + out_dtype=torch.bfloat16, + scale_input_inv=1.0 / self.scale_input, + scale_other_inv=1.0 / self.scale_other, + ) + return output + + +class FetchFromCache(torch.nn.Module): + + def __init__(self, scale_inv): + super().__init__() + self.scale_inv = scale_inv + + def forward(self, cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + out = cache[: blocks.size(0)] + else: + out = cache.index_select(0, blocks) + if out.dtype == torch.float8_e4m3fn: + out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16) + return out def attention( @@ -67,6 +117,7 @@ def paged_attention( hpu_attention_meta: HPUPagedAttentionMetadata, ): batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn output = ops.flat_pa( query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, @@ -76,19 +127,50 @@ def paged_attention( block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, scale=softmax_scale, - matmul_qk_op=Matmul(), - matmul_av_op=Matmul(), + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), batch2block_matmul_op=Matmul(), block2batch_matmul_op=Matmul(), - keys_fetch_func=fetch_from_cache, - values_fetch_func=fetch_from_cache, + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu), ) # Reshape the output tensor. return output.view(batch_size, head_num, head_size) -__all__ = [ - "SUPPORTS_WINDOWING", - "attention", - "paged_attention", -] +def paged_attention_mla( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, + kv_lora_rank: int = 0, +): + batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn + output = ops.flat_pa_mla( + query=query, + key_cache=kv_cache.key, + value_cache=None, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_groups=hpu_attention_meta.block_groups, + scale=softmax_scale, + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=None, + kv_lora_rank=kv_lora_rank, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, -1) + + +__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index d238cdb9..cdd1e1d7 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -50,6 +50,8 @@ class KVCache: ): """Construct the key-value cache for a layer.""" ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") self.kv_cache = ( torch.zeros( @@ -101,22 +103,92 @@ class KVCache: key_cache, value_cache, slots, - kv_scales.key_scale_cpu, - kv_scales.value_scale_cpu, + kv_scales.key_scale, + kv_scales.value_scale, ) +class KVCompressCache(KVCache): + """ + Key-value cache for attention layers. + """ + + kv_cache: torch.Tensor + + def __init__( + self, + *, + num_blocks: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") + + self.kv_cache = torch.zeros( + (num_blocks, BLOCK_SIZE, 1, head_size), + dtype=dtype, + device=device, + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache.dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + + block_idx = slots // BLOCK_SIZE + block_offset = slots % BLOCK_SIZE + if self.kv_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, kv_scales.key_scale, False, False, torch.float8_e4m3fn + )[0] + cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset) + + def paged_reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ): block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE + if key_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, k_scale, False, False, torch.float8_e4m3fn + )[0] + value = torch.ops.hpu.cast_to_fp8_v2( + value, v_scale, False, False, torch.float8_e4m3fn + )[0] cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 0dc5cdaf..44d30202 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -12,11 +12,151 @@ from text_generation_server.utils.weights import ( from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 -import habana_frameworks.torch.utils.experimental as htexp -w8a8_block_fp8_matmul = None -per_token_group_quant_fp8 = None quant_dtype: torch.dtype = torch.float8_e4m3fn +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max +if is_hpu_gaudi2(): + FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max + + +def pad_weight(weight, block_size): + """Pads a matrix to make its dimensions multiples of block_size.""" + M, N = weight.shape[-2:] + block_size_m, block_size_n = block_size + pad_M = (block_size_m - M % block_size_m) % block_size_m + pad_N = (block_size_n - N % block_size_n) % block_size_n + + if pad_M == 0 and pad_N == 0: + return weight, M, N # No padding needed + padded_weight = torch.nn.functional.pad( + weight, (0, pad_N, 0, pad_M), mode="constant", value=0 + ) + return padded_weight, M, N # Return original dimensions for unpadding + + +def unpad_weight(weight, original_M, original_N, keep_first_dim=False): + """Removes padding from the matrix to restore its original shape.""" + if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N): + return weight + if keep_first_dim: + return weight[:, :original_M, :original_N] + else: + return weight[:original_M, :original_N] + + +def pad_block_fp8_weight_naive(weight, weight_scale, block_size): + + assert len(block_size) == 2 + + block_size_m, block_size_n = block_size + weight_scale_m, weight_scale_n = weight_scale.shape[-2:] + + weight, orig_M, orig_N = pad_weight(weight, block_size) + M, N = weight.shape[-2:] + + assert weight_scale_m == M // block_size_m + assert weight_scale_n == N // block_size_n + + return weight, orig_M, orig_N + + +def dynamic_quant(data, single_scale=False): + if single_scale: + scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX + else: + scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX + scale = scale.unsqueeze(-1) + data_fp8 = torch.ops.hpu.cast_to_fp8_v2( + data, 1.0 / scale, False, False, torch.float8_e4m3fn + )[0] + return data_fp8, scale.float() + + +def dequant_block_fp8_weight_naive( + weight, + weight_scale, + block_size, + dtype=torch.bfloat16, + original_M=None, + original_N=None, + do_unpad=False, +): + if weight_scale is None: + return weight + assert len(block_size) == 2 + + weight_shape_len = len(weight.shape) + + block_size_m, block_size_n = block_size + + # mul scale + if weight_shape_len == 2: + weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = False + elif weight_shape_len == 3: + fd, weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view( + fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n + ) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = True + else: + raise ValueError("Only support original weight shape is either 2 or 3") + + if do_unpad: + dequant_weight = unpad_weight( + dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim + ) + + return dequant_weight + + +def apply_block_fp8_linear_hpu_dynamic( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + x_fp8, x_scale = dynamic_quant(input_2d) + + output = torch.ops.hpu.fp8_gemm_v2( + x_fp8, + False, + weight, + True, + None, + torch.bfloat16, + x_scale, + weight_scale, + None, + False, + ) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: @@ -42,7 +182,7 @@ def per_tensor_dequantize( ) -> torch.Tensor: device = tensor.device dtype = torch.bfloat16 - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + if is_hpu_gaudi2(): # dequant on cpu to avoid nan on gaudi2 tensor = tensor.to("cpu") @@ -269,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_tensor(f"{p}.weight_scale_inv", to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1) + for p in prefixes + ] + scale = torch.cat(scale, dim=0).reshape(-1) + + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) + for p in prefixes + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch @@ -389,6 +589,22 @@ class Fp8Linear(torch.nn.Module): scale_upper_bound = kwargs.get("scale_upper_bound", None) weight_block_size = kwargs.get("weight_block_size", None) + if weight_block_size is not None: + weight, orig_M, orig_N = pad_block_fp8_weight_naive( + weight, scale, weight_block_size + ) + weight, scale = dynamic_quant( + dequant_block_fp8_weight_naive( + weight, + scale, + weight_block_size, + original_M=orig_M, + original_N=orig_N, + do_unpad=True, + ) + ) + scale = scale.squeeze(-1) + return cls( qweight=weight, scale=scale, @@ -409,25 +625,10 @@ class Fp8Linear(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight_block_size is not None: - # https://arxiv.org/pdf/2412.19437 - # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and - # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we - # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output - # channels). - qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) - output = w8a8_block_fp8_matmul( - qinput, - self.qweight, - scale, - self.scale, - self.weight_block_size, - output_dtype=input.dtype, + return apply_block_fp8_linear_hpu_dynamic( + input, self.qweight, self.scale, self.input_scale, self.bias ) - if self.bias is not None: - output = output + self.bias - return output.to(dtype=input.dtype) - qinput, scale = fp8_quantize( input, self.input_scale, diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 90b8f692..babf3d4b 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -4,7 +4,12 @@ from typing import List, Optional, Union import torch from loguru import logger from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from text_generation_server.utils.weights import ( + Weight, + Weights, + WeightsLoader, + DefaultWeightsLoader, +) from .hpu import QuantLinear @@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, + modules_to_not_convert: List[str], ): self.bits = bits self.desc_act = desc_act @@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.modules_to_not_convert = modules_to_not_convert + + def is_layer_skipped_quantization( + self, prefix: str, modules_to_not_convert: List[str] + ): + return any(module_name in prefix for module_name in modules_to_not_convert) def get_weights(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights(weights, prefix) + try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_col_packed( + weights, prefix, block_sizes + ) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader): if self.bits != 4: use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_row(weights, prefix) + if self.desc_act: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index 84878791..4bbb6c1f 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - from vllm_hpu_extension.kernels import rms_norm - - orig_shape = hidden_states.shape if residual is not None: - residual += hidden_states.view(residual.shape) - else: - residual = hidden_states - # Note: HPUFusedRMSNorm requires 3D tensors as inputs - if len(orig_shape) == 2: - residual = residual.unsqueeze(0) - x = rms_norm().apply(residual, self.weight, self.variance_epsilon) - return x.view(orig_shape), residual.view(orig_shape) + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(self.weight.dtype), residual diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py index 071b2abe..5365f24f 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.nn as nn +import os from text_generation_server.utils.weights import Weights from text_generation_server.layers.fp8 import ( @@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import ( fp8_quantize, quant_dtype, normalize_e4m3fn_to_native_float8, + dynamic_quant, + dequant_block_fp8_weight_naive, ) - -try: - from .unquantized import fused_moe -except Exception: - fused_moe = None +from text_generation_server.layers.moe.fused_moe import select_experts +import habana_frameworks.torch as htorch class FP8SparseMoELayer(nn.Module): @@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module): self.weight_block_size = weights.weights_loader.weight_block_size self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.ep_rank = self.rank + self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true" + + if self.use_ep: + n_experts = (n_experts + self.world_size - 1) // self.world_size + self.ep_offset = self.ep_rank * n_experts + else: + self.ep_offset = 0 ( self.gate_up_proj, @@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module): gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( @@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module): n_experts=n_experts, name=down_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) ) + if self.weight_block_size is not None: + self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.weight_block_size, + ) + ) + self.down_proj, self.down_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.down_proj, self.down_proj_weight_scale, self.weight_block_size + ) + ) + self.gate_up_proj_weight_scale, self.down_proj_weight_scale = ( + self.gate_up_proj_weight_scale.squeeze(-1), + self.down_proj_weight_scale.squeeze(-1), + ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=gating_output, use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, + top_k=self.topk, + renormalize=self.renormalize, topk_group=self.topk_group, + num_expert_group=self.n_expert_group, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, - use_fp8_w8a8=True, - w1_scale=self.gate_up_proj_weight_scale, - w2_scale=self.down_proj_weight_scale, - a1_scale=self.gate_up_proj_input_scale, - a2_scale=self.down_proj_input_scale, ) + total_num_experts = gating_output.size(-1) + x_fp8, x_scale = dynamic_quant(x, single_scale=True) + + if self.use_ep: + moe_n_slice = 1 + n_expert_slice = ( + total_num_experts + self.world_size - 1 + ) // self.world_size + else: + moe_n_slice = 1 + n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice + for i in range(moe_n_slice): + min_expert = i * n_expert_slice + max_expert = min((i + 1) * n_expert_slice, total_num_experts) + w13_list_slice = [ + self.gate_up_proj[j, ...] for j in range(min_expert, max_expert) + ] + w2_list_slice = [ + self.down_proj[j, ...] for j in range(min_expert, max_expert) + ] + w13_weight_scale = [ + self.gate_up_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + w2_weight_scale = [ + self.down_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + + current_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=x_fp8, + expert_routing_table=topk_ids.to(torch.int64), + router_weights=topk_weights.to(x.dtype), + w12=w13_list_slice, + w3=w2_list_slice, + d_scale_hidden_states=x_scale, + d_scale_w12=w13_weight_scale, + d_scale_w3=w2_weight_scale, + permuted_weights=True, + activation="silu", + experts_min=min_expert + self.ep_offset, + experts_max=max_expert + self.ep_offset - 1, + ) + htorch.core.mark_step() + if i == 0: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + return final_hidden_states def _load_expert_weights( @@ -98,13 +170,14 @@ def _load_expert_weights( n_experts: int, name: str, weights: Weights, + ep_offset: int = 0, ) -> torch.Tensor: all_weight = None all_weight_scales = None max_input_scale = None for i in range(n_experts): - weight = get_weight_fn(prefix, i, name, weights) + weight = get_weight_fn(prefix, i + ep_offset, name, weights) assert isinstance(weight, Fp8Weight) @@ -147,14 +220,26 @@ def _load_expert_multi_weights_col( gate_proj_name: str, up_proj_name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_multi_weights_col( [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 ) + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=None, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) @@ -164,10 +249,20 @@ def _load_expert_weights_row( n_experts: int, name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_weights_row(f"{prefix}.{i}.{name}") + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights(f"{prefix}.{i}.{name}") + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=name, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py index e26ff877..1987f0ed 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Optional import torch @@ -25,12 +25,36 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ 1 ] # [n, top_k_group] @@ -41,13 +65,19 @@ def grouped_topk( .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .reshape(num_token, -1) ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def fused_topk( @@ -63,3 +93,39 @@ def fused_topk( if renormalize: topk_weights /= topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +): + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index ec158398..58709ec3 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn from text_generation_server.utils.weights import UnquantizedWeight, Weights -from vllm_hpu_extension.ops import DynamicFusedMOE +from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp +import habana_frameworks.torch as htorch +import torch.nn.functional as F class UnquantizedSparseMoELayer(nn.Module): @@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) - self.hpu_fused_moe = DynamicFusedMOE(n_experts) + self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1) for i in range(n_experts): - self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) - self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) + self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.MoeOp.w2_list[i].set_weight(self.down_proj[i]) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return self.hpu_fused_moe(x, gating_output, self.topk) + htorch.core.mark_step() + routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.topk, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + final_hidden_states = self.MoeOp( + hidden_states=x, + expert_routing_table=selected_experts, + router_weights=routing_weights, + permuted_weights=True, + activation="silu", + ) + + return final_hidden_states.view(-1, x.shape[1]) def _load_expert_multi_weights_col( diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 6a83d6a5..7e740e5f 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__( - inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor - ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): / get_mscale(self.scaling_factor, mscale_all_dim) * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation + super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 6ca7b567..a9a1d0b7 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -343,6 +343,7 @@ def get_model( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, ) -> Model: @@ -468,7 +469,12 @@ def get_model( model_type = config_dict["model_type"] - kv_cache_dtype = dtype + if kv_cache_dtype == "fp8_e4m3fn": + kv_cache_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + kv_cache_dtype = torch.float8_e5m2 + else: + kv_cache_dtype = dtype if FLASH_ATTENTION: if model_type == DEEPSEEK_V2: @@ -934,6 +940,7 @@ def get_model_with_lora_adapters( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, adapter_to_index: Dict[str, int], @@ -947,6 +954,7 @@ def get_model_with_lora_adapters( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, max_input_tokens, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 3bcc689d..801ae09e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import ( apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch + class CohereRotary(PositionRotaryEmbedding): def forward( @@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None - + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 15c243c9..76972d38 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import ( FastLayerNorm, ) from vllm_hpu_extension.ops import DynamicFusedMOE +import habana_frameworks.torch as htorch class DbrxAttentionConfig(PretrainedConfig): @@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) - residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 9d61c694..6ac7fc1a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch class DeepseekV2Config(PretrainedConfig): @@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 1a7ce5cf..e0481691 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -28,11 +28,12 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, get_linear, + Fp8Linear, ) from text_generation_server.layers.attention import ( Seqlen, attention, - paged_attention, + paged_attention_mla, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -40,6 +41,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch + + +def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor: + if isinstance(layer, Fp8Linear): + eye = torch.eye( + layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device + ) + dequant_weights = layer(eye) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight class DeepseekV3Config(PretrainedConfig): @@ -249,6 +263,44 @@ class DeepseekV3Attention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _q_proj_and_k_up_proj(self, x): + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + q_nope, q_pe = ( + q_proj(x) + .view(-1, self.num_heads, self.head_size) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def _v_up_proj_and_o_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size) + return self.o_proj(x) + def forward( self, hidden_states: torch.Tensor, @@ -261,14 +313,9 @@ class DeepseekV3Attention(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: - query = self.q_proj(hidden_states) + hidden_states_or_q_c = hidden_states else: - query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) - query = query.view(-1, self.num_heads, self.head_size) - - _, query_pe = torch.split( - query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, key_pe = torch.split( @@ -276,13 +323,18 @@ class DeepseekV3Attention(torch.nn.Module): ) key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( - -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size - ) + kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0] - key_nope, value = torch.split( - kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 - ) + # Prefill + if cu_seqlen_prefill is not None: + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + query = q_proj(hidden_states_or_q_c) + query = query.view(-1, self.num_heads, self.head_size) + query_nope, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + else: + query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c) batch_size, heads, head_dim = query_pe.shape query_pe = ( @@ -297,33 +349,47 @@ class DeepseekV3Attention(torch.nn.Module): .reshape(batch_size, heads, head_dim) ) self.rotary_emb(query_pe, key_pe, cos, sin) + latent_vec_k = torch.concat( + (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1 + ) + latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank) - query[..., self.qk_nope_head_dim :] = query_pe - key = torch.empty_like(query) - key[..., : self.qk_nope_head_dim] = key_nope - key[..., self.qk_nope_head_dim :] = key_pe - - # We need to pad the heads because Flash Attention does not support - # qk and v with different head sizes. - query = torch.nn.functional.pad( - query, (0, self.head_pad_size - self.head_size), value=0 - ) - key = torch.nn.functional.pad( - key, (0, self.head_pad_size - self.head_size), value=0 - ) - value = torch.nn.functional.pad( - value, (0, self.head_pad_size - self.value_head_size), value=0 - ) + latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1)) kv_cache.store( - key=key, - value=value, + key=latent_vec_k, + value=None, slots=slots, kv_scales=self.kv_scales, ) - # Prefill if cu_seqlen_prefill is not None: + kv = self.kv_b_proj(kv_c_normed).view( + -1, + self.num_key_value_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + # flash attention attn_output = attention( query=query, @@ -334,9 +400,15 @@ class DeepseekV3Attention(torch.nn.Module): seqlen=seqlen, softmax_scale=self.softmax_scale, ) - # Decode + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) else: - attn_output = paged_attention( + # Decode + query = torch.cat([query_nope, query_pe], dim=-1) + attn_output = paged_attention_mla( query, kv_cache, self.kv_head_mapping, @@ -344,14 +416,10 @@ class DeepseekV3Attention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + kv_lora_rank=self.kv_lora_rank, ) - - # Remove padding. - attn_output = attn_output[..., : self.value_head_size] - - return self.o_proj( - attn_output.reshape(-1, self.num_heads * self.value_head_size) - ) + attn_output = self._v_up_proj_and_o_proj(attn_output) + return attn_output class DeepseekV3MLP(nn.Module): @@ -584,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -596,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 79f21b0f..a5860823 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Gemma2Config(PretrainedConfig): @@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 609f03ac..3d678df1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GemmaConfig(PretrainedConfig): @@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 10024a6d..ed413662 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales +import habana_frameworks.torch as htorch def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module): hidden_states = inputs_embeds residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 41eeab78..cde03a00 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix: str, weights): @@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 81af5560..0edea03a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,7 +26,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN - +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( KVCache, get_kv_scales, @@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module): cross_attention_states, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d23d4f67..75d9d360 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch class MistralConfig(PretrainedConfig): @@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 1ef6be48..f47986d8 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class MixtralConfig(PretrainedConfig): @@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33f63333..29620826 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GPTNeoXConfig(TransformersGPTNeoXConfig): @@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.final_layer_norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 0c777912..12830991 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +import habana_frameworks.torch as htorch class PhiConfig(PretrainedConfig): @@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py index bb585cc4..c28f3aee 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py @@ -18,7 +18,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index af4b404d..7c7ac03e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix, weights): @@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module): ) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 141e13a6..76a2cd01 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -21,6 +21,7 @@ from text_generation_server.layers.attention import ( Seqlen, HPUPagedAttentionMetadata, ) +import habana_frameworks.torch as htorch def load_row(config, prefix: str, weights, bias: bool): @@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel): cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, @@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b68f4784..c64b2ff7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +import habana_frameworks.torch as htorch def load_multi_mqa( @@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 76f6f473..94c60eb6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Starcoder2Config(PretrainedConfig): @@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ad585172..eb0f7454 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -53,6 +53,7 @@ from text_generation_server.models.globals import ( ) from text_generation_server.layers.attention import ( KVCache, + KVCompressCache, Seqlen, HPUPagedAttentionMetadata, trim_attn_metadata, @@ -68,11 +69,13 @@ from text_generation_server.utils.import_utils import ( synchronize, get_free_memory, ) - +from text_generation_server.utils.prefill_chunking import ( + get_max_prefill_tokens, +) import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing.common import get_bucketing_context tracer = trace.get_tracer(__name__) @@ -153,7 +156,7 @@ def prepare_for_decode( block_groups_device, num_classes=batch_size ) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) + mask = mask >= block_usage_device.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) return trim_attn_metadata( HPUPagedAttentionMetadata( @@ -425,7 +428,9 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device - all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -1438,15 +1443,17 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" ) - self.limit_hpu_graphs = ( - os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" + self.limit_hpu_graph = ( + os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true" ) + self.max_seq_len_to_capture = 8192 super().__init__( model_id=model_id, model=model, @@ -1478,16 +1485,27 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - self.kv_cache = [ - KVCache( - num_blocks=num_blocks, - num_heads=num_heads, - head_size=head_size, - dtype=dtype, - device=device, - ) - for _ in range(num_layers) - ] + if self.config.model_type == "deepseek_v3": + self.kv_cache = [ + KVCompressCache( + num_blocks=num_blocks, + head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] def warmup( self, @@ -1495,6 +1513,11 @@ class FlashCausalLM(Model): max_input_tokens: Optional[int], max_total_tokens: Optional[int], ): + if os.environ.get("MAX_BATCH_SIZE") is None: + raise RuntimeError( + "MAX_BATCH_SIZE is not set, it should be set in the launcher " + "using `--max-batch-size xxx`" + ) # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() @@ -1502,8 +1525,14 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + if self.config.model_type == "deepseek_v3": + cache_block_size = BLOCK_SIZE * ( + self.config.kv_lora_rank + self.config.qk_rope_head_dim + ) + else: + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + cache_block_size = cache_block_size * 2 + total_cache_size = self.num_layers * cache_block_size * dtype_size try: self.init_kv_cache( @@ -1563,25 +1592,33 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - - max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) - if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: - os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) - if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: - max_total_blocks = ( - math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 - ) - os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks) - + self.max_batch_prefill_tokens = get_max_prefill_tokens() + max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) + HPUBucketingContext = get_bucketing_context() + max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE + model_max_length = self.tokenizer.model_max_length + max_position_embeddings = getattr( + self.config, "max_position_embeddings", model_max_length + ) self.bucketing_ctx = HPUBucketingContext( max_num_seqs, - os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO + max_num_seqs, # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, - num_blocks * BLOCK_SIZE, + max_num_seqs * max_total_tokens_aligned, False, + min(model_max_length, max_position_embeddings), + max_input_tokens, + max_total_tokens_aligned, ) - self.bucketing_ctx.num_hpu_blocks = num_blocks + max_blocks = ( + max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1 + ) + self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks) if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets( + self.bucketing_ctx.num_hpu_blocks + ) logger.info("skip warmup hpu graph, not recommmended") del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens @@ -1591,28 +1628,55 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture): + if self.limit_hpu_graph: + return prefill + else: + return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture + def warmup_hpu_graph(self, batch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + + for i, (batch_size, seq_len) in enumerate(buckets): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue + warmup_shape_count += 1 log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + + def ordering_function_max_bs(b): + return (-b[0], b[1]) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch @@ -1643,7 +1707,9 @@ class FlashCausalLM(Model): lm_head_indices = input_lengths - 1 kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + True, input_ids.shape[0] + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1792,8 +1858,8 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + batch.prefilling, input_ids.shape[0] ) logits, speculative_logits = self.model.forward( @@ -1836,9 +1902,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - _async_h2d_tensor_copy( - batch.all_input_ids_tensor[:, : batch.max_current_length] - ), + batch.all_input_ids_tensor[:, : batch.max_current_length], batch.next_token_logits, speculate, batch.speculative_ids, @@ -1852,7 +1916,6 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: - next_input_ids = next_input_ids.cpu() next_token_logprobs = next_token_logprobs.cpu() accepted_ids = accepted_ids.cpu() batch.all_input_ids_tensor = batch.all_input_ids_tensor[ @@ -1902,7 +1965,6 @@ class FlashCausalLM(Model): accepted_ids = accepted_ids.cpu() cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - next_input_ids = next_input_ids.cpu() if batch.speculative_logits is not None: for i in range(len(batch)): batch.all_input_ids_tensor[ @@ -1914,7 +1976,7 @@ class FlashCausalLM(Model): ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor - index = index.to(batch.all_input_ids_tensor) + index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, batch.all_input_ids_tensor.shape[0], @@ -1924,6 +1986,7 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) + next_input_ids = next_input_ids.cpu() batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids if batch.position_ids.dim() == 2: diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index b44ae03f..d9c57f20 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -23,6 +23,7 @@ from text_generation_server.layers.attention import ( _async_h2d_tensor_copy, ) import habana_frameworks.torch as htorch +import time from text_generation_server.utils.import_utils import ( synchronize, ) @@ -486,20 +487,32 @@ class FlashVlmCausalLM(FlashCausalLM): ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 5de9bca8..0e5544f2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import ( ) import torch.nn.functional as F from text_generation_server.utils.log import log_master +import time tracer = trace.get_tracer(__name__) @@ -325,7 +326,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + True, input_ids.shape[0] + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -343,26 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + for i, (batch_size, seq_len) in enumerate(buckets): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue + warmup_shape_count += 1 log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self, @@ -438,8 +462,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + batch.prefilling, input_ids.shape[0] ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index f9250115..f5080ec3 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -206,6 +206,7 @@ def serve( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, uds_path: Path, max_input_tokens: int, @@ -218,6 +219,7 @@ def serve( quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, trust_remote_code: bool = False, ): if not is_driver_compatible(): @@ -261,6 +263,7 @@ def serve( quantize, speculate, data_type, + kv_cache_dtype, trust_remote_code, max_input_tokens, adapter_to_index, @@ -308,6 +311,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, ) ) diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 1c45713e..9866710b 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -7,7 +7,7 @@ from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) -MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9")) class FakeBarrier: diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 22560dd7..d25484d6 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,17 +1,19 @@ import torch from loguru import logger +import habana_frameworks.torch as htorch +import os def get_hpu_free_memory(device, memory_fraction): - from habana_frameworks.torch.hpu import memory_stats - - device_id = device.index - mem_stats = memory_stats(device_id) - logger.info(f"mem_stats: {mem_stats}") - total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] - free_memory = max( - 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) + graph_reserved_mem = ( + float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) + if htorch.utils.internal.is_lazy() + else 0 ) + free_memory = int( + torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) + ) + logger.info(f"Free memory on device {device}: {free_memory} bytes.") return free_memory diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index a8faf4a5..022a4897 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -1,7 +1,7 @@ import json import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, List from huggingface_hub import hf_hub_download from text_generation_server.utils.weights import ( @@ -18,6 +18,8 @@ class _QuantizerConfig: groupsize: int quant_method: str sym: bool + weight_block_size: Optional[List[int]] + modules_to_not_convert: List[str] @dataclass @@ -25,7 +27,20 @@ class _FP8QuantizerConfig: activation_scale_ub: float -# We should probably do this with Pytantic JSON deserialization, +def _get_config_json(model_id: str, revision: Optional[str], filename: str): + if os.path.exists( + os.path.join( + model_id, + ) + ): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + return json.load(f) + + +# We should probably do this with Pydantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): bits = 4 @@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision): checkpoint_format = None sym = False desc_act = False + weight_block_size = None + modules_to_not_convert = [] filename = "config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download(model_id, filename=filename, revision=revision) - with open(filename, "r") as f: - data = json.load(f) - + data = _get_config_json(model_id, revision, filename) # FP8 config if data["quantization_config"]["quant_method"] == "fbgemm_fp8": return _FP8QuantizerConfig( activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + weight_block_size = data["quantization_config"].get("weight_block_size", None) if "zero_point" in data["quantization_config"]: sym = not data["quantization_config"]["zero_point"] @@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision): # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + desc_act = data["quantization_config"].get("desc_act", False) + modules_to_not_convert = data["quantization_config"].get( + "modules_to_not_convert", [] + ) + if modules_to_not_convert is None: + modules_to_not_convert = [] except Exception: filename = "quantize_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["bits"] groupsize = data["group_size"] @@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision): except Exception: filename = "quant_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["w_bit"] groupsize = data["q_group_size"] desc_act = data["desc_act"] @@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision): checkpoint_format=checkpoint_format, sym=sym, desc_act=desc_act, + weight_block_size=weight_block_size, + modules_to_not_convert=modules_to_not_convert, ) @@ -134,6 +139,7 @@ def get_loader( quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, + modules_to_not_convert=quantizer_config.modules_to_not_convert, ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader @@ -141,9 +147,14 @@ def get_loader( # Since the default for the quantize config is _QuantizerConfig, # we need to add this check to not get an attribute error activation_scale_ub = None + weight_block_size = quantizer_config.weight_block_size if isinstance(quantizer_config, _FP8QuantizerConfig): activation_scale_ub = quantizer_config.activation_scale_ub - return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + return HybridFP8UnquantLoader( + activation_scale_ub, + to_fp8=quantize == "fp8", + weight_block_size=weight_block_size, + ) else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index dec22942..4edae0d4 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -62,6 +62,14 @@ class WeightsLoader(ABC): """ ... + @abstractmethod + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + @abstractmethod def get_weights_row(self, weights: "Weights", prefix: str): """ @@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader): weights.get_sharded(f"{prefix}.weight", dim=1), ) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_tensor(f"{p}.weight") for p in prefixes] + return self.weight_class(torch.cat(w, dim=dim)) + class Weights: def __init__( @@ -393,6 +405,9 @@ class Weights: def get_weights_row(self, prefix: str): return self.weights_loader.get_weights_row(self, prefix) + def get_multi_weights(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights(self, prefixes, dim) + @contextmanager def use_loader(self, weights_loader: WeightsLoader): """ diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index d3bf4b9c..8cfee3a5 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -8,6 +8,7 @@ use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::usage_stats::Env; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, @@ -185,6 +186,9 @@ struct State { /// Paged Attention Block Allocation block_allocator: Option, + + /// indicate if it's hpu device, the hpu device needs padding to generate first token. + is_hpu_device: bool, } impl State { @@ -214,6 +218,7 @@ impl State { speculate, support_chunking, block_allocator, + is_hpu_device: Env::new().is_hpu_device(), } } @@ -368,6 +373,21 @@ impl State { } } + if self.is_hpu_device { + //HPU needs to pad for the prefill + max_input_length = max_input_length.max(entry.request.input_length); + let actual_prefill_tokens_for_hpu = + (batch.len() + 1) as u32 * max_input_length; + + if actual_prefill_tokens_for_hpu > prefill_token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + prefill_tokens += postfix_len; Some(block_allocation)