From 201dc6294fb44da349449226865226247de84a27 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 13 Mar 2025 19:21:44 -0700 Subject: [PATCH 01/35] clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 2 +- .../layers/attention/__init__.py | 51 +- .../layers/attention/common.py | 195 +- .../layers/attention/cuda.py | 357 --- .../layers/attention/flash_attn_triton.py | 813 ------ .../layers/attention/flashinfer.py | 251 -- .../layers/attention/hpu.py | 97 + .../layers/attention/ipex.py | 82 - .../layers/attention/kv_cache.py | 141 ++ .../layers/attention/rocm.py | 308 --- .../layers/awq/quantize/__init__.py | 3 + .../awq/quantize/{qmodule.py => hpu.py} | 0 .../layers/compressed_tensors/__init__.py | 3 + .../layers/compressed_tensors/loader.py | 196 ++ .../layers/compressed_tensors/w8a8_int.py | 239 ++ .../layers/compressed_tensors/w8an_fp.py | 168 ++ .../layers/compressed_tensors/wna16_int.py | 188 ++ .../layers/compressed_tensors/wna16_int_24.py | 101 + .../text_generation_server/layers/eetq.py | 43 - .../text_generation_server/layers/fp8.py | 380 ++- .../layers/gptq/__init__.py | 101 +- .../layers/gptq/custom_autotune.py | 261 -- .../layers/gptq/exllama.py | 134 - .../layers/gptq/exllamav2.py | 267 -- .../layers/gptq/ipex.py | 125 + .../layers/gptq/quant_linear.py | 359 --- .../layers/gptq/quantize.py | 17 +- .../layers/layernorm.py | 149 +- .../text_generation_server/layers/linear.py | 90 +- .../layers/marlin/fp8.py | 27 +- .../layers/marlin/gptq.py | 57 +- .../layers/marlin/marlin.py | 83 +- .../layers/marlin/util.py | 16 +- .../layers/moe/__init__.py | 55 +- .../text_generation_server/layers/moe/fp8.py | 173 ++ .../{fused_moe_rocm.py => fused_moe_ipex.py} | 17 +- .../layers/moe/gptq_marlin.py | 215 -- .../layers/moe/unquantized.py | 133 +- .../text_generation_server/layers/rotary.py | 140 +- .../layers/tensor_parallel.py | 41 +- .../text_generation_server/models/__init__.py | 693 +++++- .../models/custom_modeling/bloom_modeling.py | 4 +- .../custom_modeling/flash_cohere_modeling.py | 137 +- .../custom_modeling/flash_dbrx_modeling.py | 38 +- .../flash_deepseek_v2_modeling.py | 65 +- .../flash_deepseek_v3_modeling.py | 653 +++++ .../custom_modeling/flash_gemma2_modeling.py | 30 +- .../custom_modeling/flash_gemma_modeling.py | 29 +- .../custom_modeling/flash_gpt2_modeling.py | 29 +- .../custom_modeling/flash_gptj_modeling.py | 84 +- .../custom_modeling/flash_llama_modeling.py | 195 +- .../custom_modeling/flash_mistral_modeling.py | 65 +- .../custom_modeling/flash_mixtral_modeling.py | 29 +- .../custom_modeling/flash_neox_modeling.py | 29 +- .../flash_pali_gemma_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 29 +- .../custom_modeling/flash_qwen2_modeling.py | 71 +- .../custom_modeling/flash_rw_modeling.py | 60 +- .../flash_santacoder_modeling.py | 29 +- .../flash_starcoder2_modeling.py | 147 +- .../models/custom_modeling/idefics2.py | 1 + .../models/custom_modeling/idefics3.py | 584 +++++ .../custom_modeling/idefics_modeling.py | 104 +- .../models/custom_modeling/llava_next.py | 484 ++-- .../models/custom_modeling/mamba_modeling.py | 10 +- .../models/custom_modeling/mllama.py | 51 +- .../models/custom_modeling/opt_modeling.py | 15 +- .../models/custom_modeling/qwen2_5_vl.py | 947 +++++++ .../models/custom_modeling/qwen2_vl.py | 522 ++++ .../models/custom_modeling/vlm.py | 8 +- .../models/flash_causal_lm.py | 2169 +++++++++-------- .../text_generation_server/models/globals.py | 28 +- .../models/idefics_causal_lm.py | 21 +- .../models/mllama_causal_lm.py | 133 +- .../text_generation_server/models/model.py | 1 + .../models/seq2seq_lm.py | 18 +- .../text_generation_server/utils/dist.py | 51 +- .../utils/import_utils.py | 69 +- .../text_generation_server/utils/kernels.py | 22 + .../text_generation_server/utils/weights.py | 20 +- 80 files changed, 7786 insertions(+), 5967 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/cuda.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py create mode 100644 backends/gaudi/server/text_generation_server/layers/attention/hpu.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/ipex.py create mode 100644 backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/rocm.py create mode 100644 backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py rename backends/gaudi/server/text_generation_server/layers/awq/quantize/{qmodule.py => hpu.py} (100%) create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/eetq.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/exllama.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py create mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/ipex.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py create mode 100644 backends/gaudi/server/text_generation_server/layers/moe/fp8.py rename backends/gaudi/server/text_generation_server/layers/moe/{fused_moe_rocm.py => fused_moe_ipex.py} (80%) delete mode 100644 backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py create mode 100644 backends/gaudi/server/text_generation_server/utils/kernels.py diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 6d37c6ae..14c507d0 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -96,7 +96,7 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir - +RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 4d83a11f..9ba9f6e0 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -1,43 +1,28 @@ -from text_generation_server.utils.import_utils import SYSTEM -import os +from .common import ( + Seqlen, + HPUPagedAttentionMetadata, + trim_attn_metadata, + trim_seqlen_metadata, +) -from .common import Seqlen +from .hpu import ( + SUPPORTS_WINDOWING, + attention, + paged_attention, +) -if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") -if SYSTEM == "cuda": - from .cuda import ( - attention, - paged_attention, - reshape_and_cache, - SUPPORTS_WINDOWING, - PREFILL_IN_KV_CACHE, - ) -elif SYSTEM == "rocm": - from .rocm import ( - attention, - paged_attention, - reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, - ) -elif SYSTEM == "ipex": - from .ipex import ( - attention, - paged_attention, - reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, - ) -else: - raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") +# KVCache needs `reshape_and_cache`, so ensure that it is defined already. +from .kv_cache import KVCache, get_kv_scales __all__ = [ "attention", + "get_kv_scales", "paged_attention", - "reshape_and_cache", - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", + "KVCache", "Seqlen", + "HPUPagedAttentionMetadata", + "trim_seqlen_metadata", + "trim_attn_metadata", ] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index d6e512c0..8ec9fb46 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -1,72 +1,147 @@ from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION import torch -from typing import Optional +from typing import Optional, List, Dict +import collections + +_TYPE_CACHE = {} -if ATTENTION in {"flashinfer", "flashdecoding"}: +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] - max_q: int - max_k: int + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] + attn_bias: Optional[torch.Tensor] - def __init__( - self, - input_lengths, - prefix_lengths, - cu_seqlen_q=None, - max_q=None, - max_k=None, - ): - self.input_lengths = input_lengths - self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - max_q = 1 - else: - assert max_q is not None - assert max_k is not None - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) +def subtuple( + obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None, +): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + if isinstance(obj, dict): + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields)) + return _TYPE_CACHE[typename](**values) - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k - self.max_q = max_q - self.max_k = max_k - def clamp(self, max): - # Flash decoding doesn't need to clamp - return self +def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. -else: + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: torch.Tensor - max_q: int - max_k: int + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple( + metadata, + "TrimmedAttentionMetadata", + [ + "block_list", + "block_mapping", + "block_usage", + "block_scales", + "block_groups", + "attn_bias", + ], + ) + return attention_metadata - def clamp(self, max): - if SYSTEM == "rocm": - return self - raise NotImplementedError("Not implemented seqlen for paged") - return Seqlen(torch.clamp(self.input_lengths, max=max)) + +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cache_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__( + self, + input_lengths, + cache_lengths, + cu_seqlen_q=None, + ): + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.cache_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + + +def trim_seqlen_metadata(metadata: Seqlen) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple( + metadata, + "TrimmedSeqlen", + [ + "input_lengths", + "cache_lengths", + "cu_seqlen_q", + "cu_seqlen_k", + ], + ) + return attention_metadata diff --git a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py b/backends/gaudi/server/text_generation_server/layers/attention/cuda.py deleted file mode 100644 index 51af928d..00000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py +++ /dev/null @@ -1,357 +0,0 @@ -import torch -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ( - ATTENTION, - BLOCK_SIZE, -) -from text_generation_server.layers.attention import Seqlen -from typing import Optional - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE = 512 - -try: - from vllm._C import cache_ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION in {"flashdecoding", "flashinfer"}: - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - # value_cache => [num_blocks, num_heads, head_size, block_size] - # block_size = value_cache.shape[3] - block_size = BLOCK_SIZE - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import decode_state - - return decode_state.get().forward( - query.contiguous(), - paged_kv_cache=(key_cache, value_cache), - logits_soft_cap=softcap, - sm_scale=softmax_scale, - ) - elif ATTENTION == "flashdecoding": - max_q = 1 - max_k = max_s - import flash_attn_2_cuda - - # TODO fixme when flash contains the fix. - # Number of splits is not correctly handled - # by the current path - # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 - # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. - if softcap is None: - softcap = 0.0 - out = flash_attn_2_cuda.varlen_fwd( - query, - key_cache, - value_cache, - None, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, # pad_k - None, - block_tables, - None, - max_q, - max_k, - 0.0, # dropout - softmax_scale, - False, # zero_tensors - True, # causal - -1, # Window_left - -1, # Window right - softcap, - False, # return softmax - None, # generator - ) - return out[0] - else: - if softcap is not None: - raise RuntimeError("Paged attention doesn't support softcapping") - input_lengths = seqlen.input_lengths - from vllm._C import ops - - out = torch.empty_like(query) - - use_v1 = max_s <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 - ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - return out - - -try: - is_ampere_or_newer = major >= 8 and minor >= 0 - if not is_ampere_or_newer: - raise ImportError("FlashAttention only supports Ampere GPUs or newer.") - - import flash_attn_2_cuda - - V2 = True -except ImportError: - try: - import flash_attn_cuda - - V2 = False - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - - -SUPPORTS_WINDOWING = V2 - -if ATTENTION == "flashinfer": - - def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - from text_generation_server.layers.attention.flashinfer import ( - prefill_with_paged_kv_state, - ) - - return prefill_with_paged_kv_state.get().forward( - q.contiguous(), - causal=causal, - paged_kv_cache=(key_cache, value_cache), - logits_soft_cap=softcap, - sm_scale=softmax_scale, - window_left=window_size_left, - ) - -elif V2: - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] - -else: - - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - if softcap is not None: - raise NotImplementedError("softcap is only available with flash attn v2") - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - False, - 0, - None, - ) - return out - - -# Prefill in the cache with every kind of attention, unless we -# have a configuration that requires flash-attention v1, which -# does not support block tables. -PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py b/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py deleted file mode 100644 index 3a6f9a73..00000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py +++ /dev/null @@ -1,813 +0,0 @@ -#!/usr/bin/env python -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team - -Features supported: - -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: - -1) Non power of two head dims - -""" - -import torch -import triton -import triton.language as tl - -torch_dtype: tl.constexpr = torch.float16 - - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets( - philox_seed, philox_offset, dropout_p, m, n, stride - ).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) - rng_keep = rng_output > dropout_p - return rng_keep - - -@triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): - # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn( - bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" - ) - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = ( - batch_philox_offset - + start_m * BLOCK_M * actual_seqlen_k - + start_n - - BLOCK_N - ) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, BLOCK_N) - ) - return acc, l_i, m_i - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config( - { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - ], - key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], -) -@triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = MAX_SEQLENS_K - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N - ) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = ( - off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - ) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE - else: - off_h_k = off_h_q - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k - ) - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_HEAD, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, n_full_blocks) - ) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_HEAD, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full( - (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 - ) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - - -class _attention(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): - if o is None: - o = torch.empty_like(q, dtype=v.dtype) - - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) - - def grid(META): - return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch - - encoded_softmax = None - - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) - else: - bias_strides = (0, 0, 0, 0) - - attn_fwd[grid]( - q, - k, - v, - bias, - sm_scale, - None, - o, - *q_strides, - *k_strides, - *v_strides, - *o_strides, - *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, - philox_seed=philox_seed, - philox_offset_base=philox_offset, - encoded_softmax=encoded_softmax, - HQ=nheads_q, - HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, - BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) - - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False - return o, encoded_softmax - - -triton_attention = _attention.apply diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py b/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py deleted file mode 100644 index d603c6f5..00000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Optional -from contextvars import ContextVar -from contextlib import contextmanager - -import flashinfer -import torch - -prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( - "prefill_state" -) - -prefill_with_paged_kv_state: ContextVar[ - flashinfer.BatchPrefillWithPagedKVCacheWrapper -] = ContextVar("prefill_with_paged_kv_state") - -decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( - "decode_state" -) - -workspace: Optional[torch.Tensor] = None - - -def get_workspace(device): - """Get shared flashinfer workspace.""" - global workspace - if workspace is None: - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - return workspace - - -def create_prefill_with_paged_kv_state( - *, - device: torch.device, -): - """Create a prefill state that uses the KV cache.""" - workspace_buffer = get_workspace(device) - return flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD", use_cuda_graph=False - ) - - -@contextmanager -def use_prefill_with_paged_kv_state( - *, - state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, - block_tables: torch.Tensor, - cu_seqlens: torch.Tensor, - input_lengths: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - page_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - - # Get the lengths of the last page in a block. - if page_size == 1: - last_page_len = torch.ones( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - else: - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 - - token = prefill_with_paged_kv_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - paged_kv_indptr=indptr, - paged_kv_indices=block_tables, - paged_kv_last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - page_size=page_size, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_with_paged_kv_state.reset(token) - - -def create_prefill_state( - *, - device: torch.device, -): - """Create a prefill state.""" - workspace_buffer = get_workspace(device) - return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, kv_layout="NHD", use_cuda_graph=False - ) - - -@contextmanager -def use_prefill_state( - *, - state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, - cu_seqlens: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - token = prefill_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - kv_indptr=cu_seqlens, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_state.reset(token) - - -def create_decode_state( - *, - device: torch.device, - num_heads: int, - num_kv_heads: int, -): - """Create a decode state.""" - workspace_buffer = get_workspace(device) - num_groups = num_heads // num_kv_heads - return flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout="NHD", - use_cuda_graph=False, - # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 - use_tensor_cores=num_groups not in [1, 2, 4, 8], - ) - - -def create_decode_state_cuda_graphs( - *, - device: torch.device, - block_tables: torch.Tensor, - block_tables_ptr: torch.Tensor, - last_page_len: torch.Tensor, - num_heads: int, - num_kv_heads: int, -): - """ - Create a decode state for use with CUDA Graphs. `block_tables`, - `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are - therefore stored as part of the state. - """ - workspace_buffer = get_workspace(device) - num_groups = num_heads // num_kv_heads - return flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout="NHD", - use_cuda_graph=True, - paged_kv_indices_buffer=block_tables, - paged_kv_indptr_buffer=block_tables_ptr, - paged_kv_last_page_len_buffer=last_page_len, - # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 - use_tensor_cores=num_groups not in [1, 2, 4, 8], - ) - - -@contextmanager -def use_decode_state( - *, - state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, - input_lengths: torch.Tensor, - block_tables: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - page_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer decoding state to the given - `state` and parameters. This state will be used by all calls to the - `paged_attention` function while the context manager is active. - """ - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - - # Get the lengths of the last page in a block. - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 - - token = decode_state.set(state) - - try: - state.begin_forward( - indptr=indptr, - indices=block_tables, - last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - page_size=page_size, - data_type=dtype, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - decode_state.reset(token) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py new file mode 100644 index 00000000..56143541 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -0,0 +1,97 @@ +import torch +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from typing import Optional +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales +from vllm_hpu_extension import ops +from vllm_hpu_extension.utils import Matmul +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA +import os + +SUPPORTS_WINDOWING = False + + +def fetch_from_cache(cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + return cache[: blocks.size(0)] + else: + return cache.index_select(0, blocks) + + +def attention( + *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: KVCache, + kv_scales: KVScales, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, +): + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + bs = seqlen.input_lengths.shape[0] + _, head_num, head_size = query.shape + _, kv_head_num, head_size = key.shape + query = query.view(bs, -1, head_num, head_size).transpose(1, 2) + key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=seqlen.input_lengths, + padding_side="left", + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + return attn_output + + +def paged_attention( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, +): + batch_size, head_num, head_size = query.shape + output = ops.flat_pa( + query=query, + key_cache=kv_cache.key, + value_cache=kv_cache.value, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_scales=hpu_attention_meta.block_scales, + block_groups=hpu_attention_meta.block_groups, + scale=softmax_scale, + matmul_qk_op=Matmul(), + matmul_av_op=Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=fetch_from_cache, + values_fetch_func=fetch_from_cache, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, head_size) + + +__all__ = [ + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", +] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py b/backends/gaudi/server/text_generation_server/layers/attention/ipex.py deleted file mode 100644 index 657c90af..00000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py +++ /dev/null @@ -1,82 +0,0 @@ -import intel_extension_for_pytorch as ipex -import torch -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE -from text_generation_server.layers.attention import Seqlen -from typing import Optional - -SUPPORTS_WINDOWING = False -PREFILL_IN_KV_CACHE = False - - -def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap: Optional[float] = None, -): - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - q.contiguous() if q.device.type == "xpu" else q, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) - - return out - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - out = torch.empty_like(query) - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - seqlen.input_lengths, - BLOCK_SIZE, - max_s, - None, - ) - return out diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py new file mode 100644 index 00000000..26c80c70 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -0,0 +1,141 @@ +from typing import Tuple +from dataclasses import dataclass, field + +import torch + +from text_generation_server.models.globals import BLOCK_SIZE +from text_generation_server.utils.weights import Weights + + +@dataclass +class KVScales: + """ + Key-value scales for FP8 KV cache. + + This data class stores key and value scales both as a GPU tensor and + as a GPU float. This inconvenience is necessary because some functions + (e.g. scaling kernels) take scales as a GPU tensor, whereas others + (e.g. flashinfer) take scales as a CPU scalar. + """ + + key_scale: torch.Tensor + value_scale: torch.Tensor + key_scale_cpu: float = field(init=False) + value_scale_cpu: float = field(init=False) + + def __post_init__(self): + if self.key_scale.numel() != 1 or self.value_scale.numel() != 1: + raise ValueError("Key and value scales must be scalar tensors.") + + self.key_scale_cpu = self.key_scale.item() + self.value_scale_cpu = self.value_scale.item() + + +class KVCache: + """ + Key-value cache for attention layers. + """ + + kv_cache: Tuple[torch.Tensor, torch.Tensor] + + def __init__( + self, + *, + num_blocks: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + + self.kv_cache = ( + torch.zeros( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.zeros( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache[0].dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache[0] + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache[1] + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + + key_cache = self.kv_cache[0] + value_cache = self.kv_cache[1] + + paged_reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, + ) + + +def paged_reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, +): + + from vllm_hpu_extension import cache_ops + + block_idx = slots // BLOCK_SIZE + block_offset = slots % BLOCK_SIZE + cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + + +def get_kv_scales(weights: Weights, prefix: str) -> KVScales: + """Load KV cache scales.""" + + key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device) + value_scale = key_scale + if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor( + f"{prefix}.v_scale" + ): + key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float() + value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float() + elif weights.has_tensor(f"{prefix}.kv_scale"): + # Fall back to older more coarse-grained scale when available. + key_scale = weights.get_tensor(f"{prefix}.kv_scale").float() + value_scale = key_scale + + return KVScales(key_scale=key_scale, value_scale=value_scale) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py b/backends/gaudi/server/text_generation_server/layers/attention/rocm.py deleted file mode 100644 index 646a763d..00000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py +++ /dev/null @@ -1,308 +0,0 @@ -import os -from typing import Optional -import torch -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION -from text_generation_server.layers.attention import Seqlen -from text_generation_server.utils.log import log_master -from loguru import logger - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 - -_PARTITION_SIZE_V1V2 = 512 -_PARTITION_SIZE_CUSTOM = 256 - -use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} -ENGINE = "triton" if use_triton else "ck" - -PREFILL_IN_KV_CACHE = False - -use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" -try: - if use_rocm_custom_paged_attn: - from vllm._custom_C import paged_attention_custom -except ImportError as e: - log_master( - logger.info, - f"Custom Paged Attention not available. Complete error: {e}", - ) - use_rocm_custom_paged_attn = False - -try: - import vllm._custom_ops as ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION == "flashdecoding": - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - if softcap is not None: - raise RuntimeError("Paged attention doesn't support softcapping") - - # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - - num_kv_heads = key_cache.shape[1] - gqa_ratio = num_heads // num_kv_heads - use_custom = ( - use_rocm_custom_paged_attn - and (query.dtype == torch.half or query.dtype == torch.bfloat16) - and (head_size == 128 or head_size == 64) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_s <= 32768 - ) - - if not use_custom: - _PARTITION_SIZE = _PARTITION_SIZE_V1V2 - else: - _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM - - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = seqlen.input_lengths - - out = torch.empty_like(query) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - import vllm._custom_ops as ops - - use_v1 = ( - max_s <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512) - and not use_custom - ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - if not use_custom: - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - paged_attention_custom( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - ) - - return out - - -if ENGINE != "triton": - try: - import flash_attn_2_cuda - - log_master( - logger.info, - "ROCm: using Flash Attention 2 Composable Kernel implementation.", - ) - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - for idx in range(torch.cuda.device_count()): - name = torch.cuda.get_device_name(idx) - if "MI210" not in name and "MI250" not in name: - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - raise ImportError( - f"AMD GPU with ROCm capability {major} {minor} is not supported" - ) from e - - -SUPPORTS_WINDOWING = False -if ENGINE == "ck": - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: float = 0.0, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - None, - None, - None, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] - -elif ENGINE == "triton": - from .flash_attn_triton import triton_attention - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: Optional[float] = None, - ): - if softcap is not None: - raise NotImplementedError("softcap is only available with CK flash attn") - - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - output, _ = triton_attention( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - causal, - softmax_scale, - ) - return output - -else: - raise RuntimeError(f"Unknown attention engine {ENGINE}") diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py new file mode 100644 index 00000000..856d7c28 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py @@ -0,0 +1,3 @@ +from .hpu import WQLinear + +__all__ = ["WQLinear"] diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py similarity index 100% rename from backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py rename to backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py new file mode 100644 index 00000000..507af706 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py @@ -0,0 +1,3 @@ +from .loader import CompressedTensorsLoader + +__all__ = ["CompressedTensorsLoader"] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py new file mode 100644 index 00000000..17d0224e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py @@ -0,0 +1,196 @@ +from typing import Any, Dict, List, Union + +from compressed_tensors import QuantizationConfig, QuantizationStatus +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationType, + find_name_or_class_matches, +) +from loguru import logger +from pydantic import ValidationError +from torch import nn + +from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader +from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader +from text_generation_server.layers.compressed_tensors.wna16_int_24 import ( + WNA16Int24Loader, +) +from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, + WeightsLoader, +) + +# compressed-tensors can match modules as quantization targets. However, +# they need to be objects rather than classes or class names. Since we +# need to match `Linear` targets, make an instance that can be re-used. +_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0) + + +class CompressedTensorsLoader(WeightsLoader): + """Loader for checkpoints stored in the compressed-tensors format.""" + + def __init__(self, config: Dict[str, Any]): + quantization_config_raw = config.get("quantization_config") + if quantization_config_raw is None: + # `compression_config` was renamed to `quantization_config`; support + # retained for backward compatibility. + quantization_config_raw = config.get("compression_config") + if quantization_config_raw is None: + raise ValueError( + "Checkpoint does not have compressed-tensors configuration" + ) + + try: + quantization_config = QuantizationConfig.model_validate( + quantization_config_raw + ) + except ValidationError as e: + raise ValueError("Cannot parse compressed-tensors configuration") from e + + if quantization_config.quantization_status not in ( + QuantizationStatus.COMPRESSED, + QuantizationStatus.FROZEN, + ): + raise ValueError( + f"Model quantization was not finished, status was: {quantization_config.quantization_status}" + ) + + self.ignore = ( + quantization_config.ignore if quantization_config.ignore is not None else [] + ) + self.loaders = self._get_target_loaders(quantization_config) + + for target, loader in self.loaders.items(): + log_once( + logger.info, + f"Using {loader} for compressed-tensors target '{target}'", + ) + + def get_weights(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights(weights, prefix) + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + loader = self._lookup_loader(prefix) + return loader.get_weights_col_packed(weights, prefix, block_sizes) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights_col(weights, prefixes, dim) + + def get_weights_row(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights_row(weights, prefix) + + def _get_target_loaders( + self, quantization_config: QuantizationConfig + ) -> Dict[str, WeightsLoader]: + """ + A compressed-tensors checkpoint can use different quantizations + for different targets. This method returns a dictionary with a + loader per target. + """ + + loaders: Dict[str, WeightsLoader] = {} + + format = quantization_config.format + + for group_name, group in quantization_config.config_groups.items(): + # The group configuration can be a string, but does that ever + # happen in a serialized quantization config? + assert isinstance(group, QuantizationScheme) + + loader = self._create_loader_for_group(format, group_name, group) + + # A quantized parameter group can have multiple targets, add the + # loader for all the targets. + for target in group.targets: + if target in loaders: + raise ValueError( + f"Target '{target} has multiple configured loaders'" + ) + loaders[target] = loader + + return loaders + + def _create_loader_for_group( + self, format: str, group_name: str, group: QuantizationScheme + ) -> WeightsLoader: + """ + Find and create a loader for the group with the given quantization + scheme. + """ + # NOTE: we ignore group.output_activations because we don't support + # output quantization yet. + + input_activations = group.input_activations + weights = group.weights + if ( + format + in { + CompressionFormat.float_quantized.value, + CompressionFormat.naive_quantized.value, + } + and weights is not None + and weights.type == QuantizationType.FLOAT + and weights.num_bits == 8 + ): + # FP W8A8 or W8A16. + return W8ANFpLoader(input_activations=input_activations, weights=weights) + elif ( + format == CompressionFormat.pack_quantized.value + and weights is not None + and weights.type == QuantizationType.INT + and weights.num_bits in (4, 8) + ): + # INT W4A16 or W8A16 (GPTQ/AWQ-like). + return WNA16IntLoader(weights) + elif ( + format == CompressionFormat.marlin_24.value + and weights is not None + and weights.type == QuantizationType.INT + and weights.num_bits in (4, 8) + ): + return WNA16Int24Loader(weights) + elif ( + format + in { + CompressionFormat.int_quantized.value, + CompressionFormat.naive_quantized.value, + } + and weights is not None + and weights.type == QuantizationType.INT + and weights.num_bits == 8 + ): + return W8A8IntLoader(input_args=input_activations, weight_args=weights) + else: + raise ValueError( + f"Group '{group_name}' has unsupported compressed-tensors configurtion" + ) + + def _lookup_loader(self, prefix: str) -> WeightsLoader: + """ + Look up the loader to use for a given parameter name (prefix). + """ + + if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0: + return DefaultWeightsLoader(UnquantizedWeight) + + # We currently only handle linear layers, so unconditionally pass + # a `Linear` instance. + targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys()) + if len(targets) == 0: + raise ValueError( + f"Cannot find compressed-tensors target for prefix: {prefix}" + ) + return self.loaders[targets[0]] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py new file mode 100644 index 00000000..fff0c765 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py @@ -0,0 +1,239 @@ +from typing import List, Optional, Union, TypeVar +from dataclasses import dataclass + +from loguru import logger +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationType + +from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + + +quantization = None + + +class W8A8IntLoader(WeightsLoader): + """ + Loader for w8a8 integer compressed-tensors parameters. + """ + + def __init__( + self, + *, + input_args: Optional[QuantizationArgs], + weight_args: QuantizationArgs, + ): + if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8: + raise ValueError( + f"{type(self).__name__} only supports w8a8 int checkpoints" + ) + + if not weight_args.symmetric: + raise ValueError("Checkpoints with asymmetric weights are not supported") + + self.load_weight_scale = not weight_args.dynamic + + if input_args is not None: + self.input_symmetric = input_args.symmetric + + if not input_args.dynamic: + log_once( + logger.warning, + "Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).", + ) + else: + self.input_symmetric = True + + def __str__(self) -> str: + def scale_to_str(scale): + return "static" if scale else "dynamic" + + def symmetric_to_str(symmetric): + return "symmetric" if symmetric else "asymmetric" + + return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))" + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight", to_dtype=False) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False + ) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if weight_scale.numel() > 1: + weight_scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + weight_scale = weight_scale.reshape(-1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + w = torch.cat(w, dim=dim) + + weight_scale = None + if self.load_weight_scale: + weight_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + +OtherT = TypeVar("OtherT") + + +def _get_tensor_or_else( + weights: Weights, prefix: str, other: OtherT +) -> Union[torch.Tensor, OtherT]: + # Even if a checkpoint uses e.g. zero-points, they can be elided: + # https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105 + if weights.has_tensor(prefix): + return weights.get_tensor(prefix, to_dtype=False) + else: + return other + + +@dataclass +class Int8Weight(Weight): + input_symmetric: bool + weight: torch.Tensor + weight_scale: Optional[torch.Tensor] + + def get_linear(self, bias: torch.Tensor): + if self.weight_scale is None: + assert quantization is not None + qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight) + return W8A8IntLinear( + bias=bias, + input_symmetric=self.input_symmetric, + weight=qweight, + weight_scale=weight_scale, + ) + else: + return W8A8IntLinear( + bias=bias, + input_symmetric=self.input_symmetric, + weight=self.weight, + weight_scale=self.weight_scale, + ) + + +class W8A8IntLinear(torch.nn.Module): + def __init__( + self, + *, + bias: Optional[torch.Tensor], + input_symmetric: bool, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ): + super().__init__() + + weight_scale = weight_scale.to(torch.float32) + + self.bias = bias + self.input_symmetric = input_symmetric + # cutlass kernels require transposed weights. + self.weight = weight.t() + self.weight_scale = weight_scale + + if input_symmetric: + self.zero_point_adj = None + else: + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp + self.zero_point_adj = self.weight.sum( + dim=0, keepdim=True, dtype=torch.int32 + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + assert quantization is not None + + qinput, input_scale, input_zero_point = quantization.scaled_int8_quant( + input=input, + scale=None, + azp=None, + symmetric=self.input_symmetric, + ) + + if self.input_symmetric: + return quantization.cutlass_scaled_mm( + a=qinput, + b=self.weight, + scale_a=input_scale, + scale_b=self.weight_scale, + out_dtype=input.dtype, + bias=self.bias, + ) + else: + assert ( + self.zero_point_adj is not None + and input_scale is not None + and (self.input_symmetric or input_zero_point is not None) + ) + + return quantization.cutlass_scaled_mm_azp( + a=qinput, + b=self.weight, + scale_a=input_scale, + scale_b=self.weight_scale, + out_dtype=input.dtype, + azp_adj=self.zero_point_adj, + azp=input_zero_point, + bias=self.bias, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py new file mode 100644 index 00000000..ed63806e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -0,0 +1,168 @@ +from typing import List, Optional, Union + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationType + +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, +) +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class W8ANFpLoader(WeightsLoader): + """ + Loader for W8A8/W8A16 FP compressed-tensors parameters. + """ + + def __init__( + self, + *, + input_activations: Optional[QuantizationArgs], + weights: QuantizationArgs, + ): + assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8 + + # We ignore the `strategy` option which sets the scales to be + # per-tensor, per-channel or per-token. What scales are supported + # is dependent on the kernels used (e.g. cutlass can do tokenwise, + # Torch cannot, and FP8-Marlin does not quantize inputs at all). + # So, instead we try to use the best-possible configuration. + + self.load_weight_scale = not weights.dynamic + self.load_input_scale = ( + input_activations is not None and not input_activations.dynamic + ) + self.force_w8a16 = ( + input_activations is not None and input_activations.num_bits == 16 + ) + + def __str__(self) -> str: + def scale_to_str(scale): + return "static" if scale else "dynamic" + + quantization_type = f"W8A{16 if self.force_w8a16 else 8}" + + return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})" + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if weight_scale.numel() > 1: + weight_scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + if self.load_weight_scale: + weight_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + + input_scale = None + if self.load_input_scale: + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py new file mode 100644 index 00000000..bb69c6b5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py @@ -0,0 +1,188 @@ +from typing import List, Union + +import torch +from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs +from loguru import logger + +from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class WNA16IntLoader(WeightsLoader): + """ + Loader for W4A16/W8A16 INT compressed-tensors parameters. + """ + + def __init__(self, weights: QuantizationArgs): + self.weights = weights + self.desc_act = self.weights.actorder == ActivationOrdering.GROUP + self.groupsize = ( + -1 if self.weights.group_size is None else self.weights.group_size + ) + + def __str__(self) -> str: + quantization_type = f"W{self.weights.num_bits}A16" + + return f"{self.__class__.__name__} ({quantization_type})" + + def get_weights(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t() + except RuntimeError: + raise RuntimeError( + f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" + ) + + zero_point = None + if not self.weights.symmetric: + zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t() + + g_idx = None + if self.desc_act: + g_idx = weights.get_tensor(f"{prefix}.weight_g_idx") + + scales = weights.get_tensor(f"{prefix}.weight.scales").t() + + return repack_gptq_for_marlin( + qweight=weight_packed.contiguous(), + scales=scales, + qzeros=zero_point, + g_idx=g_idx, + bits=self.weights.num_bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method="compressed-tensors", + sym=self.weights.symmetric, + sharded_infeatures=False, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + try: + weight_packed = weights.get_packed_sharded( + f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes + ).t() + except RuntimeError: + raise RuntimeError( + f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" + ) + scales = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes + ).t() + scales = scales.to(dtype=weights.dtype) + + zero_point = None + if not self.weights.symmetric: + zero_point = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=0, block_sizes=block_sizes + ).t() + + g_idx = None + if self.desc_act: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + + return repack_gptq_for_marlin( + qweight=weight_packed.contiguous(), + scales=scales, + qzeros=zero_point, + g_idx=g_idx, + bits=self.weights.num_bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method="compressed-tensors", + sym=self.weights.symmetric, + sharded_infeatures=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + try: + weight_packed = torch.cat( + [ + weights.get_sharded(f"{p}.weight_packed", dim=0).t() + for p in prefixes + ], + dim=1, + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes], + dim=1, + ) + + zero_point = None + if not self.weights.symmetric: + zero_point = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1 + ).t() + + g_idx = None + if self.desc_act: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=weight_packed.contiguous(), + scales=scales, + qzeros=zero_point, + g_idx=g_idx, + bits=self.weights.num_bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method="compressed-tensors", + sym=self.weights.symmetric, + sharded_infeatures=False, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t() + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + + zero_point = None + if not self.weights.symmetric: + if self.desc_act or self.groupsize == -1: + zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t() + else: + zero_point = weights.get_sharded( + f"{prefix}.weight_zero_point", dim=1 + ).t() + + g_idx = None + if self.desc_act: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.weight_scale").t() + else: + scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t() + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=weight_packed.contiguous(), + scales=scales, + qzeros=zero_point, + g_idx=g_idx, + bits=self.weights.num_bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method="compressed-tensors", + sym=self.weights.symmetric, + sharded_infeatures=sharded_in_features, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py new file mode 100644 index 00000000..27b8614c --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py @@ -0,0 +1,101 @@ +from typing import List, Union + +import torch + + +from compressed_tensors.quantization import QuantizationArgs, QuantizationType +from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class WNA16Int24Loader(WeightsLoader): + """ + Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints. + """ + + def __init__(self, weight_args: QuantizationArgs): + super().__init__() + + if weight_args.type != QuantizationType.INT: + raise ValueError( + f"{type(self).__name__} only supports wNa8 int checkpoints" + ) + + if weight_args.strategy == "group" and weight_args.group_size is None: + raise ValueError("`group_size` must be set when `actorder` is `group`") + + self.bits = weight_args.num_bits + self.group_size = weight_args.group_size + + def __str__(self) -> str: + quantization_type = f"W{self.bits}A16 2:4 sparsity" + + return f"{self.__class__.__name__} ({quantization_type})" + + def get_weights(self, weights: Weights, prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + weight_packed = weights.get_tensor(f"{prefix}.weight_packed") + meta = weights.get_tensor(f"{prefix}.meta") + scale_packed = weights.get_tensor(f"{prefix}.scale_packed") + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + weight_packed = weights.get_packed_sharded( + f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes + ) + meta = weights.get_packed_sharded( + f"{prefix}.meta", dim=1, block_sizes=block_sizes + ) + scale_packed = weights.get_packed_sharded( + f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes + ) + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + weight_packed = torch.cat( + [weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1 + ) + meta = torch.cat( + [weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1 + ) + scale_packed = torch.cat( + [weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1 + ) + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0) + meta = weights.get_sharded(f"{prefix}.meta", dim=0) + if self.group_size is None: + scale_packed = weights.get_tensor(f"{prefix}.scale_packed") + else: + scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0) + + return GPTQMarlin24Weight( + weight_packed=weight_packed, + meta=meta, + scale_packed=scale_packed, + bits=self.bits, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/eetq.py b/backends/gaudi/server/text_generation_server/layers/eetq.py deleted file mode 100644 index b1e5235a..00000000 --- a/backends/gaudi/server/text_generation_server/layers/eetq.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass - -import torch -from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import UnquantizedWeight - - -@dataclass -class EETQWeight(UnquantizedWeight): - weight: torch.Tensor - - def get_linear(self, bias: torch.Tensor): - try: - from text_generation_server.layers.eetq import EETQLinear - - return EETQLinear(self.weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - - -class EETQLinear(torch.nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - device = weight.device - if weight.dtype != torch.float16: - weight = weight.to(dtype=torch.float16) - weight = torch.t(weight).contiguous().cpu() - weight, scale = quant_weights(weight, torch.int8, False) - - self.weight = weight.cuda(device) - self.scale = scale.cuda(device) - self.bias = bias.cuda(device) if bias is not None else None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output = w8_a16_gemm(input, self.weight, self.scale) - output = output + self.bias if self.bias is not None else output - return output diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 61dd5115..e37c4983 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -1,100 +1,163 @@ -import torch - from dataclasses import dataclass -from typing import Optional, Union, List +from typing import Optional, Tuple, Type, Union, List + +import torch from loguru import logger -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import ( Weight, WeightsLoader, UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_master, log_once -import importlib.util +from text_generation_server.utils.log import log_once + +quantization = None +w8a8_block_fp8_matmul = None +per_token_group_quant_fp8 = None + +quant_dtype: torch.dtype = torch.float8_e4m3fn -FBGEMM_MM_AVAILABLE = False -FBGEMM_DYN_AVAILABLE = False +CUTLASS_FP8_AVAILABLE = False -def is_fbgemm_gpu_available(): - try: - return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None - except ModuleNotFoundError: - return False - - -if is_fbgemm_gpu_available(): - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - FBGEMM_MM_AVAILABLE = major == 9 - FBGEMM_DYN_AVAILABLE = major >= 8 -else: - log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") - - -def get_fp8_linear() -> torch.nn.Module: +def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. """ - - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - if major == 8: - from text_generation_server.layers.marlin import GPTQMarlinFP8Linear - - return GPTQMarlinFP8Linear - # On other systems let Torch decide if the hardware supports FP8. return Fp8Linear -def fp8_quantize( - weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False -): - if FBGEMM_DYN_AVAILABLE and not scalar: - qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( - weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype - ) - return qweight, scale +def normalize_e4m3fn_to_native_float8( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return weight, weight_scale, input_scale + + +def per_tensor_dequantize( + tensor: torch.Tensor, + inv_scale: Union[float, torch.Tensor], + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + fake_qweight = tensor.to(dtype) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def requantize_with_max_scale( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: int, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max().float() + + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx], dtype + ) + weight[start:end, :], max_w_scale_normalized = fp8_quantize( + weight_dq, max_w_scale + ) + start = end + + return weight, max_w_scale_normalized + + +def fp8_quantize( + weight: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[torch.Tensor] = None, + qdtype: torch.dtype = torch.float8_e4m3fn, + scalar: bool = False, +): + """ + This function returns a reciprocal of the scale, so that a tensor can be unscaled + by multiplying it with the returned scale. If a scale is given through the `scale` + argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can + be used without modification). + """ + if quantization is not None: + shape = weight.shape + qweight, scale = quantization.scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + scale=scale, + scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, + ) + + return qweight.reshape(shape), scale - # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + + if scale is None: + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + scale = scale.float().reciprocal() + else: + # Use reciprocal to avoid more expensive division. + qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), # as both required as inputs to torch._scaled_mm qweight = qweight.to(qdtype) - scale = scale.float().reciprocal() + return qweight, scale class HybridFP8UnquantLoader(WeightsLoader): """Weight loader that loads FP8 and unquantized Torch tensors.""" - def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + def __init__( + self, + activation_scale_ub: Optional[float], + to_fp8: bool, + weight_block_size: Optional[List[int]] = None, + ): self.activation_scale_ub = activation_scale_ub self.to_fp8 = to_fp8 + self.weight_block_size = weight_block_size def get_weights(self, weights: "Weights", prefix: str): w = weights.get_tensor(f"{prefix}.weight") if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = weights.get_tensor(f"{prefix}.weight_scale_inv") + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) # FP8 branch - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -116,6 +179,7 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: scale = weights.get_packed_sharded( f"{prefix}.weight_scale", @@ -123,11 +187,25 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -148,15 +226,43 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -169,14 +275,32 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + if self.weight_block_size is not None: + # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems. + scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) + + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -191,83 +315,142 @@ class Fp8Weight(Weight): weight: torch.Tensor dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None + force_w8a16: bool = False + weight_block_size: Optional[List[int]] = None def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: - return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant( + self.weight, bias, self.dtype + ) # This is not checked by the fbgemm kernels, but they require contiguous # memory. Can be non-contiguous when we e.g. expand from scalars. self.weight_scale = self.weight_scale.contiguous() - return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8( + weight=self.weight, + scale=self.weight_scale, + dtype=self.dtype, + bias=bias, + input_scale=self.input_scale, + scale_upper_bound=self.activation_scale_ub, + weight_block_size=self.weight_block_size, ) class Fp8Linear(torch.nn.Module): + _device_identity_cache = {} + def __init__( self, - qweight, - scale, - scale_upper_bound, - bias, - dtype, + qweight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[float] = None, + weight_block_size: Optional[List[int]] = None, ) -> None: super().__init__() - if FBGEMM_MM_AVAILABLE: - log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + if CUTLASS_FP8_AVAILABLE: + log_once(logger.info, "Using cutlass w8a8 kernels") self.dtype = dtype self.qweight = qweight - self.scale = scale - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device + self.scale = scale.float() + self.input_scale = input_scale.float() if input_scale is not None else None + self.weight_block_size = weight_block_size + + if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: + self.scale_upper_bound = torch.tensor( + scale_upper_bound, dtype=torch.float32, device=qweight.device ) - if scale_upper_bound is not None - else None - ) + else: + self.scale_upper_bound = scale_upper_bound self.bias = bias if bias is not None else None @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE) return cls( - qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + qweight=qweight, + scale=scale, + dtype=dtype, + bias=bias, + input_scale=None, + scale_upper_bound=None, ) @classmethod - def from_fp8(cls, weight, scale, input_scale, bias, dtype): - if FBGEMM_DYN_AVAILABLE: - # fbgemm needs float32 scales. - scale = scale.float() + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> "Fp8Linear": + input_scale = kwargs.get("input_scale", None) + scale_upper_bound = kwargs.get("scale_upper_bound", None) + weight_block_size = kwargs.get("weight_block_size", None) + return cls( qweight=weight, scale=scale, - scale_upper_bound=input_scale, + input_scale=input_scale, + scale_upper_bound=scale_upper_bound, bias=bias, dtype=dtype, + weight_block_size=weight_block_size, ) - def forward(self, input: torch.Tensor) -> torch.Tensor: - if FBGEMM_MM_AVAILABLE: - qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound - ) + @classmethod + def get_shared_device_identity(cls, device): + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale + if device not in cls._device_identity_cache: + cls._device_identity_cache[device] = torch.ones(1, device=device) + return cls._device_identity_cache[device] - y = torch.ops.fbgemm.f8f8bf16_rowwise( + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.weight_block_size is not None: + # https://arxiv.org/pdf/2412.19437 + # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and + # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we + # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output + # channels). + qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) + output = w8a8_block_fp8_matmul( qinput, self.qweight, scale, self.scale, - use_fast_accum=True, - bias=self.bias, + self.weight_block_size, + output_dtype=input.dtype, ) - return y.to(self.dtype) - qinput, scale = fp8_quantize(input, scalar=True) - output, _ = torch._scaled_mm( + if self.bias is not None: + output = output + self.bias + return output.to(dtype=input.dtype) + if CUTLASS_FP8_AVAILABLE: + # cutlass FP8 supports per-token scales, so get non-scalar scales. + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound, scalar=False + ) + return quantization.cutlass_scaled_mm( + qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias + ) + + qinput, scale = fp8_quantize( + input, + self.input_scale, + scale_upper_bound=self.scale_upper_bound, + scalar=True, + ) + + output = torch._scaled_mm( qinput, self.qweight.t(), out_dtype=self.dtype, @@ -275,11 +458,16 @@ class Fp8Linear(torch.nn.Module): scale_b=self.scale, bias=self.bias, ) + + if isinstance(output, tuple) and len(output) == 2: + output = output[0] + return output def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) return scale.reshape(-1).expand(shape[0]) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 505caa59..e62a334c 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -1,14 +1,15 @@ -import os from dataclasses import dataclass from typing import List, Optional, Union import torch from loguru import logger -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +QuantLinear = None + + @dataclass class GPTQWeight(Weight): qweight: torch.Tensor @@ -30,13 +31,8 @@ class GPTQWeight(Weight): def get_linear(self, bias: torch.Tensor): if self.use_awq_kernel: - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear + from text_generation_server.layers.awq.quantize import WQLinear return WQLinear( w_bit=self.bits, @@ -50,18 +46,7 @@ class GPTQWeight(Weight): raise NotImplementedError( "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" ) - elif self.use_exllama: - try: - from text_generation_server.layers.gptq import ExllamaQuantLinear - except ImportError: - raise NotImplementedError( - "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - return ExllamaQuantLinear(self, bias) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - return QuantLinear( self.qweight, self.qzeros, @@ -118,23 +103,6 @@ class GPTQWeightsLoader(WeightsLoader): else: g_idx = None - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") @@ -298,6 +266,7 @@ class GPTQWeightsLoader(WeightsLoader): self._get_gptq_params(weights) use_exllama = True + desc_act = self.desc_act if self.bits != 4: use_exllama = False @@ -321,7 +290,8 @@ class GPTQWeightsLoader(WeightsLoader): if g_idx is not None: if ( not torch.equal( - g_idx.cpu(), + # Remove g_idx[0] to adapt the check with TP>1. + (g_idx - g_idx[0]).cpu(), torch.tensor( [i // self.groupsize for i in range(g_idx.shape[0])], dtype=torch.int32, @@ -332,34 +302,22 @@ class GPTQWeightsLoader(WeightsLoader): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_exllama = False + desc_act = True from text_generation_server.layers.gptq import ( - CAN_EXLLAMA, - HAS_EXLLAMA, GPTQWeight, ) - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and self.groupsize != -1: + if not desc_act and self.groupsize != -1: qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0) + if g_idx is not None: + # qzeros, scales sharded, and g_idx must be adjusted accordingly + g_idx = g_idx - g_idx[0] else: qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - if self.quantize == "gptq" and self.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." @@ -392,7 +350,7 @@ class GPTQWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False @@ -400,41 +358,10 @@ class GPTQWeightsLoader(WeightsLoader): # before the `gptq_sym` setting tensor was added. self.sym = ( weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") + if weights.has_tensor("gptq_sym") else False ) self.quant_method = "gptq" -# Needs to be at the end because circular import. -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, # noqa: F401 - create_exllama_buffers, # noqa: F401 - set_device, # noqa: F401 - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 - create_exllama_buffers, # noqa: F401 - set_device, # noqa: F401 - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py b/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py deleted file mode 100644 index 0388ef20..00000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py +++ /dev/null @@ -1,261 +0,0 @@ -# https://github.com/fpgaminer/GPTQ-triton -""" -Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. -""" - -import builtins -import math -import time -from typing import Dict - -import triton - - -class Autotuner(triton.KernelInterface): - def __init__( - self, - fn, - arg_names, - configs, - key, - reset_to_zero, - prune_configs_by: Dict = None, - nearest_power_of_two: bool = False, - ): - """ - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results - """ - if not configs: - self.configs = [triton.Config({}, num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.nearest_power_of_two = nearest_power_of_two - self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = ( - prune_configs_by["perf_model"], - prune_configs_by["top_k"], - ) - if "early_config_prune" in prune_configs_by: - early_config_prune = prune_configs_by["early_config_prune"] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - self.fn = fn - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.fn.run( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **current, - ) - - try: - # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses - # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - return triton.testing.do_bench( - kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 - ) - except triton.OutOfResources: - return [float("inf"), float("inf"), float("inf")] - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple(args[i] for i in self.key_idx) - - # This reduces the amount of autotuning by rounding the keys to the nearest power of two - # In my testing this gives decent results, and greatly reduces the amount of tuning required - if self.nearest_power_of_two: - key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) - - if key not in self.cache: - # prune configs - pruned_configs = self.prune_configs(kwargs) - bench_start = time.time() - timings = { - config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs - } - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.fn.run( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model( - **self.nargs, - **kwargs, - **config.kwargs, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ - :top_k - ] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - - -def autotune( - configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False -): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - .. highlight:: python - .. code-block:: python - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - - def decorator(fn): - return Autotuner( - fn, - fn.arg_names, - configs, - key, - reset_to_zero, - prune_configs_by, - nearest_power_of_two, - ) - - return decorator - - -def matmul248_kernel_config_pruner(configs, nargs): - """ - The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. - """ - m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) - n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) - k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) - - used = set() - for config in configs: - block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) - block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) - block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) - group_size_m = config.kwargs["GROUP_SIZE_M"] - - if ( - block_size_m, - block_size_n, - block_size_k, - group_size_m, - config.num_stages, - config.num_warps, - ) in used: - continue - - used.add( - ( - block_size_m, - block_size_n, - block_size_k, - group_size_m, - config.num_stages, - config.num_warps, - ) - ) - yield triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - }, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py deleted file mode 100644 index f27666b7..00000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py +++ /dev/null @@ -1,134 +0,0 @@ -from text_generation_server.layers.gptq import GPTQWeight -import torch -from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -def ext_make_q4(qweight, qzeros, scales, g_idx, device): - """Construct Q4Matrix, return handle""" - return make_q4( - qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device - ) - - -def ext_q4_matmul(x, q4, q4_width): - """Matrix multiplication, returns x @ q4""" - outshape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device) - - q4_matmul(x, q4, output) - - return output.view(outshape) - - -MAX_DQ = 1 -MAX_INNER = 1 -ACT_ORDER = False -DEVICE = None - -TEMP_STATE = None -TEMP_DQ = None - - -def set_device(device): - global DEVICE - DEVICE = device - - -def create_exllama_buffers(max_total_tokens: int): - global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ - - assert DEVICE is not None, "call set_device first" - - if not ACT_ORDER: - max_total_tokens = 1 - - # This temp_state buffer is required to reorder X in the act-order case. - temp_state = torch.zeros( - (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE - ) - temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) - - # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - prepare_buffers(DEVICE, temp_state, temp_dq) - - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - TEMP_STATE, TEMP_DQ = temp_state, temp_dq - - -class Ex4bitLinear(torch.nn.Module): - """Linear layer implementation with per-group 4-bit quantization of the weights""" - - def __init__(self, weight: GPTQWeight, bias): - super().__init__() - global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE - assert weight.bits == 4 - - self.device = weight.qweight.device - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None - self.bias = bias if bias is not None else None - - if self.g_idx is not None and ( - (self.g_idx == 0).all() - or torch.equal( - weight.g_idx.cpu(), - torch.tensor( - [i // weight.groupsize for i in range(weight.g_idx.shape[0])], - dtype=torch.int32, - ), - ) - ): - self.empty_g_idx = True - self.g_idx = None - - assert self.device.type == "cuda" - assert self.device.index is not None - - self.q4 = ext_make_q4( - self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index - ) - - self.height = weight.qweight.shape[0] * 8 - self.width = weight.qweight.shape[1] - - # Infer groupsize from height of qzeros - self.groupsize = None - if self.qzeros.shape[0] > 1: - self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) - - if self.groupsize is not None: - assert weight.groupsize == self.groupsize - - # Handle act-order matrix - if self.g_idx is not None: - if self.groupsize is None: - raise ValueError("Found group index but no groupsize. What do?") - self.act_order = True - else: - self.act_order = False - - DEVICE = self.qweight.device - - MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8) - - if self.act_order: - MAX_INNER = max(MAX_INNER, self.height, self.width) - - ACT_ORDER = True - - def forward(self, x): - out = ext_q4_matmul(x, self.q4, self.width) - - if self.bias is not None: - out.add_(self.bias) - return out diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py deleted file mode 100644 index 920a6adf..00000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py +++ /dev/null @@ -1,267 +0,0 @@ -# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 - -from dataclasses import dataclass -from typing import Optional -import torch -import torch.nn as nn - -from loguru import logger - -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.utils.log import log_master - -try: - from exllamav2.ext import exllamav2_ext - - make_q_matrix = exllamav2_ext.make_q_matrix - gemm_half_q_half = exllamav2_ext.gemm_half_q_half -except ImportError: - log_master(logger.warning, "exllamav2_kernels not installed.") - raise - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -@dataclass -class _ExtraTensors: - """Additional generated quantizer tensors.""" - - q_group_map: Optional[torch.Tensor] = None - q_invperm: Optional[torch.Tensor] = None - q_perm: Optional[torch.Tensor] = None - - -def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): - """Matrix multiplication, returns x @ q4""" - output_shape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) - gemm_half_q_half(x, q_handle, output, force_cuda) - return output.view(output_shape) - - -def make_group_map(q_groups: torch.Tensor, num_qrows: int): - gr = q_groups.tolist() - group_map = [] - num_groups = len(gr) // 2 - - for i in range(num_groups): - bits = gr[i * 2] - if i < num_groups - 1: - qrows = gr[i * 2 + 3] - gr[i * 2 + 1] - else: - qrows = num_qrows - gr[i * 2 + 1] - rows = qrows * 32 // bits - for j in range(rows): - group_map += [i] - group_map += [rows - j] - - return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) - - -# Create Q matrix - - -def ext_make_q_matrix( - w: Exl2Weight | GPTQWeight, - extra: _ExtraTensors, - temp_dq, - key: Optional[str] = None, -): - """ - Create Q matrix - """ - # max_dq_size = 512*(1024**2) - # max_dq_rows = max_dq_size // out_features[0] - max_dq_rows = 0 - - # EXL2 - if isinstance(w, Exl2Weight): - extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) - extra.q_perm = torch.argsort(w.q_invperm).short() - - return make_q_matrix( - w.q_weight, - extra.q_perm, - w.q_invperm, - w.q_scale, - w.q_scale_max, - w.q_groups, - extra.q_group_map, - none_tensor, # zeros - none_tensor, # scales - none_tensor, # g_idx - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - # GPTQ - elif isinstance(w, GPTQWeight): - if w.scales.dtype == torch.float: - w.scales = w.scales.half() - - # GPTQ with g_idx (act_order) - if w.g_idx is not None and not (w.g_idx == 0).all().item(): - extra.q_perm = torch.empty( - (w.qweight.shape[0] * 8,), - dtype=torch.short, - device=w.qweight.device, - ) - extra.q_invperm = torch.empty_like(extra.q_perm) - # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return make_q_matrix( - w.qweight, - extra.q_perm, - extra.q_invperm, - none_tensor, # q_scale - none_tensor, # q_scale_max - none_tensor, # q_groups - none_tensor, # q_group_map - w.qzeros, - w.scales, - w.g_idx.cpu(), - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - # GPTQ without g_idx - else: - return make_q_matrix( - w.qweight, - none_tensor, # q_perm - none_tensor, # q_invperm - none_tensor, # q_scale - none_tensor, # q_scale_max - none_tensor, # q_groups - none_tensor, # q_group_map - w.qzeros, - w.scales, - none_tensor, # g_idx - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - else: - RuntimeError("Cannot create handle") - - -DEVICE = None -LAYERS = [] - - -def set_device(device): - global DEVICE - DEVICE = device - - -def create_exllama_buffers(max_total_tokens: int): - global LAYERS, DEVICE - - # No need to initialize scratch space if there are no layers - # that use ExLLamav2. - if len(LAYERS) == 0: - return - - # Find the size of the scratch space. - scratch_bytes = max( - layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) - for layer in LAYERS - ) - temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) - - for layer in LAYERS: - layer.post_init(temp_dq) - - -class QuantLinear(nn.Module): - QUANT_TYPE = "exllamav2" - - """Linear layer implementation with per-group 4-bit quantization of the weights""" - - def __init__( - self, - weight: Exl2Weight | GPTQWeight, - bias: torch.Tensor, - ): - super().__init__() - - self.q_handle = None - self.q_tensors = weight - self.extra_tensors = _ExtraTensors() - - if isinstance(weight, Exl2Weight): - self.infeatures = weight.q_invperm.shape[0] - self.outfeatures = weight.q_weight.shape[1] - elif isinstance(weight, GPTQWeight): - if weight.bits != 4: - raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." - ) - - self.infeatures = weight.qweight.shape[0] // weight.bits * 32 - self.outfeatures = weight.qweight.shape[1] - - self.padding = -self.outfeatures % 32 - self.outfeatures = self.outfeatures + self.padding - - self.device = weight.device - self.bias = bias if bias is not None else None - - global LAYERS - LAYERS.append(self) - - def post_init(self, temp_dq): - device = self.q_tensors.device - assert device.type == "cuda" - assert device.index is not None - temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - - # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, - # and `Memory access fault by GPU node-2` will EAT you. - self.temp_dq = temp_dq - self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) - - def forward(self, x, force_cuda=False): - output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) - - if self.bias is not None: - output.add_(self.bias) - return output - - def temp_dq_size(self): - return self.infeatures * self.outfeatures * 2 + 128 - - def temp_fwd_size(self, max_input_len, max_batch_size): - return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - - def scratch_space_fixed(self, max_input_len, max_batch_size): - return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) - - -class ExLlamaV2DeviceTensors: - - device_idx: int - scratch_bytes: int - scratch_idx: int - scratch: torch.tensor = None - - def __init__(self, device, scratch_bytes): - self.device = device - self.scratch_bytes = scratch_bytes - - def prepare(self): - self.scratch = torch.empty( - (self.scratch_bytes // 2,), dtype=torch.half, device=self.device - ) - - def get_scratch_slice(self, size_bytes): - - if self.scratch is None: - self.prepare() - - size_bytes = ((size_bytes + 127) // 128) * 128 - size_half = size_bytes // 2 - scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py b/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py new file mode 100644 index 00000000..48584e90 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py @@ -0,0 +1,125 @@ +import math +import numpy as np +import torch +import torch.nn as nn + +import intel_extension_for_pytorch as ipex + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.infeatures, + self.outfeatures, + bias=self.bias, + group_size=self.groupsize, + g_idx=g_idx, + quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py b/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py deleted file mode 100644 index 736c357b..00000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py +++ /dev/null @@ -1,359 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_fwd - -import triton -import triton.language as tl -from . import custom_autotune - - -# code based https://github.com/fpgaminer/GPTQ-triton -@custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, -) -@triton.jit -def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # eventually avoid overflow - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - - def grid(META): - return ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) - return output - - -class QuantLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - return output - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py index b0086ea0..aa664ea6 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files -from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq import QuantLinear from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error @@ -956,15 +956,24 @@ def quantize( pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file - from transformers.modeling_utils import shard_checkpoint + from huggingface_hub import split_torch_state_dict_into_shards state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} max_shard_size = "10GB" - shards, index = shard_checkpoint( - state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + state_dict_split = split_torch_state_dict_into_shards( + state_dict, + filename_pattern="model.safetensors", + max_shard_size=max_shard_size, ) + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + shards = state_dict_split.filename_to_tensors os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): save_file( diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index ce5289f9..84878791 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -1,9 +1,6 @@ import torch from torch import nn from accelerate import init_empty_weights -from text_generation_server.utils.import_utils import ( - SYSTEM, -) # Monkey patching @@ -33,69 +30,14 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias -if SYSTEM == "cuda": - import dropout_layer_norm - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - -elif SYSTEM == "rocm": - from vllm._C import ops - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super().forward(hidden_states), residual - -elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - out = ipex.llm.functional.add_layer_norm( - residual, - hidden_states, - self.weight, - self.bias, - self.eps, - residual is not None, - ) - return out, residual if residual is not None else hidden_states + return super().forward(hidden_states), residual class FastRMSNorm(nn.Module): @@ -111,74 +53,15 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "ipex": - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - residual is not None, - ) - return out, residual if residual is not None else hidden_states - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states + from vllm_hpu_extension.kernels import rms_norm - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states, residual - elif SYSTEM == "cuda": - # faster post attention rms norm - ( - normed_hidden_states, - res, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual + orig_shape = hidden_states.shape + if residual is not None: + residual += hidden_states.view(residual.shape) else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + residual = hidden_states + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + if len(orig_shape) == 2: + residual = residual.unsqueeze(0) + x = rms_norm().apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual.view(orig_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/linear.py b/backends/gaudi/server/text_generation_server/layers/linear.py index 08306d57..cca80c44 100644 --- a/backends/gaudi/server/text_generation_server/layers/linear.py +++ b/backends/gaudi/server/text_generation_server/layers/linear.py @@ -1,21 +1,5 @@ import torch -from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F -import os - -if SYSTEM == "rocm": - ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( - "true", - "1", - ) - - if ROCM_USE_SKINNY_GEMM: - try: - from vllm import _custom_C - except Exception as e: - raise ImportError( - f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" - ) class FastLinear(torch.nn.Module): @@ -44,83 +28,11 @@ class FastLinear(torch.nn.Module): return F.linear(input, self.weight, self.bias) -class FastLinearROCm(torch.nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.weight = torch.nn.Parameter(weight) - if bias is not None: - self.bias = torch.nn.Parameter(bias) - else: - self.bias = None - - self.cu_count = torch.cuda.get_device_properties( - device="cuda" - ).multi_processor_count - self.use_skinny_gemm = ( - ROCM_USE_SKINNY_GEMM - and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName - ) - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") - if bias: - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None - return cls(weight, bias) - - def forward(self, inp: torch.Tensor) -> torch.Tensor: - weight = self.weight - bias = self.bias - - if ( - self.use_skinny_gemm - and inp.dtype == torch.float16 - and inp.shape[-1] % 8 == 0 - ): - batched = False - inp_shape = inp.shape - - if inp.dim() == 3: - inp = inp.view(-1, inp_shape[-1]) - batched = True - - m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] - if m > 8 and n <= 4: - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device - ) - _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) - elif m % 4 == 0 and n == 1 and k <= 8192: - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device - ) - _custom_C.LLMM1(weight, inp, out, 4) - else: - out = F.linear(inp, weight) - - if batched: - out.view(*inp_shape[:-1], out.shape[-1]) - - if bias is not None: - out = out + bias - return out - return F.linear(inp, self.weight, self.bias) - - def get_linear(weight, bias): # Weights that are loaded through methods that are not # quantization-aware are still bare tensors. We may want # to change this in the future. if isinstance(weight, torch.Tensor): - if SYSTEM == "rocm": - return FastLinearROCm(weight, bias) - else: - return FastLinear(weight, bias) + return FastLinear(weight, bias) return weight.get_linear(bias) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py index fe55a58a..c2666d2b 100644 --- a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py @@ -2,19 +2,15 @@ from typing import Optional import torch import torch.nn as nn -from loguru import logger from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.util import ( _check_marlin_kernels, permute_scales, ) -from text_generation_server.utils.log import log_once -try: - import marlin_kernels -except ImportError: - marlin_kernels = None + +quantization = None MARLIN_TILE_SIZE = 16 @@ -34,9 +30,7 @@ class GPTQMarlinFP8Linear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None - - log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + assert quantization is not None scales = scales.unsqueeze(0) if scales.shape[1] == 1: @@ -62,14 +56,21 @@ class GPTQMarlinFP8Linear(nn.Module): return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + bias: torch.Tensor, + dtype: torch.dtype, + **kwargs, + ): return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.fp8_marlin_gemm( + C = quantization.fp8_marlin_gemm( A_flat, self.qweight, self.scales, @@ -131,7 +132,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): qweight = pack_fp8_as_int32(weight.t()) perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( + repacked = quantization.gptq_marlin_repack( qweight, perm, in_features, out_features, 8 ) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py index 0a785d94..185a6d77 100644 --- a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py +++ b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py @@ -11,14 +11,12 @@ from text_generation_server.layers.marlin.util import ( permute_scales, unpack_cols, ) -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: - import marlin_kernels -except ImportError: - marlin_kernels = None + +quantization = None + try: major, _minor = torch.cuda.get_device_capability() @@ -35,17 +33,7 @@ MARLIN_TILE_SIZE = 16 def can_use_gptq_marlin( *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool ) -> bool: - return ( - SYSTEM == "cuda" - and marlin_kernels is not None - and has_sm_8_0 - and quantize in {"awq", "gptq"} - and quant_method in {"awq", "gptq"} - and bits in GPTQ_MARLIN_BITS - and groupsize in GPTQ_MARLIN_GROUP_SIZES - # We only suppord asymmetric quantization for AWQ. - and (sym or quant_method == "awq") - ) + return False class GPTQMarlinWeightsLoader(WeightsLoader): @@ -231,7 +219,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False @@ -239,7 +227,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader): # before the `gptq_sym` setting tensor was added. self.sym = ( weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") + if weights.has_tensor("gptq_sym") else False ) self.quant_method = "gptq" @@ -261,7 +249,7 @@ class GPTQMarlinWeight(Weight): def __post_init__(self): assert self.qweight.dtype == torch.int32 - assert self.scales.dtype == torch.float16 + assert self.scales.dtype in (torch.float16, torch.bfloat16) assert self.g_idx.dtype == torch.int32 assert self.perm.dtype == torch.int32 @@ -287,7 +275,7 @@ def repack_gptq_for_marlin( ) -> GPTQMarlinWeight: """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None if bits not in GPTQ_MARLIN_BITS: supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) @@ -300,7 +288,7 @@ def repack_gptq_for_marlin( raise RuntimeError( f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" ) - if not (sym or quant_method == "awq"): + if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"): raise RuntimeError( "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." ) @@ -330,7 +318,7 @@ def repack_gptq_for_marlin( g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) if quant_method == "awq": - repacked = marlin_kernels.awq_marlin_repack( + repacked = quantization.awq_marlin_repack( qweight, in_features, out_features, bits ) if qzeros is not None: @@ -342,7 +330,7 @@ def repack_gptq_for_marlin( ) else: - repacked = marlin_kernels.gptq_marlin_repack( + repacked = quantization.gptq_marlin_repack( qweight, perm, in_features, out_features, bits ) @@ -379,13 +367,26 @@ class GPTQMarlinLinear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE out_features = weight.scales.shape[1] _check_valid_shape(in_features=in_features, out_features=out_features) - self.bits = weight.bits + if weight.bits not in (4, 8): + raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization") + + if weight.qzeros.numel() > 0: + if weight.bits == 4: + self.quant_type = quantization.scalar_types.uint4 + else: + self.quant_type = quantization.scalar_types.uint8 + else: + if weight.bits == 4: + self.quant_type = quantization.scalar_types.uint4b8 + else: + self.quant_type = quantization.scalar_types.uint8b128 + self.is_full_k = weight.is_full_k self.qweight = weight.qweight @@ -403,10 +404,10 @@ class GPTQMarlinLinear(nn.Module): ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( + C = quantization.gptq_marlin_gemm( A_flat, self.qweight, self.scales, @@ -414,7 +415,7 @@ class GPTQMarlinLinear(nn.Module): self.g_idx, self.perm, self.workspace, - self.bits, + self.quant_type, A_flat.shape[0], self.scales.shape[1], A_flat.shape[1], diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py index 89ebaca6..2ffbcf33 100644 --- a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py +++ b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py @@ -3,13 +3,11 @@ from typing import List, Optional, Union import torch import torch.nn as nn + from text_generation_server.layers.marlin.util import _check_marlin_kernels from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: - import marlin_kernels -except ImportError: - marlin_kernels = None +quantization = None class MarlinWeightsLoader(WeightsLoader): @@ -34,7 +32,9 @@ class MarlinWeightsLoader(WeightsLoader): B_meta = weights.get_tensor(f"{prefix}.B_meta") s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: try: B = weights.get_tensor(f"{prefix}.B") @@ -65,7 +65,9 @@ class MarlinWeightsLoader(WeightsLoader): f"{prefix}.s", dim=1, block_sizes=block_sizes ) - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: B = weights.get_packed_sharded( f"{prefix}.B", dim=1, block_sizes=block_sizes @@ -96,7 +98,9 @@ class MarlinWeightsLoader(WeightsLoader): [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 ) - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: try: B = torch.cat( @@ -132,7 +136,9 @@ class MarlinWeightsLoader(WeightsLoader): else: s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + weight = GPTQMarlin24Weight( + weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits + ) else: try: B = weights.get_sharded(f"{prefix}.B", dim=0) @@ -179,7 +185,7 @@ class MarlinLinear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None in_features = weight.B.shape[0] * MARLIN_TILE_SIZE out_features = weight.s.shape[1] @@ -208,9 +214,9 @@ class MarlinLinear(nn.Module): ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None - C = marlin_kernels.marlin_gemm( + C = quantization.marlin_gemm( A.view(-1, A.shape[-1]), self.B, self.s, @@ -247,15 +253,15 @@ class GPTQMarlin24Weight: bits: quantized weight size. """ - B: torch.Tensor - B_meta: torch.Tensor - s: torch.Tensor + weight_packed: torch.Tensor + meta: torch.Tensor + scale_packed: torch.Tensor bits: int def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.B_meta.dtype == torch.int16 - assert self.s.dtype == torch.float16 + assert self.weight_packed.dtype == torch.int32 + assert self.meta.dtype == torch.int16 + assert self.scale_packed.dtype == torch.float16 def get_linear(self, bias: torch.Tensor): return GPTQMarlin24Linear( @@ -269,7 +275,7 @@ class GPTQMarlin24Linear(nn.Module): super().__init__() _check_marlin_kernels() - assert marlin_kernels is not None + assert quantization is not None if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: supported_bits = ", ".join( @@ -279,9 +285,13 @@ class GPTQMarlin24Linear(nn.Module): f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" ) - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.s.shape[1] - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2 + out_features = weight.scale_packed.shape[1] + groupsize = ( + -1 + if weight.scale_packed.shape[0] == 1 + else in_features // weight.scale_packed.shape[0] + ) if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: supported_sizes = ", ".join( @@ -291,8 +301,11 @@ class GPTQMarlin24Linear(nn.Module): f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" ) - self.bits = weight.bits - weights_per_int32 = 32 // self.bits + if weight.bits == 4: + self.quant_type = quantization.scalar_types.uint4b8 + else: + self.quant_type = quantization.scalar_types.uint8b128 + weights_per_int32 = 32 // weight.bits assert ( out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 @@ -309,9 +322,9 @@ class GPTQMarlin24Linear(nn.Module): f"Number of input features ({in_features}) not divisable by group size ({groupsize})" ) - self.B = weight.B - self.B_meta = weight.B_meta - self.s = weight.s + self.weight_packed = weight.weight_packed + self.meta = weight.meta + self.scale_packed = weight.scale_packed if bias is not None: self.bias = bias else: @@ -320,25 +333,25 @@ class GPTQMarlin24Linear(nn.Module): self.workspace = torch.zeros( (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, dtype=torch.int, - device=weight.B.device, + device=weight.weight_packed.device, ) def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None + assert quantization is not None - C = marlin_kernels.gptq_marlin_24_gemm( + C = quantization.gptq_marlin_24_gemm( A.view(-1, A.shape[-1]), - self.B, - self.B_meta, - self.s, + self.weight_packed, + self.meta, + self.scale_packed, self.workspace, - self.bits, + self.quant_type, A.shape[0], - self.s.shape[1], + self.scale_packed.shape[1], A.shape[1], ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) + C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],)) if self.bias is not None: C += self.bias diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/util.py b/backends/gaudi/server/text_generation_server/layers/marlin/util.py index 250d1714..9f52340f 100644 --- a/backends/gaudi/server/text_generation_server/layers/marlin/util.py +++ b/backends/gaudi/server/text_generation_server/layers/marlin/util.py @@ -3,12 +3,9 @@ from typing import List, Tuple import numpy import torch -from text_generation_server.utils.import_utils import SYSTEM -try: - import marlin_kernels -except ImportError: - marlin_kernels = None + +quantization = None try: major, _minor = torch.cuda.get_device_capability() @@ -18,12 +15,11 @@ except Exception: def _check_marlin_kernels(): - if not (SYSTEM == "cuda" and has_sm_8_0): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) + raise NotImplementedError( + "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." + ) - if marlin_kernels is None: + if quantization is None: raise NotImplementedError( "marlin is not installed, install it with: pip install server/marlin" ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py index 2c46ca02..cba81407 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py @@ -10,13 +10,8 @@ from text_generation_server.layers import ( TensorParallelRowLinear, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader -from text_generation_server.layers.moe.gptq_marlin import ( - GPTQMarlinSparseMoELayer, - can_use_marlin_moe_gemm, -) from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, @@ -24,12 +19,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, ) -if SYSTEM == "rocm": - from .fused_moe_rocm import grouped_topk - from vllm.model_executor.layers.fused_moe import fused_topk -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_topk, grouped_topk - +from .fused_moe_ipex import fused_topk, grouped_topk # NOTE: we are using a protocol here, because multiple inherance is not nice. # We need `Module`, and `Module` -> some abstract class -> some concrete @@ -52,6 +42,8 @@ class MoELayer(Protocol): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): ... def forward( @@ -81,9 +73,14 @@ class DenseMoELayer(nn.Module): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): super().__init__() + assert scoring_func is None, "scoring func is not handled" + assert e_score_correction_bias is None, "scoring correction bias is not handled" + log_once( logger.info, "No fused layers are available for this model type, using (slower) dense MoE layer", @@ -199,22 +196,27 @@ class SparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", ): super().__init__() - if ( isinstance(weights.loader, DefaultWeightsLoader) and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): - cls = UnquantizedSparseMoELayer - elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: - cls = GPTQMarlinSparseMoELayer + if ( + isinstance(weights.loader, HybridFP8UnquantLoader) + and weights.loader.to_fp8 + ): + cls = FP8SparseMoELayer + else: + cls = UnquantizedSparseMoELayer else: raise ValueError( - f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights" ) log_once( @@ -230,6 +232,8 @@ class SparseMoELayer(nn.Module): topk=topk, topk_group=topk_group, weights=weights, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, down_proj_name=down_proj_name, @@ -241,17 +245,6 @@ class SparseMoELayer(nn.Module): @staticmethod def is_supported(weights: Weights) -> bool: return ( - ( - isinstance(weights.loader, DefaultWeightsLoader) - and isinstance(weights.loader.weight_class, UnquantizedWeight) - ) - or isinstance(weights.loader, HybridFP8UnquantLoader) - or ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) - and can_use_marlin_moe_gemm( - quant_method=weights.loader.quant_method, - quantize=weights.loader.quantize, - sym=weights.loader.sym, - ) - ) - ) + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) or isinstance(weights.loader, HybridFP8UnquantLoader) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py new file mode 100644 index 00000000..071b2abe --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -0,0 +1,173 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.fp8 import ( + Fp8Weight, + fp8_quantize, + quant_dtype, + normalize_e4m3fn_to_native_float8, +) + +try: + from .unquantized import fused_moe +except Exception: + fused_moe = None + + +class FP8SparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + + ( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.gate_up_proj_input_scale, + ) = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( + _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + use_fp8_w8a8=True, + w1_scale=self.gate_up_proj_weight_scale, + w2_scale=self.down_proj_weight_scale, + a1_scale=self.gate_up_proj_input_scale, + a2_scale=self.down_proj_input_scale, + ) + + +def _load_expert_weights( + get_weight_fn, + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + all_weight_scales = None + max_input_scale = None + + for i in range(n_experts): + weight = get_weight_fn(prefix, i, name, weights) + + assert isinstance(weight, Fp8Weight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=quant_dtype, + device=weight.weight.device, + ) + if all_weight_scales is None: + all_weight_scales = torch.empty( + (n_experts,) + weight.weight_scale.shape, + dtype=torch.float32, + device=weight.weight.device, + ) + + if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: + all_weight[i], all_weight_scales[i], current_input_scale = ( + normalize_e4m3fn_to_native_float8( + weight.weight, weight.weight_scale, weight.input_scale + ) + ) + if current_input_scale is not None: + if max_input_scale is None or current_input_scale > max_input_scale: + max_input_scale = current_input_scale + else: + all_weight[i], all_weight_scales[i] = fp8_quantize( + weight.weight, scalar=True + ) + + assert all_weight is not None + + return all_weight, all_weight_scales, max_input_scale + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + ) + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights_row(f"{prefix}.{i}.{name}") + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py similarity index 80% rename from backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py rename to backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py index 68accb99..e26ff877 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py @@ -16,10 +16,8 @@ from typing import Tuple import torch -import torch.distributed -# TODO: Remove the functions once moe_kernel are built for ROCM def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -50,3 +48,18 @@ def grouped_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + topk_weights = torch.nn.functional.softmax( + gating_output, dim=1, dtype=torch.float32 + ) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py b/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py deleted file mode 100644 index 3217cdc2..00000000 --- a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py +++ /dev/null @@ -1,215 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import torch -import torch.nn as nn - -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeight, - GPTQMarlinWeightsLoader, -) - -if SYSTEM == "cuda": - from moe_kernels.fused_marlin_moe import fused_marlin_moe -else: - fused_marlin_moe = None - - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def can_use_marlin_moe_gemm( - *, - quant_method: str, - quantize: str, - sym: bool, -): - return ( - SYSTEM == "cuda" - and fused_marlin_moe is not None - and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" - and sym - ) - - -@dataclass -class GPTQMarlinMoEWeight: - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - is_full_k: bool - - -class GPTQMarlinSparseMoELayer(nn.Module): - """ - MoE layer that uses a fused GPTQ-Marlin kernel. - """ - - def __init__( - self, - *, - n_expert_group: Optional[int], - n_experts: int, - prefix: str, - renormalize: bool, - topk: int, - topk_group: Optional[int], - weights: Weights, - gate_proj_name: str = "gate_proj", - up_proj_name: str = "up_proj", - down_proj_name: str = "down_proj", - ): - super().__init__() - - if not ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym - ): - raise ValueError( - f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" - ) - - assert (n_expert_group is None) == ( - topk_group is None - ), "n_expert_group and topk_group must both be None or have some value" - - self.n_expert_group = n_expert_group - self.topk = topk - self.topk_group = topk_group - self.renormalize = renormalize - - self.gate_up_proj = _load_expert_multi_weights_col( - prefix=prefix, - n_experts=n_experts, - names=[gate_proj_name, up_proj_name], - weights=weights, - ) - - self.down_proj = _load_expert_weights_row( - prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights - ) - - self.bits = weights.loader.bits - - def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_marlin_moe( - x, - w1=self.gate_up_proj.qweight, - w2=self.down_proj.qweight, - g_idx1=self.gate_up_proj.g_idx, - g_idx2=self.down_proj.g_idx, - perm1=self.gate_up_proj.perm, - perm2=self.down_proj.perm, - w1_scale=self.gate_up_proj.scales, - w2_scale=self.down_proj.scales, - is_full_k1=self.gate_up_proj.is_full_k, - is_full_k2=self.down_proj.is_full_k, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - num_bits=self.bits, - ) - - -def _load_expert_multi_weights_col( - *, - prefix: str, - n_experts: int, - names: List[str], - weights: Weights, -) -> GPTQMarlinMoEWeight: - moe_weight = None - for i in range(n_experts): - weight = weights.get_multi_weights_col( - [f"{prefix}.{i}.{name}" for name in names], 0 - ) - assert isinstance(weight, GPTQMarlinWeight) - moe_weight = _pack_weight( - n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight - ) - assert moe_weight is not None - return moe_weight - - -def _load_expert_weights_row( - *, - prefix: str, - n_experts: int, - name: str, - weights: Weights, -) -> GPTQMarlinMoEWeight: - moe_weight = None - for i in range(n_experts): - weight = weights.get_weights_row( - f"{prefix}.{i}.{name}", - ) - assert isinstance(weight, GPTQMarlinWeight) - moe_weight = _pack_weight( - n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight - ) - assert moe_weight is not None - return moe_weight - - -def _pack_weight( - *, - n_experts: int, - expert: int, - moe_weight: Optional[GPTQMarlinMoEWeight], - weight: GPTQMarlinWeight, -) -> GPTQMarlinMoEWeight: - if moe_weight is None: - qweight = torch.empty( - (n_experts,) + weight.qweight.shape, - dtype=weight.qweight.dtype, - device=weight.qweight.device, - ) - qzeros = torch.empty( - (n_experts,) + weight.qzeros.shape, - dtype=weight.qzeros.dtype, - device=weight.qzeros.device, - ) - scales = torch.empty( - (n_experts,) + weight.scales.shape, - dtype=weight.scales.dtype, - device=weight.scales.device, - ) - g_idx = torch.empty( - (n_experts,) + weight.g_idx.shape, - dtype=weight.g_idx.dtype, - device=weight.g_idx.device, - ) - perm = torch.empty( - (n_experts,) + weight.perm.shape, - dtype=weight.perm.dtype, - device=weight.perm.device, - ) - - moe_weight = GPTQMarlinMoEWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - is_full_k=weight.is_full_k, - ) - - moe_weight.qweight[expert] = weight.qweight - moe_weight.qzeros[expert] = weight.qzeros - moe_weight.scales[expert] = weight.scales - moe_weight.g_idx[expert] = weight.g_idx - moe_weight.perm[expert] = weight.perm - - return moe_weight diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index d9d62c0e..8cb27879 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -1,15 +1,11 @@ -from typing import Optional +from typing import Callable, List, Optional import torch import torch.nn as nn -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe +moe_kernels = None class UnquantizedSparseMoELayer(nn.Module): @@ -23,6 +19,8 @@ class UnquantizedSparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", @@ -37,6 +35,9 @@ class UnquantizedSparseMoELayer(nn.Module): self.topk = topk self.topk_group = topk_group self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, @@ -54,17 +55,6 @@ class UnquantizedSparseMoELayer(nn.Module): ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - if SYSTEM == "rocm": - return fused_moe( - x, - self.gate_up_proj, - self.down_proj, - gating_output, - self.topk, - renormalize=self.renormalize, - inplace=True, - ) - return fused_moe( x, w1=self.gate_up_proj, @@ -76,6 +66,8 @@ class UnquantizedSparseMoELayer(nn.Module): use_grouped_topk=self.n_expert_group is not None, num_expert_group=self.n_expert_group, topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, ) @@ -136,3 +128,110 @@ def _load_expert_weights_row( assert all_weight is not None return all_weight + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + from loguru import logger + import inspect + + logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}") + topk_weights, topk_ids = moe_kernels.grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = moe_kernels.fused_topk( + hidden_states, gating_output, topk, renormalize + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize + ) + + return moe_kernels.fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index a2076bb2..5b6cad5c 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -2,14 +2,10 @@ import os import math import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM == "cuda": - import rotary_emb -elif SYSTEM == "rocm": - from vllm._C import ops -elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) def _create_inv_freq(dim, base, device): @@ -30,7 +26,7 @@ def _get_rope_config(config): class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq, scaling_factor): + def __init__(self, inv_freq, scaling_factor, max_position_embeddings): super().__init__() self.inv_freq = inv_freq self._seq_len_cached = 0 @@ -40,6 +36,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, inv_freq.device, max_position_embeddings + ) def forward( self, @@ -48,34 +47,30 @@ class PositionRotaryEmbedding(nn.Module): cos: torch.Tensor, sin: torch.Tensor, ): - # Such controlflows may add some overhead. - if SYSTEM == "cuda": - rotary_dim = cos.shape[-1] - q1 = query[..., :rotary_dim] - q2 = query[..., rotary_dim : 2 * rotary_dim] + num_tokens = query.shape[0] + head_size = query.shape[-1] + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., :rotary_dim] - k2 = key[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "ipex": - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), True - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) @classmethod def static(cls, config, dim, base, device): @@ -89,6 +84,14 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass + elif rope_type == "default": + pass + elif rope_type == "mrope": + mrope_section = rope_scaling["mrope_section"] + if mrope_section is not None: + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -109,7 +112,7 @@ class PositionRotaryEmbedding(nn.Module): ], ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] @@ -190,7 +193,7 @@ class PositionRotaryEmbedding(nn.Module): raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) @classmethod def load(cls, config, prefix, weights): @@ -236,7 +239,7 @@ class PositionRotaryEmbedding(nn.Module): raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -257,17 +260,7 @@ class PositionRotaryEmbedding(nn.Module): self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - if SYSTEM == "rocm": - # For RoCm, we always use float cos/sin to avoid a cast. - # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 - # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. - dtype = torch.float32 - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) + def get_cos_sin(self, position_ids: torch.Tensor): cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -383,7 +376,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -461,7 +454,9 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) + super().__init__( + inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor + ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -546,3 +541,44 @@ def apply_llama3_scaling( new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): + def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list): + super().__init__(inv_freq, scaling_factor) + self.sections = sections + self._cos_cached = None + self._sin_cached = None + self.section_indices = ( + torch.arange(len(self.sections)) + .repeat_interleave(torch.tensor(self.sections)) + .view(1, 1, -1) + .to(inv_freq.device) + ) + + def _update_cos_sin_cache( + self, dtype: torch.dtype, device: torch.device, seqlen: int + ): + # always cache the cos/sin for the full sequence length to avoid + # recomputing if the sequence length is smaller than the cached one + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + self._sections = self.section_indices.expand(seqlen, -1, -1) + + def get_cos_sin( + self, + position_ids: torch.Tensor, + ): + slen = position_ids.shape[0] + + cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) + sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) + return cos, sin diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1..8f19174f 100644 --- a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py +++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py @@ -2,10 +2,8 @@ import torch from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear -from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex +import habana_frameworks.torch as htorch class LayerConcat(torch.nn.Module): @@ -90,14 +88,10 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - if SYSTEM == "ipex": - ipex.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) - else: - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + htorch.core.mark_step() + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -107,10 +101,9 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - if SYSTEM == "ipex": - ipex.distributed.all_gather(world_output, output, group=self.process_group) - else: - torch.distributed.all_gather(world_output, output, group=self.process_group) + + htorch.core.mark_step() + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -202,10 +195,11 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - if SYSTEM == "ipex": - ipex.distributed.all_reduce(out, group=self.process_group) - else: - torch.distributed.all_reduce(out, group=self.process_group) + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -242,8 +236,9 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - if SYSTEM == "ipex": - ipex.distributed.all_reduce(out, group=self.process_group) - else: - torch.distributed.all_reduce(out, group=self.process_group) + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 651b71ec..926fb57a 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -1,3 +1,5 @@ +# ruff: noqa: F821 +# the above line disables the `undefined-name` rule for the model type variables import torch import os @@ -8,6 +10,7 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path from typing import List, Dict +import enum # Needed to properly setup habana_frameworks @@ -35,9 +38,313 @@ from text_generation_server.utils.adapter import ( ) from text_generation_server.adapters.lora import LoraWeights - +from text_generation_server.utils.log import log_master from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +__all__ = [ + "Model", + "CausalLM", + "Seq2SeqLM", + "get_model_with_lora_adapters", +] +from text_generation_server.models.globals import ATTENTION + +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." + +FLASH_ATTENTION = False +if ATTENTION == "paged": + FLASH_ATTENTION = True + +try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.mllama_causal_lm import MllamaCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) + from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import ( + FlashDeepseekV3ForCausalLM, + DeepseekV3Config, + ) + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, + ) + from text_generation_server.models.pali_gemma import ( + PaliGemmaBatch, + ) + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, + ) + from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM + from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch + from text_generation_server.models.custom_modeling.mllama import ( + MllamaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gptj_modeling import ( + FlashGPTJForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.idefics3 import ( + Idefics3ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.qwen2_5_vl import ( + Qwen2_5VLForConditionalGeneration, + Qwen2_5_VLConfig, + Qwen2_5_VLProcessor, + ) + from text_generation_server.layers.attention import SUPPORTS_WINDOWING +except ImportError as e: + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") + SUPPORTS_WINDOWING = False + FLASH_ATTENTION = False + +if FLASH_ATTENTION: + __all__.append(FlashCausalLM) + __all__.append(IdeficsCausalLM) + + +class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } + DEEPSEEK_V3 = { + "type": "deepseek_v3", + "name": "Deepseek V3", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V3", + } + IDEFICS2 = { + "type": "idefics2", + "name": "Idefics 2", + "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", + "multimodal": True, + } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + "multimodal": True, + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + GRANITE = { + "type": "granite", + "name": "Granite", + "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } + GEMMA2 = { + "type": "gemma2", + "name": "Gemma2", + "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", + } + COHERE = { + "type": "cohere", + "name": "Cohere", + "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", + } + DBRX = { + "type": "dbrx", + "name": "Dbrx", + "url": "https://huggingface.co/databricks/dbrx-instruct", + } + MAMBA = { + "type": "mamba", + "name": "Mamba", + "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407", + } + MIXTRAL = { + "type": "mixtral", + "name": "Mixtral", + "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", + } + GPT_BIGCODE = { + "type": "gpt_bigcode", + "name": "Gpt Bigcode", + "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + PHI_MOE = { + "type": "phimoe", + "name": "PhiMoe", + "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + FALCON = { + "type": "falcon", + "name": "Falcon", + "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", + } + STARCODER2 = { + "type": "starcoder2", + "name": "StarCoder 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", + } + QWEN2_VL = { + "type": "qwen2_vl", + "name": "Qwen 2 VL", + "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", + } + QWEN2_5_VL = { + "type": "qwen2_5_vl", + "name": "Qwen 2.5 VL", + "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", + } + OPT = { + "type": "opt", + "name": "Opt", + "url": "https://huggingface.co/facebook/opt-6.7b", + } + T5 = { + "type": "t5", + "name": "T5", + "url": "https://huggingface.co/google/flan-t5-xxl", + } + GALACTICA = { + "type": "galactica", + "name": "Galactica", + "url": "https://huggingface.co/facebook/galactica-120b", + } + SANTACODER = { + "type": "santacoder", + "name": "SantaCoder", + "url": "https://huggingface.co/bigcode/santacoder", + } + BLOOM = { + "type": "bloom", + "name": "Bloom", + "url": "https://huggingface.co/bigscience/bloom-560m", + } + MPT = { + "type": "mpt", + "name": "Mpt", + "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", + } + GPT2 = { + "type": "gpt2", + "name": "Gpt2", + "url": "https://huggingface.co/openai-community/gpt2", + } + GPT_NEOX = { + "type": "gpt_neox", + "name": "Gpt Neox", + "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", + } + GPTJ = { + "type": "gptj", + "name": "Gptj", + "url": "https://huggingface.co/EleutherAI/gpt-j-6b", + } + IDEFICS = { + "type": "idefics", + "name": "Idefics", + "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", + "multimodal": True, + } + MLLAMA = { + "type": "mllama", + "name": "Mllama", + "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct", + "multimodal": True, + } + + +__GLOBALS = locals() +for data in ModelType: + __GLOBALS[data.name] = data.value["type"] # Disable gradients torch.set_grad_enabled(False) @@ -54,7 +361,7 @@ def get_model( trust_remote_code: bool, max_input_tokens: int, ) -> Model: - adapt_transformers_to_gaudi() + global FLASH_ATTENTION if speculate is not None: set_speculate(speculate) @@ -178,7 +485,389 @@ def get_model( if model_type == "gpt_bigcode": return StarCoder(model_id=model_id, revision=revision, dtype=dtype) + kv_cache_dtype = dtype + if FLASH_ATTENTION: + if model_type == DEEPSEEK_V2: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif model_type == DEEPSEEK_V3: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV3Config, + head_size=head_size, + ) + + elif ( + model_type == GPT_BIGCODE + or model_type == GPT2 + and model_id.startswith("bigcode/") + ): + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, + ) + elif model_type == GPT2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GPTJ: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTJForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GPT_NEOX: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, + ) + elif model_type == PHI: + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == PHI_MOE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + config_class=PhiMoEConfig, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == BAICHUAN: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GEMMA: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GEMMA2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == COHERE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == DBRX: + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, + ) + elif ( + model_type in ["RefinedWeb", "RefinedWebModel", FALCON] + and not sharded + and not config_dict.get("alibi", False) + ): + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, + ) + elif model_type == MISTRAL: + return FlashCausalLM( + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == MIXTRAL: + return FlashCausalLM( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == STARCODER2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == IDEFICS: + return IdeficsCausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == QWEN2_VL: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2_5_VL: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2_5VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=Qwen2_5_VLConfig, + processor_class=Qwen2_5_VLProcessor, + ) + elif model_type == MLLAMA: + return MllamaCausalLM( + model_id=model_id, + model_class=MllamaForConditionalGeneration, + batch_class=MllamaCausalLMBatch, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == IDEFICS2: + return VlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + ) + elif model_type == IDEFICS3: + return VlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 1456}}, + ) + elif model_type == PALIGEMMA: + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, + ) + elif model_type == LLAVA_NEXT: + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + ) + adapt_transformers_to_gaudi() if model_type == "bloom": return BLOOM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e2719fad..84835ab8 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -377,7 +377,7 @@ class BloomAttention(nn.Module): past_value.view(-1, *past_value.shape[-2:]), ) - if CUSTOM_KERNELS_ENABLED: + if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096: assert self.training is False, "Only foward pass was implemented" assert ( attention_mask.shape[-1] < 4096 @@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel): @staticmethod def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: """ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 30656038..44df7964 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,10 +28,9 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -39,7 +38,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -47,11 +45,10 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight - -if SYSTEM == "cuda": - import dropout_layer_norm -else: - dropout_layer_norm = None +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) class CohereRotary(PositionRotaryEmbedding): @@ -63,38 +60,25 @@ class CohereRotary(PositionRotaryEmbedding): sin: torch.Tensor, ): # Such controlflows may add some overhead. - if SYSTEM == "cuda": - import rotary_emb + num_tokens = query.shape[0] + head_size = query.shape[-1] + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - q1 = query[..., ::2] - q2 = query[..., 1::2] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., ::2] - k2 = key[..., 1::2] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - from vllm._C import ops - - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, False) - elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), False - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) class CohereLayerNorm(nn.Module): @@ -107,49 +91,18 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": - hidden_states = hidden_states.reshape( - -1, self.weight.shape[0], self.weight.shape[1] - ) - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - hidden_states_minus_mean = hidden_states - mean - variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) - hidden_states = self.weight.to(torch.float32) * hidden_states - hidden_states = hidden_states.view(-1, self.weight.shape[1]) - return hidden_states.to(input_dtype) - - ( - hidden_states, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - None, - self.ones, - None, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - - # Required to apply one weight matrix per head - hidden_states = hidden_states.view( + hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) - hidden_states = self.weight * hidden_states + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + hidden_states_minus_mean = hidden_states - mean + variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) + hidden_states = self.weight.to(torch.float32) * hidden_states hidden_states = hidden_states.view(-1, self.weight.shape[1]) - - return hidden_states + return hidden_states.to(input_dtype) def load_attention(config, prefix, weights): @@ -229,6 +182,7 @@ class FlashCohereAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: @@ -291,30 +245,37 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 1137a453..ba86f579 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,17 +20,13 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales -if SYSTEM != "ipex": - from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, - PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( FastLinear, @@ -48,6 +44,9 @@ from text_generation_server.layers.layernorm import ( ) +moe_kernels = None + + class DbrxAttentionConfig(PretrainedConfig): def __init__( self, @@ -290,6 +289,7 @@ class DbrxAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -330,30 +330,37 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -485,7 +492,8 @@ class BlockSparseMoE(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( + + out = moe_kernels.fused_moe( x, self.wv1, self.w2, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 88c2cf80..e30510b4 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,21 +33,13 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - class DeepseekV2Config(PretrainedConfig): def __init__( @@ -232,6 +224,8 @@ class DeepseekV2Attention(torch.nn.Module): ), ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -260,7 +254,7 @@ class DeepseekV2Attention(torch.nn.Module): cos: torch.Tensor, sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_cache: KVCache, block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, @@ -321,30 +315,37 @@ class DeepseekV2Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) # Remove padding. @@ -387,27 +388,11 @@ class DeepseekV2MLP(nn.Module): self.quantize = config.quantize def forward(self, hidden_states: torch.Tensor, reduce: bool = True): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out, reduce=reduce) - else: - gate_up_states = self.gate_up_proj(hidden_states) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce - ) + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) class DeepseekV2MoE(nn.Module): diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py new file mode 100644 index 00000000..452fe3f2 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -0,0 +1,653 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, +) +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.weights import Weights + + +class DeepseekV3Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekV3Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: KVCache, + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + kv_scales=self.kv_scales, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV3MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class DeepseekV3MoE(nn.Module): + def __init__( + self, + prefix, + config: DeepseekV3Config, + moe_layer_cls: Type[MoELayer], + weights, + ): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = torch.zeros( + config.n_routed_experts, device=weights.device + ) + else: + self.gate.e_score_correction_bias = None + + self.moe_layer = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.n_routed_experts, + n_expert_group=config.n_group, + renormalize=config.norm_topk_prob, + topk=config.num_experts_per_tok, + topk_group=config.topk_group, + weights=weights, + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + assert isinstance(self.moe_layer, MoELayer) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV3MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + + out = self.moe_layer(x, gating_output=router_logits) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DeepseekV3Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV3Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + else: + self.mlp = DeepseekV3MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV3Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV3Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV3ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV3Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 7a3d60c9..ebf1b80e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -40,7 +39,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -208,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -253,19 +253,25 @@ class FlashGemma2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, - causal=self.causal, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, ) @@ -273,14 +279,14 @@ class FlashGemma2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, softcap=self.softcap, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4c1be6f6..ad3be80e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,9 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, - PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -39,6 +37,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -187,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -224,31 +224,38 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, causal=self.causal, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 44c015cf..906b34c1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,11 +24,9 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -38,6 +36,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -195,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module): head_size=self.head_size, num_heads=self.num_heads, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -224,30 +224,37 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index aca97004..c23aa07f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,11 +24,10 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -38,13 +37,16 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) def load_attention(config, prefix: str, weights): @@ -78,39 +80,25 @@ class GPTJRotary(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - # Such controlflows may add some overhead. - if SYSTEM == "cuda": - import rotary_emb + num_tokens = query.shape[0] + head_size = query.shape[-1] + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - q1 = query[..., ::2] - q2 = query[..., 1::2] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., ::2] - k2 = key[..., 1::2] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - from vllm._C import ops - - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, False) - elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), False - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) class FlashGPTJAttention(torch.nn.Module): @@ -140,6 +128,7 @@ class FlashGPTJAttention(torch.nn.Module): prefix=prefix, weights=weights, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -186,30 +175,37 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c9ec70cc..a118ace5 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,14 +27,16 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention import ( + KVCache, + get_kv_scales, +) from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -57,15 +59,6 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -if SYSTEM != "ipex": - pass - -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. @@ -157,7 +150,10 @@ class FlashLlamaAttention(torch.nn.Module): device=weights.device, ) - self.softmax_scale = self.head_size**-0.5 + # `config.attention_multiplier` is used in Granite + self.softmax_scale = getattr( + config, "attention_multiplier", self.head_size**-0.5 + ) if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -177,11 +173,13 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index + self.kv_scales = get_kv_scales(weights, f"{prefix}") + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=False, + bias=getattr(config, "attention_bias", False), ) self.o_proj = TensorParallelAdapterRowLinear.load( @@ -202,12 +200,13 @@ class FlashLlamaAttention(torch.nn.Module): cos, sin, cu_seqlen_prefill, - kv_cache, + kv_cache: KVCache, block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -222,30 +221,42 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_scales=self.kv_scales, + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -363,31 +374,11 @@ class LlamaMLP(nn.Module): self.hidden_size = config.hidden_size def forward(self, hidden_states, adapter_data): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - and self.hidden_size - != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu( - self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 - ) - return self.down_proj(out, adapter_data) - else: - gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - ) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): @@ -408,7 +399,7 @@ class FlashLlamaLayer(nn.Module): if SparseMoELayer.is_supported(weights) else DenseMoELayer ) - self.dense = Phi3MoE( + self.mlp = Phi3MoE( f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights ) # with moe the layernorms are are not rmsnorms and they have bias @@ -423,7 +414,7 @@ class FlashLlamaLayer(nn.Module): eps=config.rms_norm_eps, ) else: - self.dense = LlamaMLP( + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) self.input_layernorm = FastRMSNorm.load( @@ -437,6 +428,11 @@ class FlashLlamaLayer(nn.Module): eps=config.rms_norm_eps, ) + # Used in Granite + # This could eventually be baked into the weights like we do for the embeddings/lm_head + # but this would mean modifying the lora code + self.residual_multiplier = getattr(config, "residual_multiplier", None) + def forward( self, hidden_states, @@ -448,9 +444,10 @@ class FlashLlamaLayer(nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -464,16 +461,20 @@ class FlashLlamaLayer(nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) + if self.residual_multiplier is not None: + attn_output *= self.residual_multiplier - # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( attn_output, res ) - mlp_output = self.dense(normed_attn_res_output, adapter_data) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) + if self.residual_multiplier is not None: + mlp_output *= self.residual_multiplier return mlp_output, attn_res @@ -493,9 +494,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=f"{prefix}.layers.0", config=config, weights=weights, ) @@ -511,11 +510,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaCrossLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -524,11 +519,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -539,18 +530,14 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=last_layer_id, - prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.model.layers.{last_layer_id}" - ), + prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, ) ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -570,19 +557,16 @@ class FlashLlamaModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], adapter_data, cross_attention_states=None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -596,9 +580,10 @@ class FlashLlamaModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, + prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -607,31 +592,51 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}.model.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" + # Used in Granite + embedding_multiplier = getattr(config, "embedding_multiplier", None) + if embedding_multiplier is not None: + self.embed_tokens.weight.data *= embedding_multiplier + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, + prefix, + weights, ) + # Used in Granite + self.logits_scaling = getattr(config, "logits_scaling", None) + if self.logits_scaling is not None and self.lm_head.head is not None: + try: + # Scale the weights directly + self.lm_head.head.linear.weight.data /= self.logits_scaling + self.logits_scaled = True + except Exception: + self.logits_scaled = False + def forward( self, input_ids: torch.Tensor, @@ -641,11 +646,11 @@ class FlashLlamaForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -656,13 +661,19 @@ class FlashLlamaForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) + + # Used in Granite + if self.logits_scaling is not None and not self.logits_scaled: + logits /= self.logits_scaling + if speculative_logits is not None: + speculative_logits /= self.logits_scaling + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 341a2352..a0116297 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,11 +26,10 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -41,20 +40,12 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - - class MistralConfig(PretrainedConfig): model_type = "mistral" @@ -160,6 +151,7 @@ class MistralAttention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -210,33 +202,38 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( @@ -300,29 +297,11 @@ class MistralMLP(nn.Module): self.quantize = config.quantize def forward(self, hidden_states, adapter_data): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu( - self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 - ) - return self.down_proj(out, adapter_data) - else: - gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - ) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 5836d30a..a45dd1e6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,9 +37,8 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding @@ -215,6 +214,7 @@ class MixtralAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -258,33 +258,38 @@ class MixtralAttention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ad4e382f..2301b63c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,7 +29,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -39,7 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -132,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module): head_size=self.head_size, hidden_size=self.hidden_size, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) @@ -165,30 +165,37 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=qkv[:, 1], + value=qkv[:, 2], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - qkv[:, 0], - kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], - kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], - seqlen, - block_tables, - self.softmax_scale, + query=qkv[:, 0], + key=qkv[:, 1], + value=qkv[:, 2], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( qkv[:, 0], - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0024f2bb..b1f89eff 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 2a0dc606..7382a7cb 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,7 +9,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -19,7 +18,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -139,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") # in llama the dense layer is called "o_proj" and has bias=False self.dense = TensorParallelRowLinear.load( @@ -188,29 +188,36 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_scales=self.kv_scales, + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 02c788d3..d6569a1d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,7 +8,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -17,7 +16,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -86,6 +85,8 @@ class Qwen2Attention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -128,33 +129,38 @@ class Qwen2Attention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -229,7 +235,7 @@ class Qwen2Layer(nn.Module): max_s, prefill_cache_indices, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + normed_hidden_states, residual = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( @@ -244,15 +250,13 @@ class Qwen2Layer(nn.Module): max_s, prefill_cache_indices, ) + hidden_states = attn_output + residual # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) - - mlp_output = self.mlp(normed_attn_res_output) - - return mlp_output, attn_res + hidden_states, residual = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states) + hidden_states = mlp_output + residual + return hidden_states class Qwen2Model(torch.nn.Module): @@ -264,9 +268,6 @@ class Qwen2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ Qwen2Layer( @@ -290,7 +291,7 @@ class Qwen2Model(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -301,17 +302,17 @@ class Qwen2Model(torch.nn.Module): true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = inputs_embeds - # Get rotary cos and sin for this forward - # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype + position_ids, + true_max_s, + hidden_states.dtype, ) residual = None for i, layer in enumerate(self.layers): - hidden_states, residual = layer( + hidden_states = layer( hidden_states, residual, cos, @@ -325,7 +326,7 @@ class Qwen2Model(torch.nn.Module): prefill_cache_indices, ) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states) return hidden_states @@ -346,6 +347,12 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + self.max_past = config.sliding_window self.max_past_tensor = ( torch.tensor(config.sliding_window, device=weights.device) @@ -376,8 +383,10 @@ class Qwen2ForCausalLM(torch.nn.Module): # kernel requires the true values seqlen = seqlen.clamp(max=self.max_past_tensor) + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6671d85e..fbf1a597 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,13 +12,12 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, - reshape_and_cache, Seqlen, ) @@ -79,6 +78,7 @@ class RWConfig(PretrainedConfig): self.alibi = False self.rotary = True self.rope_theta = rope_theta + self.max_position_embeddings = 2048 self.vocab_size = vocab_size # Backward compatibility with n_embed kwarg @@ -160,6 +160,7 @@ class FlashRWAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -200,30 +201,37 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -278,6 +286,7 @@ class FlashRWLargeAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -312,36 +321,37 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - reshape_and_cache( - kv[:, :, 0].contiguous(), - kv[:, :, 1].contiguous(), - kv_cache[0], - kv_cache[1], - slots, + kv_cache.store( + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, :, 0], + value=kv[:, :, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 43eb9687..ed053eb6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,7 +8,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -18,7 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, @@ -259,6 +258,7 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_head_mapping = torch.zeros( self.num_heads, dtype=torch.int32, device=weights.device ) @@ -284,32 +284,37 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - reshape_and_cache( - key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=key_value[:, 0], + value=key_value[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key_value[:, 0], + value=key_value[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 4975cf22..5e090369 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,17 +29,18 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -110,17 +111,31 @@ class Starcoder2Config(PretrainedConfig): ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=config.use_bias, ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): @@ -158,6 +173,7 @@ def _load_gqa(config, prefix: str, weights): class Starcoder2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -189,14 +205,23 @@ class Starcoder2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) + self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=config.use_bias, + bias=getattr(config, "use_bias", False), ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -233,40 +259,47 @@ class Starcoder2Attention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Starcoder2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -280,27 +313,42 @@ class Starcoder2MLP(nn.Module): ) ) # Fuse gate and up proj - self.c_fc = TensorParallelColumnLinear.load( + c_fc = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.c_fc", weights=weights, bias=config.use_bias, ) - self.c_proj = TensorParallelRowLinear.load( + c_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.c_proj", weights=weights, bias=config.use_bias, ) - def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id=index, + layer_names=[f"{prefix}.c_fc"], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.c_proj = TensorParallelAdapterRowLinear.load( + c_proj, + index, + "c_proj", + process_group=weights.process_group, + ) + + def forward(self, hidden_states, adapter_data): + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - return self.c_proj(hidden_states) + return self.c_proj(hidden_states, adapter_data) class Starcoder2GatedMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -314,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=config.use_bias, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=config.use_bias, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) STARCODER2_NORMALIZATION_CLASSES = { @@ -353,11 +421,11 @@ class Starcoder2Layer(nn.Module): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id ) self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( @@ -384,6 +452,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -399,6 +468,7 @@ class Starcoder2Layer(nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -406,7 +476,7 @@ class Starcoder2Layer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -453,6 +523,7 @@ class Starcoder2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -476,6 +547,7 @@ class Starcoder2Model(torch.nn.Module): seqlen, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -547,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index a829c374..923123d6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module): # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py new file mode 100644 index 00000000..580398cb --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -0,0 +1,584 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics3 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics3VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics3VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics3VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics3EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics3VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics3VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics3Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics3EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics3VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics3VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics3Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics3SimpleMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj + + def forward(self, x): + return self.proj(x) + + +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight` + # since Idefics3 uses the `embed_tokens` for the final prediction + # config.text_config.tie_word_embeddings = True + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics3VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py index fc6becc4..a130dbc1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -46,15 +46,9 @@ from text_generation_server.layers import ( FastLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.utils.import_utils import SYSTEM from loguru import logger -if SYSTEM == "cuda": - import dropout_layer_norm -elif SYSTEM == "rocm": - from vllm._C import ops -else: - dropout_layer_norm = None +dropout_layer_norm = None @dataclass @@ -351,94 +345,18 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex + from vllm_hpu_extension.kernels import rms_norm - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - residual is not None, - ) - return out - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - elif SYSTEM == "cuda": - # faster post attention rms norm - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) - - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - if unwrap: - normed_hidden_states = normed_hidden_states.view(*shape) - - return normed_hidden_states - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - - if unwrap: - out = out.view(*shape) - - return out + orig_shape = hidden_states.shape + if residual is not None: + residual += hidden_states.view(residual.shape) else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + residual = hidden_states + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + if len(orig_shape) == 2: + residual = residual.unsqueeze(0) + x = rms_norm().apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual.view(orig_shape) # this was adapted from LlamaMLP diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py index f98dab91..df7366ea 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,17 +14,25 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.utils.checkpoint +from torch import nn -from transformers.models.llava_next.modeling_llava_next import ( - unpad_image, -) -from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration +from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution +from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -32,7 +40,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Args: image_size (`tuple`): - The size of the input image in the format (width, height). + The size of the input image in the format (height, width). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. @@ -40,7 +48,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): The size of each image patch. Returns: - tuple: The shape of the image patch grid in the format (width, height). + tuple: The shape of the image patch grid in the format (height, width). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") @@ -49,13 +57,100 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext +class LlavaNextMultiModalProjector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.linear_1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaNextForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + vision_config = config.vision_config + # Instead of selecting in hidden_states[-2]. + # Instead compute only the n -2 + 1 layers and don't pool + if config.vision_feature_layer < 0: + vision_config.num_hidden_layers += config.vision_feature_layer + 1 + else: + vision_config.num_hidden_layers = config.vision_feature_layer + 1 + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.multi_modal_projector = LlavaNextMultiModalProjector( + prefix="multi_modal_projector", config=config, weights=weights + ) + + self.image_newline = weights.get_tensor("image_newline") + + self.vocab_size = config.text_config.vocab_size + self.config = config + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) def _merge_input_ids_with_image_features( self, + input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: torch.Tensor, - input_ids: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" mask = input_ids == self.config.image_token_index @@ -70,273 +165,126 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, + # Unused for this model + pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None and len(pixel_values) > 0: + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings - if token_idx is not None: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) + image_features = self.vision_tower(pixel_values) - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state - logits = outputs[0] - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if ( - past_key_values is None - and pixel_values is not None - and input_ids.shape[1] != 1 - ): - vision_feature_select_strategy = kwargs.get( - "vision_feature_select_strategy", None - ) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - batch_size, num_patches, num_channels, height, width = ( - pixel_values.shape - ) - reshaped_pixel_values = pixel_values.view( - batch_size * num_patches, num_channels, height, width - ) - image_features = self.vision_tower( - reshaped_pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - selected_image_feature = image_features.hidden_states[ - vision_feature_layer - ] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx].tolist(), - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute( - 4, 0, 2, 1, 3 - ).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image( - image_feature, image_sizes[image_idx] - ) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - inputs_embeds = self._merge_input_ids_with_image_features( - inputs_embeds, image_features, input_ids - ) - self.image_offset = ( - image_features.shape[1] - 1 - ) # image_token has occupied 1 token position. - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) - - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = ( - torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature else: - model_inputs = {"input_ids": input_ids} + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size ) - return model_inputs + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_features + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 293051c2..5a9c0588 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -196,7 +196,10 @@ class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + try: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights) + except RuntimeError: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) @@ -206,7 +209,10 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + try: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) + except RuntimeError: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) self.config = config def forward( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py index 73536bd6..e040a542 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py @@ -19,7 +19,10 @@ from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn -import flash_attn_2_cuda + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + from transformers.activations import ACT2FN import torch.nn.functional as F @@ -488,9 +491,14 @@ class MllamaVisionModel(nn.Module): aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width @@ -698,29 +706,24 @@ class MllamaTextCrossAttention(nn.Module): # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) - attn_output = flash_attn_2_cuda.varlen_fwd( + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( query_states, key_states, value_states, - None, - cu_seqlen_q, - cu_seqlen_k, - None, - None, - None, # block_tables - None, - max_q, - max_k, - 0.0, - self.softmax_scale, - False, - causal, # Causal - -1, # window_size_left, - -1, - 0.0, # softcap - False, - None, - )[0] + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return attn_output diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py index bd440321..db73ae84 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch OPT model.""" +"""PyTorch OPT model.""" + import random from typing import List, Optional, Tuple, Union @@ -99,7 +100,7 @@ class OPTLearnedPositionalEmbedding(nn.Module): self.offset = 2 self.weight = nn.Parameter( weights.get_tensor( - f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + f"{prefix if prefix else ''}decoder.embed_positions.weight" ) ) @@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel): self.layers = nn.ModuleList( [ - OPTDecoderLayer(layer_id, prefix, config, weights) + OPTDecoderLayer( + layer_id, + prefix=f"{prefix}decoder.layers.{layer_id}", + config=config, + weights=weights, + ) for layer_id in range(config.num_hidden_layers) ] ) @@ -755,6 +760,8 @@ class OPTModel(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, prefix, config, weights): super().__init__(config) + if not prefix and any(s.startswith("model") for s in weights.routing.keys()): + prefix = "model" self.model = OPTModel(prefix, config, weights) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py new file mode 100644 index 00000000..efd9cccd --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -0,0 +1,947 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +import numpy as np + +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +from typing import Union +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ( + ProcessingKwargs, + ProcessorMixin, + Unpack, + VideosKwargs, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, chat_template=None, **kwargs + ): + self.image_token = ( + "<|image_pad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.video_token = ( + "<|video_pad|>" + if not hasattr(tokenizer, "video_token") + else tokenizer.video_token + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[ + TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] + ] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor( + images=images, videos=None, **output_kwargs["images_kwargs"] + ) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor( + images=None, videos=videos, **output_kwargs["images_kwargs"] + ) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [ + self.image_processor.temporal_patch_size / fps + ] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [ + self.image_processor.temporal_patch_size / tmp for tmp in fps + ] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" + * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" + * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list( + dict.fromkeys(tokenizer_input_names + image_processor_input_names) + ) + return names_from_processor + ["second_per_grid_ts"] + + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + spatial_patch_size=14, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_patch_size = spatial_patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if vision_config is not None: + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute sdpa + query = query.unsqueeze(0).transpose(1, 2) + key = key.unsqueeze(0).transpose(1, 2) + value = value.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + self.up = TensorParallelColumnLinear.load( + prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True + ) + self.gate = TensorParallelColumnLinear.load( + prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True + ) + self.down = TensorParallelRowLinear.load( + prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_states = self.gate(hidden_states) + up_states = self.up(hidden_states) + activated_states = self.activation_fn(gate_states) * up_states + down_states = self.down(activated_states) + return down_states + + +class Qwen2_5VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2_5VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastRMSNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastRMSNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2_5VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, _ = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = hidden_states + attn_out + norm2_out, _ = self.norm2(hidden_states) + mlp_out = self.mlp(norm2_out) + hidden_states = hidden_states + mlp_out + return hidden_states + + +class Qwen2_5VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastRMSNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2_5VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_channels, + out_channels=config.hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.hidden_size // config.num_heads + + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2_5VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2_5VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + self.window_size = config.window_size + self.patch_size = config.patch_size + self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size + self.fullatt_block_indexes = config.fullatt_block_indexes + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + seq_len = hidden_states.shape[0] + patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + og_shape = (seq_len, -1) + + hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view( + og_shape + ) + rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view( + og_shape + ) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + + # iterately apply the blocks to the hidden states + for layer_num, block in enumerate(self.blocks): + # NOTE: qwen2_5_vl.py has a concept of full attention blocks + # that are applied at specific layers. + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = block( + hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen + ) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + +class Qwen2_5VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment + if ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and config.rope_scaling.get("type", None) == "default" + ): + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2_5VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.device = weights.device + + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + # import ipdb + + # ipdb.set_trace() + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + # Unused in this model + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + inputs_embeds[input_ids == self.image_token_id] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py new file mode 100644 index 00000000..b32ab577 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -0,0 +1,522 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +import numpy as np + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.embed_dim // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute sdpa + query = query.unsqueeze(0).transpose(1, 2) + key = key.unsqueeze(0).transpose(1, 2) + value = value.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastLayerNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastLayerNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, residual = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = attn_out + residual + norm2_out, residual = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(norm2_out) + return hidden_states + + +class Qwen2VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.embed_dim * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastLayerNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.embed_dim // config.num_heads + # TODO: replace with static positional embeddings once implemented + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.embed_dim + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + # iterately apply the blocks to the hidden states + for block in self.blocks: + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + return hidden_states + + +class Qwen2VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2-VL model at the moment + if ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and config.rope_scaling.get("type", None) == "default" + ): + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.device = weights.device + + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + inputs_embeds[input_ids == self.image_token_id] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py index e5c44045..94b8522d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, @@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None): ) return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + elif config.model_type == "gemma2": + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + + return FlashGemma2ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index bc9d44a0..49313c83 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext import math import os import time @@ -16,12 +15,19 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict +from typing import ( + Any, + Iterable, + Optional, + Tuple, + List, + Type, + Dict, + Union, +) from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens @@ -39,14 +45,18 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( - MEM_POOL, - ATTENTION, BLOCK_SIZE, - CUDA_GRAPHS, + REQUEST_LOGPROBS, TGI_WIGGLE_ROOM, get_adapter_to_index, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import ( + KVCache, + Seqlen, + HPUPagedAttentionMetadata, + trim_attn_metadata, + trim_seqlen_metadata, +) from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader @@ -58,13 +68,19 @@ from text_generation_server.utils.import_utils import ( get_free_memory, ) -tracer = trace.get_tracer(__name__) +import vllm_hpu_extension.environment as environment +import habana_frameworks.torch as htorch +tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None +def small_power_of_2(n: int): + return 1 << ((n - 1).bit_length() - 1) + + def set_sliding_window(sliding_window: int): global SLIDING_WINDOW SLIDING_WINDOW = sliding_window @@ -117,25 +133,17 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] - - # Paged Attention values - # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] @@ -143,19 +151,32 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor + # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch + # used for filtering + cu_slots: torch.Tensor - max_seqlen: int + max_input_length: int + max_current_length: int + + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] - - # Prefixes - prefix_ids: List[List[int]] + # Will be set by `generate_token` and reset after each prefill forward + prefill_logprob_tokens: List[Optional[Tokens]] # All tokens all_input_ids: List[List[int]] @@ -163,7 +184,14 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - input_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + cache_lengths: List[int] + prompt_lengths: List[int] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -174,19 +202,27 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Adapter metadata for each request - adapter_meta: AdapterBatchMetadata + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + adapter_meta: Optional[AdapterBatchMetadata] # Number of blocks in this batch num_blocks: int # Maximum number of blocks max_blocks: int + hpu_attn_meta: Optional[HPUPagedAttentionMetadata] + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) @classmethod @@ -218,86 +254,67 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - sliding_window = get_sliding_windows() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - + cache_lengths = [] input_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] - prefix_ids = [] + all_postfix_ids = [] requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] + slots = [] + cu_slots = [0] next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 + cu_blocks = [0] block_tables = [] - slots = [] - prefix_lens = [] + block_tables_ragged = [] # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): + ### XXX: This consumes so much memory on long requests + ### Deactivating it by default seems like the best course. + if not REQUEST_LOGPROBS: + r.prefill_logprobs = False # request id -> idx in list mapping requests_idx_mapping[r.id] = i - orig_input_length = len(tokenized_input) + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) + + cache_length = r.cache_len - prefix_len = r.prefix_len assert ( - prefix_len <= orig_input_length - ), f"Prefix {prefix_len} vs input {orig_input_length}" - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: + assert False, "unreachable" - # Commented as it's costly. - # log_master(logger.debug, "Tokenized input ids {tokenized_input}") - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - input_length = len(tokenized_input) input_lengths.append(input_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( @@ -307,22 +324,13 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((input_length,), adapter_index)) - adapter_set.add(adapter_index) - # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length - - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + block_tokens = prompt_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: @@ -337,70 +345,30 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length - ] + request_slots = r.slots block_tables.append(request_blocks) + block_tables_ragged.extend(request_blocks) + cu_blocks.append(len(block_tables_ragged)) slots.extend(request_slots) - prefix_lens.append(prefix_len) + cu_slots.append(len(slots)) + + cache_lengths.append(cache_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length - cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) max_length = max( - max_length, input_length + max_new_tokens + speculative_length + max_length, + prompt_length + max_new_tokens + speculative_length, ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -414,103 +382,71 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor, dtype=torch.int64, device=device ) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - - block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + block_tables_ragged = torch.tensor( + block_tables_ragged, device=device, dtype=torch.int32 ) + cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_tensor = torch.empty( + (len(block_tables), max_blocks), + device=device, + dtype=torch.int32, + ) + for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) + + slots = torch.tensor(slots, dtype=torch.int64, device=device) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, - slot_indices=slot_indices, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + cache_lengths=cache_lengths, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + slot_indices=None, + slots=slots, + cu_slots=cu_slots, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, + adapter_meta=None, + hpu_attn_meta=None, ) @classmethod @@ -533,7 +469,7 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -548,18 +484,23 @@ class FlashCausalLMBatch(Batch): # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 requests = [] - start_slots = [] block_tables = [] all_input_ids = [] - prefix_ids = [] + input_ids = [] + prompt_lengths = [] input_lengths = [] - prefix_lens = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] + cu_slots = [0] + + prefilling_mask = [] + prefill_logprob_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -567,8 +508,8 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 + max_slots = 0 + cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -577,16 +518,23 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling_mask.append(request_prefilling) + # Get length request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] - max_seqlen = max(max_seqlen, request_input_length) + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) + max_current_length = max( + max_current_length, request_cache_length + request_input_length + ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) + prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -594,60 +542,78 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + start_slot = self.cu_slots[idx] + end_slot = self.cu_slots[idx + 1] + slot_length = end_slot - start_slot # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True + slot_filtering_indices[start_slot:end_slot] = True - cumulative_max_length += request_input_length + remaining_tokens - 1 + cu_slots.append(cumulative_slot_tokens + slot_length) + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_slot_tokens + request_cache_length + + cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + slots = self.slots[slot_filtering_indices] - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] + + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -657,24 +623,29 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + cu_slots=cu_slots, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=self.prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -682,12 +653,8 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + hpu_attn_meta=None, ) @classmethod @@ -697,74 +664,105 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) + max_blocks = max(max_blocks, b.max_blocks) total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + max_input_length = max(max_input_length, b.max_input_length) + max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - input_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - b.input_lengths, b.stopping_criterias + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias ) ), ) + prefilling = prefilling or b.prefilling - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64) + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + if ( + batches[0].position_ids is not None + and batches[0].position_ids.dim() == 2 + ): + # Qwen2_vl case: + position_ids = batches[0].position_ids.new_empty( + (total_batch_size, batches[0].position_ids.shape[-1]) + ) + else: + position_ids = batches[0].position_ids.new_empty(total_batch_size) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() + + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - start_slots = [] block_tables = [] - prefix_lens = [] + cache_lengths = [] all_input_ids = [] - prefix_ids = [] + prompt_lengths = [] input_lengths = [] prefix_offsets = [] read_offsets = [] + prefill_logprob_tokens = [] + next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -783,32 +781,9 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices - ) - all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -816,20 +791,56 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + slots[slots_start_index:slots_end_index] = batch.slots + cu_slots[start_index + 1 : end_index + 1] = ( + batch.cu_slots[1:] + cumulative_slots + ) - start_slots.append(batch.start_slots + cumulative_slots) + if not prefilling: + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) + + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) @@ -837,12 +848,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens.extend(batch.top_n_tokens) # Update - cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) - - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + cumulative_batch_size += len(batch) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -852,13 +859,21 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states=fsm_grammar_states, ) - speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) - if batches[0].speculative_ids is not None - else None - ) + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: + speculative_ids = None - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -868,24 +883,29 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + cu_slots=cu_slots, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -893,12 +913,363 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + hpu_attn_meta=None, + ) + + def prepare_for_decode(self, dtype, use_contiguous_pa): + # Prepare values if we need to continue decoding + # need for HPUPagedAttentionMetadata preparation + import itertools + from vllm_hpu_extension.ops import batch2block, block2batch + + def flatten(in_list): + return list(itertools.chain(*in_list)) + + def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + def pad_list(input, k, v): + input_len = len(input) + target_len = (input_len + k - 1) // k * k + padding = target_len - input_len + return input + [v] * padding + + device = self.block_tables_tensor.device + last_block_usage = self.slots[self.slot_indices] % BLOCK_SIZE + 1 + block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables = [] + for i, bt in enumerate(self.block_tables): + block_tables.append(bt[0 : block_num[i]]) + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [ + [BLOCK_SIZE] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt + ] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + batch = self.input_ids.size(0) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + if use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + # block_bucket_size) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + block_list = gather_list(block_list, indices, 0) + block_groups = gather_list(block_groups, indices, -1) + block_usage = gather_list(block_usage, indices, 1) + else: + block_bucket_size = len(block_list) + block_list = pad_list(block_list, block_bucket_size, 0) + block_groups = pad_list(block_groups, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + + block_list = torch.tensor(block_list, dtype=torch.int, device=device) + block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) + block_usage = torch.tensor(block_usage, dtype=dtype, device=device) + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch) + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze( + 0 + ) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + ones = torch.ones( + (block_mapping.size(0),), device=device, dtype=block_mapping.dtype + ) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + self.hpu_attn_meta = trim_attn_metadata( + HPUPagedAttentionMetadata( + block_list=block_list, + block_groups=block_groups, + block_usage=block_usage, + block_mapping=block_mapping.to(dtype), + attn_bias=attn_bias, + block_scales=block_scales, + ) + ) + + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + device = self.block_tables_tensor.device + + # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position + # padding to left to work with sliding window + # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate + # the right logit position + + if device.type == "hpu": + input_ids_padded = None + input_ids_padded_length = None + if isinstance(self.input_ids, list) and len(self) > 1: + input_ids_padded = [] + input_ids_padded_length = [] + for input_id in self.input_ids: + padded = self.max_input_length - len(input_id) + input_id_padded = input_id + if padded > 0: + input_id_padded = [0] * padded + input_id_padded + input_ids_padded.append(input_id_padded) + input_ids_padded_length.append(padded) + input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) + input_ids_padded = torch.tensor( + input_ids_padded, dtype=torch.int64, device=device + ) + + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1) + torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0) + self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + + sliding_window = get_sliding_windows() + position_ids = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + blocks, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables, + ) + ): + next_chunk_length = input_length + + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + if device.type == "hpu" and input_ids_padded is not None: + position_ids.append( + torch.ones(input_ids_padded_length[i], dtype=torch.int32) + ) + position_ids.append(request_position_ids) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slot_indices = torch.arange( + cache_length + cumulative_slot_tokens, + cache_length + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + slot_indices.append(request_slot_indices) + + # Update + cumulative_slot_tokens += len(request_slots) + + # Create tensor to slice into the kv tensor in prefill + if device.type == "hpu" and input_ids_padded is not None: + # hpu need request_prefill_cache_indices to skip padding in kv cache + sliding_window = get_sliding_windows() + if sliding_window is None: + sliding_window = input_length + cumulative_length += input_ids_padded_length[i] + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + if ADAPTER_TO_INDEX: + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((next_chunk_length,), adapter_index) + ) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + + if not all_prefill_logprobs and not no_prefill_logprobs: + prefill_head_indices = [] + prefill_next_token_indices = [] + + # Cumulative length + cumulative_length = 0 + prefill_out_cumulative_length = 0 + + for i, ( + r, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.input_lengths, + self.prefilling_mask, + ) + ): + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length + + if len(self) > 1: + if position_ids: + position_ids = torch.cat(position_ids) + if slot_indices: + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + if position_ids: + position_ids = position_ids[0] + if slot_indices: + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + + self.prefill_cu_outlens = prefill_cu_outlens + self.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.cat(prefill_head_indices).to(device) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + if device.type == "hpu" and input_ids_padded is not None: + self.input_ids = input_ids_padded + input_ids_padded_length_tensor = torch.cumsum( + torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), + dim=-1, + ) + if self.prefill_head_indices is not None: + self.prefill_head_indices = ( + self.prefill_head_indices + input_ids_padded_length_tensor + ) + + if self.prefill_next_token_indices is not None: + self.prefill_next_token_indices = ( + self.prefill_next_token_indices + input_ids_padded_length_tensor + ) + + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] + + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, ) def __len__(self): @@ -937,23 +1308,14 @@ class FlashCausalLM(Model): # Deepseek V2 uses different QK and V dims. head_size: Optional[int] = None, skip_special_tokens: bool = True, + kv_cache_dtype: Optional[torch.dtype] = None, + support_chunking: bool = True, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - init_cpu_threads_env(rank_id=rank, world_size=world_size) - else: - raise NotImplementedError(f"{model_class} is only available on GPU") + + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = tokenizer_class.from_pretrained( model_id, @@ -991,7 +1353,7 @@ class FlashCausalLM(Model): weights_loader=weights_loader, ) - prefix = "" + prefix = None model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) @@ -1007,6 +1369,7 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() + self.config = config # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1034,25 +1397,14 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import ( - create_prefill_state, - create_decode_state, - create_prefill_with_paged_kv_state, - ) - - self.prefill_state = create_prefill_state(device=device) - self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( - device=device - ) - - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - + if htorch.utils.internal.is_lazy(): + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) + environment.set_model_config(self.config) + self.use_contiguous_pa = ( + os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" + ) super().__init__( model_id=model_id, model=model, @@ -1063,6 +1415,7 @@ class FlashCausalLM(Model): rank=rank, world_size=world_size, sliding_window=config.sliding_window, + support_chunking=support_chunking, ) @property @@ -1083,174 +1436,35 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "ipex" and device.type == "xpu": - x = 1 - else: - x = BLOCK_SIZE // element_size - - if ATTENTION in {"flashdecoding", "flashinfer"}: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - elif SYSTEM == "ipex" and device == torch.device("cpu"): - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - else: - self.kv_cache = [ - ( - torch.zeros( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.zeros( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = [max_s] * bs - prefix_lengths = [0] * bs - input_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - ) - prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).repeat(bs) - block_tables = block_tables.reshape((bs, max_bt)) - - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lengths, - ) - from text_generation_server.layers.attention.flashinfer import ( - create_decode_state_cuda_graphs, + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, ) + for _ in range(num_layers) + ] - block_tables_ptr = torch.zeros( - bs + 1, dtype=torch.int32, device=self.device - ) - last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) - state = create_decode_state_cuda_graphs( - device=input_ids.device, - block_tables=block_tables, - block_tables_ptr=block_tables_ptr, - last_page_len=last_page_len, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - else: - state = None - - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": self.kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths_tensor, - "prefix_lengths": prefix_lengths_tensor, - "state": state, - "graph": graph, - } - - torch.cuda.synchronize() - # Run once outside to warmup - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, - state=state, - prefix_lens_tensor=prefix_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - del seqlen - - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits - torch.cuda.synchronize() - - def warmup(self, batch: FlashCausalLMBatch): + def warmup( + self, + request: generate_pb2.WarmupRequest, + ): # The warmup batch is the biggest batch we could ever receive + self.kv_cache = [] empty_cache() + max_input_tokens = request.max_input_tokens + max_total_tokens = request.max_total_tokens + batch = self.batch_type.from_pb( + request.batch, self.tokenizer, self.dtype, self.device + ) + + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size try: self.init_kv_cache( @@ -1258,141 +1472,84 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) - max_bt = batch.max_blocks - max_s = max_bt * BLOCK_SIZE - if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - torch.cuda.tunable.tuning_enable(False) - _, batch, _ = self.generate_token(batch) - except torch.cuda.OutOfMemoryError as e: + batch_num_blocks = batch.num_blocks + + num_tokens = batch.to_pb().current_tokens + synchronize(self.device) + free_memory = get_free_memory( + self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM + ) + real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) + log_master( + logger.debug, + f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", + ) + + _, _batch, _ = self.generate_token([batch]) + except Exception: raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"Not enough memory to handle {num_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" - ) from e + ) synchronize(self.device) - - # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) - # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - batch_num_blocks = batch.num_blocks if batch is not None else 0 - + free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) + kv_memory = free_memory num_blocks = ( # Leave 5% for some wiggle room - int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + int(kv_memory // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + batch_num_blocks ) - del batch + log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") + if max_total_tokens is None or max_total_tokens == 0: + max_total_tokens = sum(batch.cache_lengths) + + if max_input_tokens is None or max_input_tokens == 0: + max_input_tokens = max_total_tokens - 1 + + del _batch, batch + self.kv_cache = [] + empty_cache() self.init_kv_cache( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) - if SYSTEM == "rocm": - if ( - os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None - or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" - ): - torch.cuda.tunable.enable() + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": - torch.cuda.tunable.tuning_enable(True) - - if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None: - tuning_sequences = [ - int(val) - for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") - ] - elif CUDA_GRAPHS is not None: - tuning_sequences = CUDA_GRAPHS - else: - tuning_sequences = [1, 2, 3, 4, 5, 6, 7] - - tunableop_filepath = os.path.join( - HUGGINGFACE_HUB_CACHE, - f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", - ) - - log_master( - logger.info, - f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", - ) - - torch.cuda.tunable.set_filename( - tunableop_filepath, insert_device_ordinal=False - ) - - if os.path.isfile(tunableop_filepath): - log_master( - logger.info, - f"The file {tunableop_filepath} already exists and will be reused.", - ) - torch.cuda.tunable.read_file(tunableop_filepath) - - os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) - - for seqlen in tuning_sequences: - log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) - torch.cuda.tunable.write_file(tunableop_filepath) - if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": - torch.cuda.tunable.tuning_enable(False) - else: - log_master( - logger.info, - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", - ) - - if CUDA_GRAPHS: - try: - log_master( - logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" - ) - # Warmup cuda graphs - for bs in CUDA_GRAPHS: - if self.speculate is None or self.speculate + 1 <= bs: - self.cuda_graph_warmup(bs, max_s, max_bt) - except torch.cuda.OutOfMemoryError: - logger.exception("Decode cuda graph warmup failed") - else: - log_master( - logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." - ) - - return int(num_blocks * BLOCK_SIZE) - - def tunableop_warmup(self, seqlen: int): + def tunableop_warmup(self, seqlen: int, max_bt: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros( + seqlen, dtype=torch.int32, device=self.device + ) cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) - max_s = seqlen + + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(seqlen) + block_tables = block_tables.reshape((seqlen, max_bt)) + seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=1, - max_k=seqlen, + cache_lengths=cache_lengths_tensor, ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. @@ -1401,10 +1558,9 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, + block_tables=block_tables, seqlen=seqlen, slots=slots, - max_s=max_s, lm_head_indices=None, prefill_cache_indices=None, ) @@ -1421,7 +1577,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -1436,12 +1592,20 @@ class FlashCausalLM(Model): new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated + slot_indices = ( + batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + slots = batch.slots[slot_indices] + input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1463,8 +1627,8 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -1473,103 +1637,51 @@ class FlashCausalLM(Model): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - - if cu_seqlen_prefill is not None or cuda_graph is None: - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, - ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - # assert block_tables.shape[0] >= slots.shape[0] - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - - # XXX: This is working only because block 0 is reserved for the healthcheck - # so it doesn't matter if we override it with bogus values. - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor - - with self._forward_context( - block_tables=cuda_graph["block_tables"], - cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - prefix_lens_tensor=cuda_graph["prefix_lengths"], - state=cuda_graph["state"], - ): - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, ) - logits = cuda_graph["logits"][:bs] + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + # TODO not support adapter now, need the add in the future + adapter_data=None, + hpu_attention_meta=batch.hpu_attn_meta, + **kwargs, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + # fix following runtime error in graph replay + # RuntimeError: Neither storage attached to input tensor, not its view + htorch.core.mark_step() return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batches: List[FlashCausalLMBatch] ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + if len(batches) > 1: + batch = self.batch_type.concatenate(batches) + else: + batch = batches[0] start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + else: + batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1611,13 +1723,23 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices + + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) + + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask speculate = get_speculate() ( @@ -1627,7 +1749,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculate, batch.speculative_ids, @@ -1638,85 +1760,103 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - else: - prefill_logprobs = None - next_position_ids = batch.position_ids - - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - stopped = True + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: + indices = batch.cu_seqlen_prefill[1:] - 1 + batch.position_ids = batch.position_ids[indices] + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ + indices + ] # Zipped iterator - iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.all_input_ids, + accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, + ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a GPU <-> CPU sync + # one, we need to first do a HPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - index = 0 - for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if prefill: + # Cumulative length + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + cumulative_length = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + all_input_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, + ) in enumerate(iterator): + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 ] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids - # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: - if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = ( - batch.input_ids[start_index + 1 : start_index + out_length] - ) - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] - - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] - index += 1 - + # If the device does not support triton, we copy one by one + if not request_is_prefilling: + # Only save tokens if we are done prefilling for this request + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] cumulative_length += input_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices + # These values can be updated without a HPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.speculative_ids = speculative_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) + batch.slot_indices += accepted_ids - if prefill: + if prefill and prefill_logprobs: + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out + prefill_logprobs = torch.gather( + prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) + ) + # HPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # Does a HPU <-> CPU sync internally + if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1725,192 +1865,282 @@ class FlashCausalLM(Model): device=batch.adapter_meta.adapter_segments.device, ) - if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) - prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) - ) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - - # GPU <-> CPU sync + # HPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + next_cache_length : next_cache_length + next_chunk_length + ] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + start_decode = time.time_ns() + # Results + generations: List[Generation] = [] + stopped = True + # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.cache_lengths, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + current_prefilling_mask, + batch.prefilling_mask, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch index = 0 for i, ( request, + prompt_length, + cache_length, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, top_n_tokens, + request_was_prefilling, + request_is_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: # Prefill - if prefill and request.prefill_logprobs: + if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[ + out_start_index:out_end_index + ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[ + cache_length + 1 : cache_length + input_length + 1 + ] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * ( + cache_length + 1 + ) + request_prefill_logprobs + prefill_token_ids = ( + all_input_ids[: cache_length + 1] + prefill_token_ids + ) - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ( - [float("nan")] * (len(prefix_ids) + 1) - ) + prefill_logprobs[out_start_index : out_end_index - 1] - prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, + prefill_logprob_tokens = Tokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens + + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - top_tokens = None + batch.prefill_logprob_tokens[i] = None - generation = Generation( - request.id, - prefill_tokens, - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length + else: + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 + # Append next token to all tokens + next_token_texts = [] + left = 0 - generations.append(generation) + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - # accept each new token for this specific request since we may - # have more than one new token per request with speculative decoding - for next_token_id in _next_token_ids: - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) - ) + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single( + i, next_token_id + ) + ) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + index += n_accepted_ids + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1921,83 +2151,14 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - def _forward_context( - self, - *, - block_tables: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, - state: Optional[Any] = None, - ) -> ContextManager: - if ATTENTION != "flashinfer": - return nullcontext() - - from text_generation_server.layers.attention.flashinfer import ( - use_decode_state, - use_prefill_with_paged_kv_state, - ) - - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) - - if cu_seqlen_prefill is not None: - return use_prefill_with_paged_kv_state( - state=( - state if state is not None else self.prefill_with_paged_kv_state - ), - # block_tables=block_tables_to_ragged( - # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, - # ), - block_tables=block_tables, - cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) - else: - assert input_lengths_tensor is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, - block_tables=block_tables, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) - - -def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] -) -> torch.Tensor: - """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) - - total_len = sum(input_lengths) + sum(prefix_lens) - block_tables_ragged = torch.empty( - total_len, dtype=torch.int32, device=block_tables.device - ) - - offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length - block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] - offset += seq_len - - return block_tables_ragged diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index 30a5d3da..cd221e14 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -1,53 +1,31 @@ -import torch import os from typing import Dict, Optional from loguru import logger from text_generation_server.utils.log import log_master +REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.getenv("ATTENTION", "default") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } -PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer", "default"} +_expected = {"paged", "default"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: - raise RuntimeError("Prefix caching is only supported with flashinfer") - -MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 # This is overridden by the cli BLOCK_SIZE: int -if ATTENTION == "flashdecoding": - BLOCK_SIZE = 256 -elif ATTENTION == "flashinfer": - BLOCK_SIZE = 1 -else: - BLOCK_SIZE = 16 -# This is overridden by the cli -cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None: - try: - cuda_graphs = [int(item) for item in cuda_graphs.split(",")] - except Exception as e: - raise RuntimeError( - f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" - ) -else: - cuda_graphs = None +BLOCK_SIZE = 128 -CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. global MODEL_ID diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py index 9a7a6fe1..98d7352a 100644 --- a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py @@ -34,9 +34,6 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.quantization import get_loader -from text_generation_server.utils.import_utils import SYSTEM - - tracer = trace.get_tracer(__name__) @@ -596,22 +593,8 @@ class IdeficsCausalLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - # 9b seems to work correctly enough in float16, but 80b seems - # to be really saturating for f16. - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype self.device, self.dtype = device, dtype config = AutoConfig.from_pretrained( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 9e19e171..507dabee 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -11,10 +11,6 @@ from transformers import ( ) from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import ( - block_tables_to_ragged, -) -from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.layers.attention import Seqlen @@ -254,104 +250,43 @@ class MllamaCausalLM(VlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] - # Try to find an associated cuda graph - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - if ( - cu_seqlen_prefill is not None - or cuda_graph is None - # Only run cuda graphs when there's no images. - or batch.cross_attention_states is not None - ): - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, - ) + input_lengths = input_lengths + prefix_lens_tensor + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) - if batch.pixel_values is not None: - cross_attention_states = self.model.vision_forward( - pixel_values=batch.pixel_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - ) - batch.cross_attention_states = cross_attention_states - - cross_attention_states = batch.cross_attention_states - - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, - adapter_data=adapter_data, - image_indices=batch.image_indices[:], - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, ) - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + batch.cross_attention_states = cross_attention_states - # Replay the graph - cuda_graph["graph"].replay() + cross_attention_states = batch.cross_attention_states - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + cross_attention_states=cross_attention_states, + adapter_data=adapter_data, + image_indices=batch.image_indices[:], ) - logits = cuda_graph["logits"][:bs] + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py index 4fda2271..66c69bc1 100644 --- a/backends/gaudi/server/text_generation_server/models/model.py +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -33,6 +33,7 @@ class Model(ABC): sliding_window: Optional[int] = None, speculate: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + support_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py index 04d4c28b..7a63d4dd 100644 --- a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py +++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py @@ -10,7 +10,6 @@ from transformers import ( AutoConfig, ) from typing import Optional, Tuple, List, Type, Dict -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -555,20 +554,9 @@ class Seq2SeqLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype config = config_class.from_pretrained( model_id, diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 0e9b97fb..1c45713e 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -1,15 +1,13 @@ import os import torch - +from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - -# CUDA memory fraction -MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) class FakeBarrier: @@ -17,10 +15,11 @@ class FakeBarrier: pass -class FakeGroup: +class FakeGroup(ProcessGroup): def __init__(self, rank, size): self._rank = rank self._size = size + super().__init__(rank, size) def allreduce(self, *args, **kwargs): return FakeBarrier() @@ -42,42 +41,11 @@ class FakeGroup: def rank(self): return self._rank + def _get_backend_name(self): + return "fake" + def initialize_torch_distributed(): - - world_size = int(os.getenv("WORLD_SIZE", "1")) - - options = None - if torch.cuda.is_available(): - from torch.distributed import ProcessGroupNCCL - - # Set the device id. - assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" - device = RANK % torch.cuda.device_count() - torch.cuda.set_device(device) - torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) - backend = "nccl" - options = ProcessGroupNCCL.Options() - options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) - elif torch.hpu.is_available(): - backend = "hccl" - n_hpus = torch.hpu.device_count() - if world_size > n_hpus: - raise ValueError( - f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})." - ) - else: - try: - import oneccl_bindings_for_pytorch # noqa: F401 - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" - options = None - if WORLD_SIZE == 1: return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE else: @@ -87,11 +55,10 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. torch.distributed.init_process_group( - backend=backend, + backend="hccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, + timeout=timedelta(seconds=120), ) else: logger.warning("torch.distributed is already initialized.") diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 782b4f15..22560dd7 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,75 +1,28 @@ import torch from loguru import logger -import os -import importlib.util +def get_hpu_free_memory(device, memory_fraction): + from habana_frameworks.torch.hpu import memory_stats - -def is_ipex_available(): - return importlib.util.find_spec("intel_extension_for_pytorch") is not None - - -def get_cuda_free_memory(device, memory_fraction): - total_free_memory, _ = torch.cuda.mem_get_info(device) - total_gpu_memory = torch.cuda.get_device_properties(device).total_memory - free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) - return free_memory - - -def get_xpu_free_memory(device, memory_fraction): - total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + mem_stats = memory_stats(device_id) + logger.info(f"mem_stats: {mem_stats}") + total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] free_memory = max( - 0, - int( - total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) - ), + 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) ) return free_memory -def get_cpu_free_memory(device, memory_fraction): - import psutil - from text_generation_server.utils.dist import WORLD_SIZE - - mem = psutil.virtual_memory() - free_memory = int(mem.available * 0.95 / WORLD_SIZE) - return free_memory +def synchronize_hpu(device): + torch.hpu.synchronize() def noop(*args, **kwargs): pass -SYSTEM = None -if torch.version.hip is not None: - SYSTEM = "rocm" - empty_cache = torch.cuda.empty_cache - synchronize = torch.cuda.synchronize - get_free_memory = get_cuda_free_memory -elif torch.version.cuda is not None and torch.cuda.is_available(): - SYSTEM = "cuda" - empty_cache = torch.cuda.empty_cache - synchronize = torch.cuda.synchronize - get_free_memory = get_cuda_free_memory -elif is_ipex_available(): - SYSTEM = "ipex" - import intel_extension_for_pytorch # noqa: F401 - - if hasattr(torch, "xpu") and torch.xpu.is_available(): - empty_cache = torch.xpu.empty_cache - synchronize = torch.xpu.synchronize - get_free_memory = get_xpu_free_memory - else: - empty_cache = noop - synchronize = noop - get_free_memory = get_cpu_free_memory -else: - SYSTEM = "cpu" - - empty_cache = noop - synchronize = noop - get_free_memory = get_cpu_free_memory -logger.info(f"Detected system {SYSTEM}") +empty_cache = noop +synchronize = synchronize_hpu +get_free_memory = get_hpu_free_memory diff --git a/backends/gaudi/server/text_generation_server/utils/kernels.py b/backends/gaudi/server/text_generation_server/utils/kernels.py new file mode 100644 index 00000000..42745c71 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/kernels.py @@ -0,0 +1,22 @@ +import importlib + +from loguru import logger +from hf_kernels import load_kernel as hf_load_kernel + +from text_generation_server.utils.log import log_once + + +def load_kernel(*, module: str, repo_id: str): + """ + Load a kernel. First try to load it as the given module (e.g. for + local development), falling back to a locked Hub kernel. + """ + try: + m = importlib.import_module(module) + log_once(logger.info, f"Using local module for `{module}`") + return m + except ModuleNotFoundError: + return hf_load_kernel(repo_id=repo_id) + + +__all__ = ["load_kernel"] diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index 75e01f7c..acd598d7 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Union, Type from safetensors import safe_open from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM - class WeightsLoader(ABC): """ @@ -88,12 +86,9 @@ class UnquantizedWeight(Weight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): - from text_generation_server.layers.linear import FastLinear, FastLinearROCm + from text_generation_server.layers.linear import FastLinear - if SYSTEM == "rocm": - return FastLinearROCm(self.weight, bias) - else: - return FastLinear(self.weight, bias) + return FastLinear(self.weight, bias) class DefaultWeightsLoader(WeightsLoader): @@ -197,7 +192,7 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ - def _has_tensor(self, tensor_name: str): + def has_tensor(self, tensor_name: str): try: self.get_filename(tensor_name) except Exception: @@ -207,7 +202,9 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): + def get_tensor( + self, tensor_name: str, to_device: bool = True, to_dtype: bool = True + ) -> torch.Tensor: filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -218,6 +215,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, @@ -253,7 +251,8 @@ class Weights: # u4 which are disguised as int32. exl2 uses int16. # FP8 uses torch.float8_e4m3fn. if ( - tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + tensor.dtype + not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32) and to_dtype ): tensor = tensor.to(dtype=self.dtype) @@ -329,6 +328,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, From b7fea6fc2f58dc06076f1018a4d887ad65f13357 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 14 Mar 2025 18:01:58 -0700 Subject: [PATCH 02/35] fix TP in pageattn Signed-off-by: Wang, Yi A --- backends/gaudi/server/text_generation_server/cli.py | 2 +- launcher/src/main.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 700f763e..24d1d748 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -112,7 +112,7 @@ def serve( logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - if sharded: + if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: tgi_file = Path(__file__).resolve().parent / "tgi_service.py" num_shard = int(os.getenv("WORLD_SIZE", "1")) logger.info("CLI SHARDED = {}".format(num_shard)) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 321d7c69..d9c41346 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1532,7 +1532,10 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().is_hpu_device() { + if rank != 0 + && env_runtime::Env::new().is_hpu_device() + && std::env::var("ATTENTION").as_deref() != Ok("paged") + { tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); break; } From 5d3653943c7796303972d93dd323a1d137adf222 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 16 Mar 2025 19:40:40 -0700 Subject: [PATCH 03/35] adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A --- backends/v3/src/block_allocator.rs | 12 +++++- router/src/usage_stats.rs | 59 ++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index e7f3d85a..6da2b51d 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; - +use text_generation_router::usage_stats::Env; #[derive(Debug, Clone)] pub struct BlockAllocation { pub allocation_id: u64, @@ -141,6 +141,7 @@ pub struct SimpleAllocator { free_blocks: Vec, block_size: u32, window_size: Option, + is_hpu_device: bool, } impl SimpleAllocator { @@ -150,6 +151,7 @@ impl SimpleAllocator { // Block 0 is reserved for health checks free_blocks: (1..blocks).collect(), window_size, + is_hpu_device: Env::new().is_hpu_device(), } } } @@ -179,9 +181,15 @@ impl Allocator for SimpleAllocator { if required_blocks > self.free_blocks.len() as u32 { None } else { - let blocks = self + if self.is_hpu_device { + self.free_blocks.sort_by(|a, b| b.cmp(a)); + } + let mut blocks = self .free_blocks .split_off(self.free_blocks.len() - required_blocks as usize); + if self.is_hpu_device { + blocks.sort(); + } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 353e9e37..a17aade9 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -157,6 +157,7 @@ pub struct Env { docker_label: &'static str, nvidia_info: Option>, xpu_info: Option>, + hpu_info: Option>, system_env: SystemInfo, } @@ -289,6 +290,60 @@ impl XpuSmiInfo { } } +#[derive(Debug, Serialize, Clone)] +struct HpuSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + temperature: String, + utilization: String, + memory_total: String, + memory_free: String, + memory_used: String, + power_draw_instant: String, +} + +impl HpuSmiInfo { + fn new() -> Option> { + let output = Command::new("hl-smi") + .args([ + "--query-aip=name,bus_id,driver_version,temperature.aip,utilization.aip,memory.total,memory.free,memory.used,power.draw", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(HpuSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + temperature: record[3].to_string(), + utilization: record[4].to_string(), + memory_total: record[5].to_string(), + memory_free: record[6].to_string(), + memory_used: record[7].to_string(), + power_draw_instant: record[8].to_string(), + }); + } + + Some(infos) + } +} + #[derive(Serialize, Debug, Clone)] pub struct SystemInfo { cpu_count: usize, @@ -335,10 +390,14 @@ impl Env { system_env: SystemInfo::new(), nvidia_info: NvidiaSmiInfo::new(), xpu_info: XpuSmiInfo::new(), + hpu_info: HpuSmiInfo::new(), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } + pub fn is_hpu_device(&self) -> bool { + self.hpu_info.is_some() + } } pub fn is_container() -> io::Result { From a07e7437b6281fb104b15553a4696a2450a6a201 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 16 Mar 2025 22:37:34 -0700 Subject: [PATCH 04/35] enable all the model. not testet yet Signed-off-by: Wang, Yi A --- .../custom_modeling/flash_cohere_modeling.py | 38 ++++++++---- .../custom_modeling/flash_dbrx_modeling.py | 41 ++++++++----- .../flash_deepseek_v2_modeling.py | 38 +++++++----- .../flash_deepseek_v3_modeling.py | 37 ++++++++---- .../custom_modeling/flash_gemma2_modeling.py | 37 +++++++----- .../custom_modeling/flash_gemma_modeling.py | 36 +++++++---- .../custom_modeling/flash_gpt2_modeling.py | 34 +++++++---- .../custom_modeling/flash_gptj_modeling.py | 38 +++++++----- .../custom_modeling/flash_llama_modeling.py | 8 +-- .../custom_modeling/flash_mistral_modeling.py | 28 ++++----- .../custom_modeling/flash_mixtral_modeling.py | 28 ++++----- .../custom_modeling/flash_neox_modeling.py | 38 +++++++----- .../custom_modeling/flash_phi_modeling.py | 35 +++++++---- .../custom_modeling/flash_qwen2_modeling.py | 29 +++++---- .../custom_modeling/flash_rw_modeling.py | 59 ++++++++++++------- .../flash_santacoder_modeling.py | 32 ++++++---- .../flash_starcoder2_modeling.py | 28 ++++----- .../models/flash_causal_lm.py | 40 ++++++------- 18 files changed, 374 insertions(+), 250 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 44df7964..8d32032d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers import ( @@ -221,7 +222,8 @@ class FlashCohereAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, key, value = qkv.split( @@ -245,9 +247,16 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value + kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -274,8 +283,8 @@ class FlashCohereAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -350,7 +359,8 @@ class FlashCohereLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -364,7 +374,8 @@ class FlashCohereLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) mlp_output = self.mlp(normed_hidden_states) @@ -416,15 +427,14 @@ class FlashCohereModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: torch.Tensor, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None @@ -439,7 +449,8 @@ class FlashCohereModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -480,8 +491,8 @@ class FlashCohereForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -493,7 +504,8 @@ class FlashCohereForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index ba86f579..c01bd1bc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -27,6 +27,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( FastLinear, @@ -312,7 +313,8 @@ class DbrxAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) if self.clip_qkv is not None: @@ -329,10 +331,14 @@ class DbrxAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -359,8 +365,8 @@ class DbrxAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -397,7 +403,8 @@ class DbrxNormAttentionNorm(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -411,7 +418,8 @@ class DbrxNormAttentionNorm(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -631,7 +639,8 @@ class DbrxLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): # Self Attention attn_output, attn_res = self.attn( @@ -644,7 +653,8 @@ class DbrxLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) moe_output = self.moe(attn_output) @@ -688,15 +698,14 @@ class DbrxModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -710,7 +719,8 @@ class DbrxModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -743,8 +753,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -756,7 +766,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index e30510b4..3298a30a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm @@ -258,7 +259,8 @@ class DeepseekV2Attention(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: query = self.q_proj(hidden_states) @@ -314,10 +316,15 @@ class DeepseekV2Attention(torch.nn.Module): value = torch.nn.functional.pad( value, (0, self.head_pad_size - self.value_head_size), value=0 ) - + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -344,8 +351,8 @@ class DeepseekV2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) # Remove padding. @@ -508,7 +515,8 @@ class DeepseekV2Layer(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -522,7 +530,8 @@ class DeepseekV2Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -571,15 +580,14 @@ class DeepseekV2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -593,7 +601,8 @@ class DeepseekV2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -623,8 +632,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -636,7 +645,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 452fe3f2..736e0c9a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm @@ -258,7 +259,8 @@ class DeepseekV3Attention(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: query = self.q_proj(hidden_states) @@ -315,9 +317,15 @@ class DeepseekV3Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -344,8 +352,8 @@ class DeepseekV3Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) # Remove padding. @@ -517,7 +525,8 @@ class DeepseekV3Layer(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -531,7 +540,8 @@ class DeepseekV3Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -580,15 +590,14 @@ class DeepseekV3Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -602,7 +611,8 @@ class DeepseekV3Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -632,8 +642,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -645,7 +655,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index ebf1b80e..5b7adad1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -237,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -252,10 +254,14 @@ class FlashGemma2Attention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -284,9 +290,9 @@ class FlashGemma2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, softcap=self.softcap, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -399,8 +405,9 @@ class FlashGemma2Layer(nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -414,8 +421,9 @@ class FlashGemma2Layer(nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -467,16 +475,15 @@ class FlashGemma2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - adapter_data: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor], + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -490,8 +497,9 @@ class FlashGemma2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -538,8 +546,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -552,8 +560,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index ad3be80e..d26184b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -209,7 +210,8 @@ class FlashGemmaAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -224,9 +226,14 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -254,8 +261,8 @@ class FlashGemmaAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -327,7 +334,8 @@ class FlashGemmaLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -341,7 +349,8 @@ class FlashGemmaLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -389,15 +398,14 @@ class FlashGemmaModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -411,7 +419,8 @@ class FlashGemmaModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -456,8 +465,8 @@ class FlashGemmaForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -470,7 +479,8 @@ class FlashGemmaForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 906b34c1..a6e0a7de 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -28,6 +28,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -215,7 +216,8 @@ class FlashGPT2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -224,9 +226,16 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value + kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -253,8 +262,8 @@ class FlashGPT2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -323,7 +332,8 @@ class FlashGPT2Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -336,7 +346,8 @@ class FlashGPT2Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states = attn_output + residual @@ -389,9 +400,8 @@ class FlashGPT2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -405,7 +415,8 @@ class FlashGPT2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states = self.norm(hidden_states) @@ -442,7 +453,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -458,9 +469,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index c23aa07f..9229a453 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -158,7 +159,8 @@ class FlashGPTJAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -175,9 +177,16 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value + kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -204,8 +213,8 @@ class FlashGPTJAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -266,7 +275,8 @@ class FlashGPTJLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) # Self Attention @@ -279,7 +289,8 @@ class FlashGPTJLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) feed_forward_hidden_states = self.mlp(hidden_states) @@ -326,16 +337,14 @@ class FlashGPTJModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -349,7 +358,8 @@ class FlashGPTJModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -380,8 +390,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -393,8 +403,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a118ace5..857e1757 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -206,7 +206,7 @@ class FlashLlamaAttention(torch.nn.Module): seqlen, adapter_data, prefill_cache_indices: Optional[torch.Tensor], - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -447,7 +447,7 @@ class FlashLlamaLayer(nn.Module): adapter_data, cross_attention_states, prefill_cache_indices: Optional[torch.Tensor], - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -559,8 +559,8 @@ class FlashLlamaModel(torch.nn.Module): seqlen: Seqlen, prefill_cache_indices: Optional[torch.Tensor], adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -646,11 +646,11 @@ class FlashLlamaForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index a0116297..8214b6b7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -180,9 +181,9 @@ class MistralAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -232,8 +233,8 @@ class MistralAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -337,9 +338,9 @@ class MistralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -353,9 +354,9 @@ class MistralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -405,17 +406,14 @@ class MistralModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -429,9 +427,9 @@ class MistralModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -480,13 +478,14 @@ class FlashMistralForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -503,9 +502,8 @@ class FlashMistralForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, + hpu_attention_meta, adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index a45dd1e6..18ffe060 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm @@ -237,8 +238,8 @@ class MixtralAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -288,8 +289,8 @@ class MixtralAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -386,8 +387,8 @@ class MixtralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -401,8 +402,8 @@ class MixtralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -456,17 +457,14 @@ class MixtralModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -480,8 +478,8 @@ class MixtralModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -515,13 +513,14 @@ class FlashMixtralForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -537,9 +536,8 @@ class FlashMixtralForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 2301b63c..76269f22 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -149,7 +150,8 @@ class FlashNeoxAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -164,10 +166,14 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(query_rot, key_rot, cos, sin) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) + if prefill_cache_indices is not None: + qkv_to_cache = qkv[prefill_cache_indices] + else: + qkv_to_cache = qkv kv_cache.store( - key=qkv[:, 1], - value=qkv[:, 2], + key=qkv_to_cache[:, 1], + value=qkv_to_cache[:, 2], slots=slots, kv_scales=self.kv_scales, ) @@ -194,8 +200,8 @@ class FlashNeoxAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -265,7 +271,8 @@ class FlashNeoXLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -279,7 +286,8 @@ class FlashNeoXLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -303,7 +311,8 @@ class FlashNeoXLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, residual = self.post_attention_layernorm( @@ -357,15 +366,14 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -379,7 +387,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.final_layer_norm(hidden_states, residual) @@ -411,7 +420,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -424,7 +433,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7382a7cb..21c4bc71 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -10,6 +10,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -162,7 +163,8 @@ class FlashPhiAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): # Compute query, key, value and split qkv = self.query_key_value(hidden_states) @@ -188,9 +190,13 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -216,8 +222,8 @@ class FlashPhiAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -284,7 +290,8 @@ class FlashPhiLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention @@ -297,7 +304,8 @@ class FlashPhiLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states = self.resid_dropout(attn_output).add( @@ -349,15 +357,14 @@ class FlashPhiModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -371,7 +378,8 @@ class FlashPhiModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -404,8 +412,8 @@ class FlashPhiForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -417,7 +425,8 @@ class FlashPhiForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index d6569a1d..c62435fe 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -108,8 +109,8 @@ class Qwen2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -159,8 +160,8 @@ class Qwen2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -232,8 +233,8 @@ class Qwen2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, residual = self.input_layernorm(hidden_states) @@ -247,8 +248,8 @@ class Qwen2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) hidden_states = attn_output + residual @@ -298,16 +299,13 @@ class Qwen2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, - true_max_s, - hidden_states.dtype, ) residual = None @@ -322,8 +320,8 @@ class Qwen2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states) @@ -369,13 +367,15 @@ class Qwen2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + + if prefill_cache_indices is not None and prefill_cache_indices.size( + 0 + ) != slots.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -393,9 +393,8 @@ class Qwen2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fbf1a597..c6034bf0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -19,6 +19,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, Seqlen, + HPUPagedAttentionMetadata, ) @@ -184,7 +185,8 @@ class FlashRWAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -201,9 +203,14 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -230,8 +237,8 @@ class FlashRWAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -305,7 +312,8 @@ class FlashRWLargeAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -321,9 +329,14 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + kv_cache.store( - key=kv[:, :, 0].contiguous(), - value=kv[:, :, 1].contiguous(), + key=kv_to_cache[:, :, 0].contiguous(), + value=kv_to_cache[:, :, 1].contiguous(), slots=slots, kv_scales=self.kv_scales, ) @@ -350,8 +363,8 @@ class FlashRWLargeAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense( @@ -437,7 +450,8 @@ class FlashRWLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -451,7 +465,8 @@ class FlashRWLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) mlp_output = self.mlp(ln_hidden_states) @@ -473,7 +488,8 @@ class FlashRWLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if self.post_attention_layernorm is not None: @@ -560,7 +576,8 @@ class FlashRWLargeLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): # Layer norm. ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) @@ -575,7 +592,8 @@ class FlashRWLargeLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # MLP. @@ -636,15 +654,14 @@ class FlashRWModel(FlashRWPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.h): @@ -658,7 +675,8 @@ class FlashRWModel(FlashRWPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -688,8 +706,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -701,7 +719,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index ed053eb6..9b24e8ba 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -271,7 +272,8 @@ class FlashMQAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.c_attn(hidden_states) @@ -284,9 +286,14 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) + if prefill_cache_indices is not None: + key_value_to_cache = key_value[prefill_cache_indices] + else: + key_value_to_cache = key_value + kv_cache.store( - key=key_value[:, 0], - value=key_value[:, 1], + key=key_value_to_cache[:, 0], + value=key_value_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -313,8 +320,8 @@ class FlashMQAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -379,7 +386,8 @@ class Block(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.self_attn( @@ -389,7 +397,8 @@ class Block(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -443,7 +452,8 @@ class FlashSantacoderModel(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -460,7 +470,8 @@ class FlashSantacoderModel(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -492,7 +503,7 @@ class FlashSantacoderForCausalLM(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -505,7 +516,8 @@ class FlashSantacoderForCausalLM(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 5e090369..d12bee5c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, @@ -237,9 +238,9 @@ class Starcoder2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -289,8 +290,8 @@ class Starcoder2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -450,9 +451,9 @@ class Starcoder2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -466,9 +467,9 @@ class Starcoder2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -520,18 +521,15 @@ class Starcoder2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -545,9 +543,9 @@ class Starcoder2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -594,13 +592,14 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -616,10 +615,9 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 49313c83..27e1c672 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1009,24 +1009,22 @@ class FlashCausalLMBatch(Batch): # padding to left to work with sliding window # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate # the right logit position - - if device.type == "hpu": - input_ids_padded = None - input_ids_padded_length = None - if isinstance(self.input_ids, list) and len(self) > 1: - input_ids_padded = [] - input_ids_padded_length = [] - for input_id in self.input_ids: - padded = self.max_input_length - len(input_id) - input_id_padded = input_id - if padded > 0: - input_id_padded = [0] * padded + input_id_padded - input_ids_padded.append(input_id_padded) - input_ids_padded_length.append(padded) - input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) - input_ids_padded = torch.tensor( - input_ids_padded, dtype=torch.int64, device=device - ) + input_ids_padded = None + input_ids_padded_length = None + if isinstance(self.input_ids, list) and len(self) > 1: + input_ids_padded = [] + input_ids_padded_length = [] + for input_id in self.input_ids: + padded = self.max_input_length - len(input_id) + input_id_padded = input_id + if padded > 0: + input_id_padded = [0] * padded + input_id_padded + input_ids_padded.append(input_id_padded) + input_ids_padded_length.append(padded) + input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) + input_ids_padded = torch.tensor( + input_ids_padded, dtype=torch.int64, device=device + ) if isinstance(self.input_ids, list): if len(self) > 1: @@ -1084,7 +1082,7 @@ class FlashCausalLMBatch(Batch): request_position_ids = torch.arange( cache_length, cache_length + input_length, dtype=torch.int32 ) - if device.type == "hpu" and input_ids_padded is not None: + if input_ids_padded is not None: position_ids.append( torch.ones(input_ids_padded_length[i], dtype=torch.int32) ) @@ -1111,7 +1109,7 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill - if device.type == "hpu" and input_ids_padded is not None: + if input_ids_padded is not None: # hpu need request_prefill_cache_indices to skip padding in kv cache sliding_window = get_sliding_windows() if sliding_window is None: @@ -1235,7 +1233,7 @@ class FlashCausalLMBatch(Batch): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - if device.type == "hpu" and input_ids_padded is not None: + if input_ids_padded is not None: self.input_ids = input_ids_padded input_ids_padded_length_tensor = torch.cumsum( torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), From 6bbe24d9743e4a2c7e8a02890cd3aef9cea08c1d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 17 Mar 2025 01:36:49 -0700 Subject: [PATCH 05/35] use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A --- .../server/text_generation_server/models/flash_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 27e1c672..3a0dc15e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1398,7 +1398,7 @@ class FlashCausalLM(Model): self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype if htorch.utils.internal.is_lazy(): - htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" From 5cd1c93cad96fa0c00deb7be26becfce2854084b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 18 Mar 2025 00:45:15 -0700 Subject: [PATCH 06/35] add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A --- .../layers/moe/__init__.py | 2 +- .../moe/{fused_moe_ipex.py => fused_moe.py} | 0 .../layers/moe/unquantized.py | 132 ++---------------- .../custom_modeling/flash_cohere_modeling.py | 2 +- .../custom_modeling/flash_dbrx_modeling.py | 2 +- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../custom_modeling/flash_gpt2_modeling.py | 2 +- .../custom_modeling/flash_gptj_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 7 +- .../custom_modeling/flash_mistral_modeling.py | 10 +- .../custom_modeling/flash_mixtral_modeling.py | 10 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 10 +- .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 2 +- .../flash_starcoder2_modeling.py | 10 +- 17 files changed, 36 insertions(+), 163 deletions(-) rename backends/gaudi/server/text_generation_server/layers/moe/{fused_moe_ipex.py => fused_moe.py} (100%) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py index cba81407..8b9d6fcb 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py @@ -19,7 +19,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, ) -from .fused_moe_ipex import fused_topk, grouped_topk +from .fused_moe import fused_topk, grouped_topk # NOTE: we are using a protocol here, because multiple inherance is not nice. # We need `Module`, and `Module` -> some abstract class -> some concrete diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py similarity index 100% rename from backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py rename to backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index 8cb27879..ec158398 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -1,11 +1,10 @@ -from typing import Callable, List, Optional +from typing import Optional import torch import torch.nn as nn from text_generation_server.utils.weights import UnquantizedWeight, Weights - -moe_kernels = None +from vllm_hpu_extension.ops import DynamicFusedMOE class UnquantizedSparseMoELayer(nn.Module): @@ -54,21 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) + self.hpu_fused_moe = DynamicFusedMOE(n_experts) + for i in range(n_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - ) + return self.hpu_fused_moe(x, gating_output, self.topk) def _load_expert_multi_weights_col( @@ -128,110 +119,3 @@ def _load_expert_weights_row( assert all_weight is not None return all_weight - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseekv2 model uses grouped_topk - - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. - - block_shape: (Optional[List[int]]): Optional block size for block-wise - quantization. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - from loguru import logger - import inspect - - logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}") - topk_weights, topk_ids = moe_kernels.grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = moe_kernels.fused_topk( - hidden_states, gating_output, topk, renormalize - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize - ) - - return moe_kernels.fused_experts( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 8d32032d..77dec80d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -263,7 +263,7 @@ class FlashCohereAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index c01bd1bc..0f1338ca 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -345,7 +345,7 @@ class DbrxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 5b7adad1..632e8017 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -268,7 +268,7 @@ class FlashGemma2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index d26184b6..d832fb00 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -240,7 +240,7 @@ class FlashGemmaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a6e0a7de..80236fe8 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -242,7 +242,7 @@ class FlashGPT2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 9229a453..3135acde 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -193,7 +193,7 @@ class FlashGPTJAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 857e1757..a0c4fb8c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -235,7 +235,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], @@ -652,6 +652,11 @@ class FlashLlamaForCausalLM(torch.nn.Module): adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8214b6b7..38eba082 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -212,11 +212,11 @@ class MistralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -488,10 +488,6 @@ class FlashMistralForCausalLM(torch.nn.Module): ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 18ffe060..fbcb0970 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -268,11 +268,11 @@ class MixtralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -523,10 +523,6 @@ class FlashMixtralForCausalLM(torch.nn.Module): ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 76269f22..d1904c03 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -180,7 +180,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=qkv[:, 0], key=qkv[:, 1], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index c62435fe..480a17d1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -139,11 +139,11 @@ class Qwen2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -378,10 +378,6 @@ class Qwen2ForCausalLM(torch.nn.Module): ) != slots.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c6034bf0..e7c4b2b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9b24e8ba..57d4ee64 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -300,7 +300,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key_value[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index d12bee5c..082e5d82 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -269,11 +269,11 @@ class Starcoder2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -602,10 +602,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, From 073f79397629e0f49fb449e463ade2829072e85c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 18 Mar 2025 23:11:01 -0700 Subject: [PATCH 07/35] fix phimoe issue Signed-off-by: Wang, Yi A --- .../gaudi/server/text_generation_server/layers/rotary.py | 8 ++++++++ .../server/text_generation_server/models/__init__.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 5b6cad5c..1f8a6bd1 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -188,6 +188,7 @@ class PositionRotaryEmbedding(nn.Module): long_inv_freq=long_inv_freq, scaling_factor=scaling_factor, original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=config.max_position_embeddings, ) else: raise NotImplementedError( @@ -276,6 +277,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): long_inv_freq, scaling_factor, original_max_position_embeddings, + max_position_embeddings, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq @@ -288,6 +290,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -341,6 +346,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 926fb57a..7144542f 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -25,6 +25,9 @@ from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) +from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( + PhiMoEConfig, +) # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch # from text_generation_server.models.custom_modeling.mllama import ( From 2cde30de246ea91335539f6b3bba903a29c70021 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 18 Mar 2025 23:59:31 -0700 Subject: [PATCH 08/35] gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A --- .../gaudi/server/text_generation_server/models/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 7144542f..9229bcf2 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -486,8 +486,6 @@ def get_model( model_type = config_dict["model_type"] - if model_type == "gpt_bigcode": - return StarCoder(model_id=model_id, revision=revision, dtype=dtype) kv_cache_dtype = dtype if FLASH_ATTENTION: @@ -871,6 +869,8 @@ def get_model( trust_remote_code=trust_remote_code, ) adapt_transformers_to_gaudi() + if model_type == "gpt_bigcode": + return StarCoder(model_id=model_id, revision=revision, dtype=dtype) if model_type == "bloom": return BLOOM( model_id=model_id, From 2074d0516b5b8cad80e6ee858d770aaf4a358ce7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 19 Mar 2025 03:16:41 -0700 Subject: [PATCH 09/35] enable dbrx remove some unused code Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/rotary.py | 5 + .../text_generation_server/models/__init__.py | 20 - .../custom_modeling/flash_dbrx_modeling.py | 20 +- .../models/custom_modeling/mpt_modeling.py | 1215 ---------------- .../models/custom_modeling/neox_modeling.py | 796 ----------- .../models/custom_modeling/opt_modeling.py | 864 ------------ .../models/custom_modeling/phi_modeling.py | 336 ----- .../models/custom_modeling/t5_modeling.py | 1227 ----------------- 8 files changed, 12 insertions(+), 4471 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 1f8a6bd1..b25f9fab 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -77,6 +77,11 @@ class PositionRotaryEmbedding(nn.Module): inv_freq = _create_inv_freq(dim, base, device) scaling_factor = None rope_scaling = _get_rope_config(config) + if not hasattr(config, "max_position_embeddings") and hasattr( + config, "max_seq_len" + ): + # handling for dbrx + config.max_position_embeddings = config.max_seq_len if rope_scaling is not None: # `rope_type` is now standard in transformers, but some existing models # have `type` instead. diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 9229bcf2..dfdec9dc 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -286,16 +286,6 @@ class ModelType(enum.Enum): "name": "Qwen 2.5 VL", "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", } - OPT = { - "type": "opt", - "name": "Opt", - "url": "https://huggingface.co/facebook/opt-6.7b", - } - T5 = { - "type": "t5", - "name": "T5", - "url": "https://huggingface.co/google/flan-t5-xxl", - } GALACTICA = { "type": "galactica", "name": "Galactica", @@ -306,16 +296,6 @@ class ModelType(enum.Enum): "name": "SantaCoder", "url": "https://huggingface.co/bigcode/santacoder", } - BLOOM = { - "type": "bloom", - "name": "Bloom", - "url": "https://huggingface.co/bigscience/bloom-560m", - } - MPT = { - "type": "mpt", - "name": "Mpt", - "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", - } GPT2 = { "type": "gpt2", "name": "Gpt2", diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 0f1338ca..b335a81f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -43,9 +43,7 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) - - -moe_kernels = None +from vllm_hpu_extension.ops import DynamicFusedMOE class DbrxAttentionConfig(PretrainedConfig): @@ -497,19 +495,15 @@ class BlockSparseMoE(nn.Module): self.process_group = weights.process_group + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) + for i in range(self.num_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i]) + def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - - out = moe_kernels.fused_moe( - x, - self.wv1, - self.w2, - router_logits, - self.top_k, - renormalize=self.moe_normalize_expert_weights, - inplace=True, - ) + out = self.hpu_fused_moe(x, router_logits, self.top_k) # Reduce sum if self.process_group.size() > 1: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py deleted file mode 100644 index 988a74a3..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ /dev/null @@ -1,1215 +0,0 @@ -"""A simple, flexible implementation of a GPT model. - -Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py -""" - -import math -import warnings -from typing import List, Optional, Tuple, Union -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from einops import rearrange -from packaging import version -from text_generation_server.layers import ( - TensorParallelEmbedding, - TensorParallelColumnLinear, - TensorParallelRowLinear, - SpeculativeHead, - get_linear, -) - -EPS = 1e-5 - - -def load_col(config, prefix, weights, bias): - assert config.quantize != "gptq", NotImplementedError - slice_ = weights._get_slice(f"{prefix}.weight") - rank = weights.process_group.rank() - size = weights.process_group.size() - - h3, h = slice_.get_shape() - block_size = h // size - - q_part = slice_[rank * block_size : (rank + 1) * block_size] - k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] - v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] - - weight = torch.cat([q_part, k_part, v_part], dim=0) - if weight.dtype != torch.int32: - weight = weight.to(dtype=weights.dtype) - weight = weight.to(device=weights.device) - - if bias: - bias_slice_ = weights._get_slice(f"{prefix}.bias") - bias_rank = weights.process_group.rank() - bias_size = weights.process_group.size() - - bias_h = bias_slice_.get_shape() - bias_h = bias_h[0] - bias_block_size = bias_h // bias_size - - bias_q_part = bias_slice_[ - bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size - ] - bias_k_part = bias_slice_[ - bias_h - + bias_rank * bias_block_size : bias_h - + (bias_rank + 1) * bias_block_size - ] - bias_v_part = bias_slice_[ - 2 * bias_h - + bias_rank * bias_block_size : 2 * bias_h - + (bias_rank + 1) * bias_block_size - ] - - bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) - if bias.dtype != torch.int32: - bias = bias.to(dtype=weights.dtype) - bias = bias.to(device=weights.device) - else: - bias = None - linear = get_linear(weight, bias) - return TensorParallelColumnLinear(linear) - - -def _reset_is_causal( - num_query_tokens: int, num_key_tokens: int, original_is_causal: bool -): - if original_is_causal and num_query_tokens != num_key_tokens: - if num_query_tokens != 1: - raise NotImplementedError( - "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." - ) - else: - return False - return original_is_causal - - -def scaled_multihead_dot_product_attention( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) - kv_n_heads = 1 if multiquery else n_heads - k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) - v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) - if past_key_value is not None: - if len(past_key_value) != 0: - k = torch.cat([past_key_value[0], k], dim=3) - v = torch.cat([past_key_value[1], v], dim=2) - past_key_value = (k, v) - (b, _, s_q, d) = q.shape - s_k = k.size(-1) - attn_weight = q.matmul(k) * softmax_scale - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - s_q) - _s_k = max(0, attn_bias.size(3) - s_k) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if ( - attn_bias.size(-1) != 1 - and attn_bias.size(-1) != s_k - or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) - ): - raise RuntimeError( - f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." - ) - attn_weight = attn_weight + attn_bias - min_val = torch.finfo(q.dtype).min - if key_padding_mask is not None: - if attn_bias is not None: - warnings.warn( - "Propogating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unneccessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - attn_weight = attn_weight.masked_fill( - ~key_padding_mask.view((b, 1, 1, s_k)), min_val - ) - if is_causal and (not q.size(2) == 1): - s = max(s_q, s_k) - causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) - causal_mask = causal_mask.tril() - causal_mask = causal_mask.to(torch.bool) - causal_mask = ~causal_mask - causal_mask = causal_mask[-s_q:, -s_k:] - attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p: - attn_weight = torch.nn.functional.dropout( - attn_weight, p=dropout_p, training=training, inplace=True - ) - out = attn_weight.to(v.dtype).matmul(v) - out = rearrange(out, "b h s d -> b s (h d)") - if needs_weights: - return (out, attn_weight, past_key_value) - return (out, None, past_key_value) - - -def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): - for tensor in tensors: - if tensor.dtype not in valid_dtypes: - raise TypeError( - f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." - ) - if not tensor.is_cuda: - raise TypeError( - f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." - ) - - -def flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from flash_attn import bert_padding, flash_attn_interface - except Exception: - raise RuntimeError("Please install flash-attn==1.0.3.post0") - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: - raise NotImplementedError("attn_bias not implemented for flash attn.") - (batch_size, seqlen) = query.shape[:2] - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1) :] - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( - query, query_padding_mask - ) - query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( - key, key_padding_mask - ) - key_unpad = rearrange( - key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange( - value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - if multiquery: - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand( - value_unpad.size(0), n_heads, value_unpad.size(-1) - ) - dropout_p = dropout_p if training else 0.0 - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights, - ) - output = bert_padding.pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen - ) - return (output, None, past_key_value) - - -def triton_flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from .flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if version.parse(torch.__version__) < version.parse("2.0.0"): - _installed = True - try: - from flash_attn.flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if not _installed: - raise RuntimeError( - "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." - ) - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if dropout_p: - raise NotImplementedError("Dropout not implemented for attn_impl: triton.") - if needs_weights: - raise NotImplementedError("attn_impl: triton cannot return attn weights.") - if key_padding_mask is not None: - warnings.warn( - "Propagating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unnecessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - (b_size, s_k) = key_padding_mask.shape[:2] - if attn_bias is None: - attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill( - ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min - ) - query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) - key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - if multiquery: - key = key.expand(*key.shape[:2], n_heads, key.size(-1)) - value = value.expand(*value.shape[:2], n_heads, value.size(-1)) - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_func( - query, key, value, attn_bias, reset_is_causal, softmax_scale - ) - output = attn_output.view(*attn_output.shape[:2], -1) - return (output, None, past_key_value) - - -class MultiheadAttention(nn.Module): - """Multi-head self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__( - self, - config, - prefix, - weights, - ): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config.attn_pdrop - - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // weights.process_group.size() - self.Wqkv = load_col( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - if self.qk_ln: - bias = not config.no_bias - hidden_size = config.d_model - head_dim = hidden_size // self.n_heads - - self.q_ln = LPLayerNorm( - d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights - ) - self.k_ln = LPLayerNorm( - self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights - ) - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.chunk(3, dim=2) - - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - ) - out = self.out_proj(context) - return (out, attn_weights, past_key_value) - - -class MultiQueryAttention(nn.Module): - """Multi-Query self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__(self, config, prefix, weights, verbose=False): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config.attn_pdrop - # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - (d_model, d_model + self.head_dim) - if self.qk_ln: - raise NotImplementedError("qk_ln not supported") - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - if verbose: - warnings.warn( - "While `attn_impl: triton` can be faster than `attn_impl: flash` " - + "it uses more memory. When training larger models this can trigger " - + "alloc retries which hurts performance. If encountered, we recommend " - + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." - ) - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - if torch.cuda.is_available() and verbose: - warnings.warn( - "Using `attn_impl: torch`. If your model does not use `alibi` or " - + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " - + "we recommend using `attn_impl: triton`." - ) - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.out_proj._is_residual = True - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2 - ) - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - multiquery=True, - ) - return (self.out_proj(context), attn_weights, past_key_value) - - -def attn_bias_shape( - attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - if (prefix_lm or not causal) or use_sequence_id: - return (1, n_heads, seq_len, seq_len) - return (1, n_heads, 1, seq_len) - elif prefix_lm or use_sequence_id: - return (1, 1, seq_len, seq_len) - return None - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def build_attn_bias( - attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - (device, dtype) = (attn_bias.device, attn_bias.dtype) - attn_bias = attn_bias.add( - build_alibi_bias( - n_heads, - seq_len, - full=not causal, - alibi_bias_max=alibi_bias_max, - device=device, - dtype=dtype, - ) - ) - return attn_bias - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def gen_slopes(n_heads, alibi_bias_max=8, device=None): - _n_heads = 2 ** math.ceil(math.log2(n_heads)) - m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) - m = m.mul(alibi_bias_max / _n_heads) - slopes = 1.0 / torch.pow(2, m) - if _n_heads != n_heads: - slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - return slopes.view(1, n_heads, 1, 1) - - -def build_alibi_bias( - n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None -): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( - 1, 1, 1, seq_len - ) - if full: - alibi_bias = alibi_bias - torch.arange( - 1 - seq_len, 1, dtype=torch.int32, device=device - ).view(1, 1, seq_len, 1) - alibi_bias = alibi_bias.abs().mul(-1) - slopes = gen_slopes(n_heads, alibi_bias_max, device=device) - alibi_bias = alibi_bias * slopes - return alibi_bias.to(dtype=dtype) - - -ATTN_CLASS_REGISTRY = { - "multihead_attention": MultiheadAttention, - "multiquery_attention": MultiQueryAttention, -} - -"""GPT Blocks used for the GPT Model.""" - - -class MPTMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) - self.up_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias - ) - self.act = nn.GELU(approximate="none") - # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.down_proj._is_residual = True - - def forward(self, x): - return self.down_proj(self.act(self.up_proj(x))) - - -class MPTBlock(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.prefix = prefix - if config.attn_config.attn_type != "multihead_attention": - raise NotImplementedError( - f"""Not implemented attn {config.attn_config.attn_type}""" - ) - resid_pdrop = config.resid_pdrop - if config.no_bias: - self.norm_1 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - else: - self.norm_1 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) - self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) - self.resid_attn_dropout = nn.Dropout(resid_pdrop) - self.resid_ffn_dropout = nn.Dropout(resid_pdrop) - - def forward( - self, - x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attn_bias: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.ByteTensor] = None, - is_causal: bool = True, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: - a = self.norm_1(x) - (b, attn_weights, past_key_value) = self.attn( - a, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=is_causal, - ) - x = x + self.resid_attn_dropout(b) - m = self.norm_2(x) - n = self.ffn(m) - x = x + self.resid_ffn_dropout(n) - return (x, attn_weights, past_key_value) - - -def _cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor - - -class LPLayerNorm(torch.nn.LayerNorm): - def __init__( - self, - normalized_shape, - eps=1e-05, - elementwise_affine=True, - device=None, - dtype=None, - bias: Optional[bool] = True, - prefix=None, - weights=None, - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - device=device, - dtype=dtype, - bias=bias, - ) - if weights is not None: - self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) - if bias: - self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) - self.normalized_shape = self.weight.shape - - def forward(self, x): - module_device = x.device - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - downcast_bias = ( - _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - ) - with torch.autocast(enabled=False, device_type=module_device.type): - return torch.nn.functional.layer_norm( - downcast_x, - self.normalized_shape, - downcast_weight, - downcast_bias, - self.eps, - ) - - -def rms_norm(x, weight=None, eps=1e-05): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - if weight is not None: - return output * weight - return output - - -class RMSNorm(torch.nn.Module): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__() - self.eps = eps - if weight: - self.weight = torch.nn.Parameter( - torch.ones(normalized_shape, dtype=dtype, device=device) - ) - else: - self.register_parameter("weight", None) - - def forward(self, x): - return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) - - -class LPRMSNorm(RMSNorm): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - weight=weight, - dtype=dtype, - device=device, - ) - - def forward(self, x): - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - with torch.autocast(enabled=False, device_type=x.device.type): - return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) - - -NORM_CLASS_REGISTRY = { - "layernorm": torch.nn.LayerNorm, - "low_precision_layernorm": LPLayerNorm, - "rmsnorm": RMSNorm, - "low_precision_rmsnorm": LPRMSNorm, -} - -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - -class MPTPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" - _no_split_modules = ["MPTBlock"] - - -class MPTModel(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - # config._validate_config() - super().__init__(config) - self.world_size = weights.process_group.size() - self.rank = weights.process_group.rank() - self.n_heads = config.n_heads - self.attn_impl = config.attn_config.attn_impl - self.prefix_lm = config.attn_config.prefix_lm - self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id - self.alibi = config.attn_config.alibi - self.alibi_bias_max = config.attn_config.alibi_bias_max - if config.init_device == "mixed": - # TODO: reimplement mixed device initialization - # dist.get_local_rank() == 0: - if True: - config.init_device = "cpu" - else: - config.init_device = "meta" - if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." - ) - if config.norm_type.lower() != "low_precision_layernorm": - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo." - ) - - self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) - - if not self.alibi: - self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) - self.blocks = nn.ModuleList( - [ - MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) - for i in range(config.n_layers) - ] - ) - if config.no_bias: - self.norm_f = nn.LayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - else: - self.norm_f = nn.LayerNorm.load( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - self.is_causal = not self.prefix_lm - self._attn_bias_initialized = False - self.attn_bias = None - self.attn_bias_shape = attn_bias_shape( - self.attn_impl, - config.n_heads, - config.max_seq_len, - self.alibi, - prefix_lm=self.prefix_lm, - causal=self.is_causal, - use_sequence_id=self.attn_uses_sequence_id, - ) - if config.no_bias: - for module in self.modules(): - if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): - if config.verbose: - warnings.warn(f"Removing bias ({module.bias}) from {module}.") - module.register_parameter("bias", None) - if hasattr(self.config, "verbose"): - if config.verbose and config.verbose > 2: - print(self) - if "verbose" not in self.config.init_config: - self.config.init_config["verbose"] = self.config.verbose - if self.config.init_config["verbose"] > 1: - init_fn_name = self.config.init_config["name"] - warnings.warn(f"Using {init_fn_name} initialization.") - - @torch.no_grad() - def _attn_bias( - self, - device, - dtype, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - ): - if not self._attn_bias_initialized: - if self.attn_bias_shape: - self.attn_bias = torch.zeros( - self.attn_bias_shape, device=device, dtype=dtype - ) - self.attn_bias = build_attn_bias( - self.attn_impl, - self.attn_bias, - self.config.n_heads, - self.config.max_seq_len, - causal=self.is_causal, - alibi=self.alibi, - alibi_bias_max=self.alibi_bias_max, - ) - assert self.n_heads % self.world_size == 0 - block_size = self.n_heads // self.world_size - self.attn_bias = self.attn_bias[ - :, self.rank * block_size : (self.rank + 1) * block_size - ] - self._attn_bias_initialized = True - if self.attn_impl == "flash": - return (self.attn_bias, attention_mask) - if self.attn_bias is not None: - self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) - attn_bias = self.attn_bias - if self.prefix_lm: - assert isinstance(attn_bias, torch.Tensor) - assert isinstance(prefix_mask, torch.Tensor) - attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) - if self.attn_uses_sequence_id and sequence_id is not None: - assert isinstance(attn_bias, torch.Tensor) - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) - if attention_mask is not None: - s_k = attention_mask.shape[-1] - if attn_bias is None: - attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) - else: - _s_k = max(0, attn_bias.size(-1) - s_k) - attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: - raise ValueError( - f"attention_mask shape={attention_mask.shape} " - + f"and prefix_mask shape={prefix_mask.shape} are not equal." - ) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill( - ~attention_mask.view(-1, 1, 1, s_k), min_val - ) - return (attn_bias, None) - - def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): - (s_k, s_q) = attn_bias.shape[-2:] - if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: - raise ValueError( - "attn_bias does not match the expected shape. " - + f"The last two dimensions should both be {self.config.max_length} " - + f"but are {s_k} and {s_q}." - ) - seq_len = prefix_mask.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - causal = torch.tril( - torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) - ).view(1, 1, seq_len, seq_len) - prefix = prefix_mask.view(-1, 1, 1, seq_len) - cannot_attend = ~torch.logical_or(causal, prefix.bool()) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def _apply_sequence_id( - self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor - ): - seq_len = sequence_id.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - cannot_attend = torch.logical_not( - torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) - ).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if attention_mask is not None: - attention_mask = attention_mask.bool() - if prefix_mask is not None: - prefix_mask = prefix_mask.bool() - if not return_dict: - raise NotImplementedError( - "return_dict False is not implemented yet for MPT" - ) - if output_attentions: - if self.attn_impl != "torch": - raise NotImplementedError( - "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." - ) - if ( - attention_mask is not None - and attention_mask[:, 0].sum() != attention_mask.shape[0] - and self.training - ): - raise NotImplementedError( - "MPT does not support training with left padding." - ) - if self.prefix_lm and prefix_mask is None: - raise ValueError( - "prefix_mask is a required argument when MPT is configured with prefix_lm=True." - ) - if self.training: - if self.attn_uses_sequence_id and sequence_id is None: - raise ValueError( - "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " - + "and the model is in train mode." - ) - elif self.attn_uses_sequence_id is False and sequence_id is not None: - warnings.warn( - "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " - + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." - ) - S = input_ids.size(1) - assert ( - S <= self.config.max_seq_len - ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" - tok_emb = self.wte(input_ids) - if self.alibi: - x = tok_emb - else: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - "past_key_values must provide a past_key_value for each attention " - + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." - ) - past_position = past_key_values[0][0].size(1) - if self.attn_impl == "torch": - past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: - raise ValueError( - f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." - ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - pos = torch.clamp( - pos - - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ - :, past_position: - ], - min=0, - ) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - (attn_bias, attention_mask) = self._attn_bias( - device=x.device, - dtype=torch.float32, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - ) - if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.config.n_layers)] - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for b_idx, block in enumerate(self.blocks): - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - past_key_value = ( - past_key_values[b_idx] if past_key_values is not None else None - ) - (x, attn_weights, past_key_value) = block( - x, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=self.is_causal, - ) - if past_key_values is not None: - past_key_values[b_idx] = past_key_value - if output_attentions: - assert all_self_attns is not None - all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - return BaseModelOutputWithPast( - last_hidden_state=x, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - if not config.tie_word_embeddings: - raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(prefix, config, weights) - self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.wte", weights=weights - ) - self.logit_scale = None - if config.logit_scale is not None: - logit_scale = config.logit_scale - if isinstance(logit_scale, str): - if logit_scale == "inv_sqrt_d_model": - logit_scale = 1 / math.sqrt(config.d_model) - else: - raise ValueError( - f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." - ) - self.logit_scale = logit_scale - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - return_dict=return_dict, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=use_cache, - ) - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - if self.logit_scale is not None: - if self.logit_scale == 0: - warnings.warn( - f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." - ) - logits *= self.logit_scale - loss = None - if labels is not None: - labels = torch.roll(labels, shifts=-1) - labels[:, -1] = -100 - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) - ) - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - if inputs_embeds is not None: - raise NotImplementedError("inputs_embeds is not implemented for MPT yet") - attention_mask = kwargs["attention_mask"].bool() - if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError( - "MPT does not support generation with right padding." - ) - if self.transformer.attn_uses_sequence_id and self.training: - sequence_id = torch.zeros_like(input_ids[:1]) - else: - sequence_id = None - if past_key_values is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if self.transformer.prefix_lm: - prefix_mask = torch.ones_like(attention_mask) - if kwargs.get("use_cache") is False: - raise NotImplementedError( - "MPT with prefix_lm=True does not support use_cache=False." - ) - else: - prefix_mask = None - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "prefix_mask": prefix_mask, - "sequence_id": sequence_id, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache", True), - } - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - """Used by HuggingFace generate when using beam search with kv-caching. - - See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 - for an example in transformers. - """ - reordered_past = [] - for layer_past in past_key_values: - reordered_past += [ - tuple( - (past_state.index_select(0, beam_idx) for past_state in layer_past) - ) - ] - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py deleted file mode 100644 index 06731a6f..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ /dev/null @@ -1,796 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch GPTNeoX model.""" - -from typing import Optional, Tuple, Union - -import os -import torch -import torch.distributed -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - - -CUSTOM_KERNELS_ENABLED = False -if ( - torch.cuda.is_available() - and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" -): - try: - from custom_kernels import fused_attention_cuda - - CUSTOM_KERNELS_ENABLED = True - except ImportError: - pass - - -def make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.ones( - (target_length, target_length + past_key_values_length), - dtype=torch.bool, - device=device, - ) - mask = mask.triu(1 + past_key_values_length) - - expanded_mask = mask.unsqueeze(0).expand( - batch_size, target_length, target_length + past_key_values_length - ) - return expanded_mask - - -def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: - """ - Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. - """ - batch_size, src_length = mask.shape - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = ~(mask[:, None, :].to(torch.bool)) - return expanded_mask.expand(batch_size, tgt_length, src_length) - - -def prepare_attn_mask( - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, -) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -class GPTNeoXPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - -class GPTNeoXAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size * config.rotary_pct) - # ??? TODO - # self.register_buffer( - # "bias", - # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - # 1, 1, max_positions, max_positions - # ), - # ) - # self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, - config.max_position_embeddings, - base=config.rotary_emb_base, - ) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) - self.inv_norm_factor = 1.0 / torch.sqrt( - torch.tensor(self.head_size, dtype=torch.float32) - ).to(torch.get_default_dtype()) - - if self.num_attention_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_attention_heads` must be divisible by `num_shards` " - f"(got `num_attention_heads`: {self.num_attention_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_attention_heads = ( - self.num_attention_heads // weights.process_group.size() - ) - self.query_key_value = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True - ) - self.dense = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense", weights=weights, bias=True - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask, - head_mask=None, - layer_past=None, - use_cache=False, - output_attentions=False, - ): - has_layer_past = layer_past is not None - - # Compute QKV - # Attention heads [batch, seq_len, hidden_size] - # --> [batch, seq_len, (np * 3 * head_size)] - qkv = self.query_key_value(hidden_states) - - # [batch, seq_len, (num_heads * 3 * head_size)] - # --> [batch, seq_len, num_heads, 3 * head_size] - new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) - qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) - # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] - query, key, value = qkv.split(self.head_size, -1) - - # Compute token offset for rotary embeddings (when decoding) - seq_len = key.shape[-2] - if has_layer_past: - seq_len += layer_past[0].shape[-2] - - # Compute rotary embeddings on rotary_ndims - query_rot = query[..., : self.rotary_ndims] - key_rot = key[..., : self.rotary_ndims] - - query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) - - query[..., : self.rotary_ndims] = query_rot - key[..., : self.rotary_ndims] = key_rot - - if CUSTOM_KERNELS_ENABLED: - attn_output, present, attn_weights = fused_attention_cuda.forward( - query, - key, - value, - layer_past, - attention_mask, - head_mask, - self.inv_norm_factor, - self.num_attention_heads, - use_cache, - ) - else: - # Cache QKV values - if has_layer_past: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = (key, value) if use_cache else None - - # Compute attention - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask - ) - - # Reshape outputs - attn_output = self._merge_heads( - attn_output, self.num_attention_heads, self.head_size - ) - - attn_output = self.dense(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs - - @classmethod - def _split_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - # tensor: [bs, seq_len, hidden_size] - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(new_shape) - # -> [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3) - return tensor - - @classmethod - def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - # tensor [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3).contiguous() - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view( - tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size - ) - # -> [bs, seq_len, hidden_size] - return tensor - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] - # compute causal mask from causal mask buffer - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - - query = query.reshape( - batch_size * num_attention_heads, query_length, attn_head_size - ) - key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - 1, - dtype=query.dtype, - device=key.device, - ).expand(batch_size * num_attention_heads, query_length, key_length) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.inv_norm_factor, - ) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attn_scores.dtype - if input_dtype in [torch.float16, torch.bfloat16]: - attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where( - attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores - ) - attn_scores = attn_scores.view( - batch_size, num_attention_heads, query_length, key_length - ) - - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings, base=10000, device=None): - super().__init__() - self.true_inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2).float().to(device) / dim) - ) - self.register_buffer("inv_freq", self.true_inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - self.cos_cached = None - self.sin_cached = None - - @staticmethod - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - @staticmethod - def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange( - max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype - ) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) - - def forward(self, q, k, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if ( - seq_len > self.max_seq_len_cached - or self.cos_cached is None - or self.sin_cached is None - ): - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.cos_cached, self.sin_cached = self._create_cos_sin( - self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device - ) - return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) - - -@torch.jit.script -def rotary_forward(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) - - chunk_size = q.shape[-1] // 2 - q1, q2 = q.split(chunk_size, -1) - q_rotated = torch.cat((-q2, q1), dim=-1) - k1, k2 = k.split(chunk_size, -1) - k_rotated = torch.cat((-k2, k1), dim=-1) - - q_embed = (q * cos) + (q_rotated * sin) - k_embed = (k * cos) + (k_rotated * sin) - return q_embed, k_embed - - -class GPTNeoXMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.act = ( - ACT2FN[config.hidden_act] - if "gelu_fast" not in config.hidden_act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - self.dense_h_to_4h = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True - ) - self.dense_4h_to_h = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True - ) - - def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states - - -class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, prefix: str, config, weights): - super().__init__() - self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.input_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.attention = GPTNeoXAttention( - config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights - ) - self.mlp = GPTNeoXMLP( - config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask=None, - head_mask=None, - use_cache=False, - layer_past=None, - output_attentions=False, - ): - attention_layer_outputs = self.attention( - self.input_layernorm(hidden_states), - attention_mask=attention_mask, - position_ids=position_ids, - layer_past=layer_past, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attention_layer_outputs[ - 0 - ] # output_attn: attn_output, present, (attn_weights) - outputs = attention_layer_outputs[1:] - - if self.use_parallel_residual: - # pseudocode: - # x = x + attn(ln1(x)) + mlp(ln2(x)) - mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = mlp_output + attn_output + hidden_states - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - attn_output = attn_output + hidden_states - mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - hidden_states = mlp_output + attn_output - - if use_cache: - outputs = ( - hidden_states, - ) + outputs # hidden_states, present, (attn_weights) - else: - outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) - - return outputs - - -class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - self.config = config - - self.num_attention_heads = config.num_attention_heads - - self.embed_in = TensorParallelEmbedding( - prefix=f"{prefix}.embed_in", weights=weights - ) - self.layers = nn.ModuleList( - [ - GPTNeoXLayer(layer_id, prefix, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.tp_world_size = weights.process_group.size() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids=None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * self.config.num_hidden_layers) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_length, seq_length + past_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - - hidden_states = inputs_embeds - - # Attention mask. - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-1] - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), device=hidden_states.device - ) - else: - attention_mask = attention_mask.to(hidden_states.device) - - causal_mask = prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - assert self.num_attention_heads % self.tp_world_size == 0 - block_size = self.num_attention_heads // self.tp_world_size - causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - presents = () if use_cache else None - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = layer( - hidden_states, - position_ids=position_ids, - attention_mask=causal_mask, - head_mask=head_mask[i], - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_attentions = all_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.final_layer_norm(hidden_states) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_attentions] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "gpt_neox" - else: - prefix = f"{prefix}.gpt_neox" - - self.gpt_neox = GPTNeoXModel(prefix, config, weights) - self.embed_out = SpeculativeHead.load( - config, prefix="embed_out", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are - only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see - `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config.is_decoder = True - >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.gpt_neox( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - lm_logits, speculative_logits = self.embed_out(hidden_states) - - lm_loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # we are doing next-token prediction; shift prediction scores and input ids by one - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return ( - CausalLMOutputWithPast( - loss=lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - input_shape = input_ids.shape - - # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - ) - - return model_inputs - - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past[:2] - ) - + layer_past[2:], - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py deleted file mode 100644 index db73ae84..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ /dev/null @@ -1,864 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch OPT model.""" - -import random -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers import OPTConfig -from text_generation_server.layers import ( - FastLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -EPS = 1e-5 - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full( - (tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device, - ) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -class OPTLearnedPositionalEmbedding(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, prefix: str, weights): - super().__init__() - self.offset = 2 - self.weight = nn.Parameter( - weights.get_tensor( - f"{prefix if prefix else ''}decoder.embed_positions.weight" - ) - ) - - def forward( - self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 - ): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = ( - torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask - ).long() - 1 - - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] - - return torch.nn.functional.embedding(positions + self.offset, self.weight) - - -class OPTAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config, - prefix, - weights, - is_decoder: bool = False, - bias: bool = True, - process_group=None, - ): - super().__init__() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - - self.hidden_size = hidden_size - self.num_heads = num_heads - self.dropout = config.dropout - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - process_group = weights.process_group - if self.num_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads = self.num_heads // process_group.size() - self.hidden_size = self.hidden_size // process_group.size() - - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias - ) - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - if attn_weights.dtype == torch.float16: - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(torch.float16) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): - super().__init__() - self.process_group = weights.process_group - self.hidden_size = config.hidden_size - self.self_attn = OPTAttention( - config, - prefix=f"{prefix}.self_attn", - weights=weights, - is_decoder=True, - bias=config.enable_bias, - ) - self.do_layer_norm_before = config.do_layer_norm_before - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - - self.self_attn_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS - ) - self.fc1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias - ) - self.fc2 = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Fully Connected - hidden_states_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - hidden_states = (residual + hidden_states).view(hidden_states_shape) - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class OPTPreTrainedModel(PreTrainedModel): - config_class = OPTConfig - - -class OPTDecoder(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.vocab_size = config.vocab_size - - prefix = prefix + "." if prefix else "" - - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}decoder.embed_tokens", weights=weights - ) - self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) - - if config.word_embed_proj_dim != config.hidden_size: - self.project_out = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_out", - weights=weights, - bias=False, - ) - else: - self.project_out = None - - if config.word_embed_proj_dim != config.hidden_size: - self.project_in = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_in", - weights=weights, - bias=False, - ) - else: - self.project_in = None - - # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility - # with checkpoints that have been fine-tuned before transformers v4.20.1 - # see https://github.com/facebookresearch/metaseq/pull/164 - if config.do_layer_norm_before and not config._remove_final_layer_norm: - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS - ) - else: - self.final_layer_norm = None - - self.layers = nn.ModuleList( - [ - OPTDecoderLayer( - layer_id, - prefix=f"{prefix}decoder.layers.{layer_id}", - config=config, - weights=weights, - ) - for layer_id in range(config.num_hidden_layers) - ] - ) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - past_key_values_length = ( - past_key_values[0][0].shape[2] if past_key_values is not None else 0 - ) - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - causal_attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) - - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - - hidden_states = inputs_embeds + pos_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): - continue - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.final_layer_norm is not None: - hidden_states = self.final_layer_norm(hidden_states) - - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class OPTModel(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.decoder = OPTDecoder(prefix, config, weights) - # Initialize weights and apply final processing - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs - - return BaseModelOutputWithPast( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - ) - - -class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, prefix, config, weights): - super().__init__(config) - if not prefix and any(s.startswith("model") for s in weights.routing.keys()): - prefix = "model" - - self.model = OPTModel(prefix, config, weights) - - self.lm_head = SpeculativeHead.load( - config, - prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", - weights=weights, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - - loss = None - - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py deleted file mode 100644 index 3f2ed010..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ /dev/null @@ -1,336 +0,0 @@ -# imlementation of the PhiModel and PhiForCausalLM classes - -import torch -import torch.distributed - -import math -from torch import nn -from typing import Optional, List, Tuple -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast - -from text_generation_server.layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - SpeculativeHead, - FastLinear, -) - - -# PhiConfig is the configuration class for the PhiModel. -class PhiConfig(PretrainedConfig): - def __init__( - self, - vocab_size=51200, - n_positions=2048, - n_embd=2560, - n_layer=32, - n_inner=None, - n_head=32, - rotary_dim=32, - layer_norm_epsilon=1e-5, - tie_word_embeddings=False, - pad_vocab_size_multiple=64, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - no_bias=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.rotary_dim = rotary_dim - - self.layer_norm_epsilon = layer_norm_epsilon - self.tie_word_embeddings = tie_word_embeddings - self.pad_vocab_size_multiple = pad_vocab_size_multiple - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.no_bias = no_bias - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -# RotaryEmbedding is a class that implements the rotary embedding. -class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)] - inv_freq_len = len(inv_freq) - inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) - t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) - freqs = t.matmul(inv_freq) - self.sin = freqs.sin() - self.cos = freqs.cos() - - def apply_rotary_emb_qkv(self, qkv, seqlen_offset): - b_size, seqlen, three, _, _headdim = qkv.shape - if three != 3: - raise Exception("unexpected shape for qkv") - _, rotary_dim = self.cos.shape - rotary_dim = rotary_dim * 2 - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - q12 = torch.chunk(q_rot, 2, dim=-1) - k12 = torch.chunk(k_rot, 2, dim=-1) - q1, q2 = q12[0], q12[1] - k1, k2 = k12[0], k12[1] - c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - q_rot = torch.cat( - [ - q1 * c - q2 * s, - q1 * s + q2 * c, - ], - dim=-1, - ) - k_rot = torch.cat( - [ - k1 * c - k2 * s, - k1 * s + k2 * c, - ], - dim=-1, - ) - q = torch.cat([q_rot, q_pass], dim=-1) - k = torch.cat([k_rot, k_pass], dim=-1) - v = qkv[:, :, 2] - return q, k, v - - -# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm. -class PhiCausalLMHead(nn.Module): - def __init__(self, config, weights): - super().__init__() - self.ln = nn.LayerNorm.load( - prefix="lm_head.ln", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.linear = SpeculativeHead.load( - config=config, prefix="lm_head.linear", weights=weights - ) - - def forward(self, hidden_states): - hidden_states = self.ln(hidden_states) - hidden_states = self.linear(hidden_states) - return hidden_states - - -# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. -class PhiMHA(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - self.op_size = config.n_embd - self.head_dim = int(config.n_embd / config.n_head) - self.num_heads = config.n_head - self.rotary_emb = RotaryEmbedding( - config.rotary_dim, - config.n_positions, - ) - self.softmax_scale = 1.0 / math.sqrt(self.head_dim) - - def forward( - self, - hidden_states, - past_kv_cache, - attention_mask=None, - ): - b_size, seq_len, _n_embd = hidden_states.shape - qkv = self.Wqkv(hidden_states) - qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim) - seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1] - q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) - - # if there is a kv_cache, then we need to concatenate - if past_kv_cache is not None: - prev_k, prev_v = past_kv_cache - k = torch.cat([prev_k, k], dim=1) - v = torch.cat([prev_v, v], dim=1) - - past_kv_cache = [k, v] - attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale) - - if attention_mask is not None: - seqlen_k = k.shape[1] - seqlen_q = q.shape[1] - causal_mask = torch.triu( - torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), - 1, - ) - attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) - attn_output = ( - attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)) - .transpose(1, 2) - .flatten(-2) - ) - return self.out_proj(attn_output), past_kv_cache - - -# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. -class PhiMLP(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.n_inner = config.n_inner - self.fc1 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc1", - weights=weights, - bias=False, - ) - self.fc2 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc2", - weights=weights, - bias=False, - ) - self.activation = torch.nn.functional.gelu - - def forward(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. -class PhiBlock(nn.Module): - def __init__(self, layer_id, config, weights): - super().__init__() - self.layer_id = layer_id - self.layer_norm = nn.LayerNorm.load( - prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon - ) - self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) - self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) - - def forward( - self, - hidden_states, - kv_cache, - attention_mask, - ): - residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - attn_outputs, past_kv_cache = self.mixer( - hidden_states, kv_cache, attention_mask - ) - feed_forward_hidden_states = self.mlp(hidden_states) - out = attn_outputs + feed_forward_hidden_states + residual - return out, past_kv_cache - - -# PhiModel implements the embedding layer and the transformer blocks. -class PhiModel(nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - self.tp_rank = weights.process_group.rank() - self.tp_world_size = weights.process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embd.wte", weights=weights - ) - self.blocks = nn.ModuleList( - [ - PhiBlock(f"{prefix}.h.{layer_id}", config, weights) - for layer_id in range(config.n_layer) - ] - ) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - hidden_states = self.embed_tokens(input_ids) - seq_len = hidden_states.shape[1] - mask = None if seq_len <= 1 else attention_mask - - past_key_values = ( - [None] * len(self.blocks) if past_key_values is None else past_key_values - ) - - for index, block in enumerate(self.blocks): - hidden_states, new_key_values = block( - hidden_states, past_key_values[index], mask - ) - past_key_values[index] = new_key_values - - return hidden_states, past_key_values - - -# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. -class PhiForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - self.model = PhiModel(prefix, config, weights) - self.lm_head = PhiCausalLMHead(config, weights) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - model_output = self.model( - input_ids, past_key_values, attention_mask, return_dict, use_cache - ) - logits = self.lm_head(model_output[0]) - - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()( - logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1) - ) - - if not return_dict: - return ( - ((loss,) + (logits,) + model_output[1:]) - if loss is not None - else (logits,) + model_output[1:] - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=model_output[1], - hidden_states=None, - attentions=None, - ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py deleted file mode 100644 index e6666acd..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ /dev/null @@ -1,1227 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch T5 model.""" - -import copy -import math -import warnings -from typing import Optional, Tuple, Union - -from loguru import logger - -import torch -import torch.distributed -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( - is_torch_fx_proxy, -) -from transformers import T5Config -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" - - -class PartialTPEmbedding(nn.Module): - def __init__(self, prefix: str, weights): - super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=1) - self.weight = nn.Parameter(weight) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.embedding(input, self.weight) - - -@torch.jit.script -def layer_norm(hidden_states, weight, epsilon): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + epsilon) - - # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(weight.dtype) - - return weight * hidden_states - - -class T5LayerNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = torch.tensor(eps) - - def forward(self, hidden_states): - return layer_norm(hidden_states, self.weight, self.variance_epsilon) - - -try: - from apex.normalization import FusedRMSNorm - - T5LayerNorm = FusedRMSNorm # noqa - - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" - ) -except ImportError: - # using the normal T5LayerNorm - pass -except Exception: - logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") - pass - -ALL_LAYERNORM_LAYERS.append(T5LayerNorm) - - -class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi", weights=weights, bias=False - ) - - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi_0 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_0", weights=weights, bias=False - ) - self.wi_1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_1", weights=weights, bias=False - ) - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - if config.is_gated_act: - self.DenseReluDense = T5DenseGatedActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - else: - self.DenseReluDense = T5DenseActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class T5Attention(nn.Module): - def __init__( - self, config: T5Config, prefix, weights, has_relative_attention_bias=False - ): - super().__init__() - self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - process_group = weights.process_group - # Mesh TensorFlow initialization to avoid scaling before softmax - assert self.n_heads % process_group.size() == 0 - self.q = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q", weights=weights, bias=False - ) - self.k = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k", weights=weights, bias=False - ) - self.v = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v", weights=weights, bias=False - ) - self.o = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.o", weights=weights, bias=False - ) - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // process_group.size() - self.inner_dim = self.inner_dim // process_group.size() - - if self.has_relative_attention_bias: - self.relative_attention_bias = PartialTPEmbedding( - prefix=f"{prefix}.relative_attention_bias", weights=weights - ) - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets - - def compute_bias(self, query_length, key_length, device=None): - """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length - ) - - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) - - def shape(states): - """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) - - def unshape(states): - """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device - ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = ( - position_bias + mask - ) # (batch_size, n_heads, seq_length, key_length) - - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class T5LayerSelfAttention(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias=False): - super().__init__() - self.SelfAttention = T5Attention( - config, - prefix=f"{prefix}.SelfAttention", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5LayerCrossAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.EncDecAttention = T5Attention( - config, - prefix=f"{prefix}.EncDecAttention", - weights=weights, - has_relative_attention_bias=False, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - query_length=query_length, - output_attentions=output_attentions, - ) - layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5Block(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): - super().__init__() - self.is_decoder = config.is_decoder - self.layer = nn.ModuleList() - self.layer.append( - T5LayerSelfAttention( - config, - prefix=f"{prefix}.layer.0", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - ) - if self.is_decoder: - i = 2 - self.layer.append( - T5LayerCrossAttention( - config, prefix=f"{prefix}.layer.1", weights=weights - ) - ) - else: - i = 1 - - self.layer.append( - T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) - ) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning( - "`past_key_values` is passed to the encoder. Please make sure this is intended." - ) - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[ - 2: - ] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - outputs = (hidden_states,) - - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - - -class T5PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = T5Config - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." - " See T5 docs for more information" - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id - - assert ( - pad_token_id is not None - ), "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - - -class T5Stack(T5PreTrainedModel): - def __init__(self, config, prefix, weights, embed_tokens): - super().__init__(config) - - self.is_decoder = config.is_decoder - - self.embed_tokens = embed_tokens - self.block = nn.ModuleList( - [ - T5Block( - config, - prefix=f"{prefix}.block.{layer_id}", - weights=weights, - has_relative_attention_bias=(layer_id == 0), - ) - for layer_id in range(config.num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - # Model parallel - use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" - ) - - if inputs_embeds is None: - assert ( - self.embed_tokens is not None - ), "You have to initialize the model with valid token embeddings" - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values[0][0].shape[2] + seq_length - if past_key_values is not None - else seq_length - ) - - if use_cache is True: - assert ( - self.is_decoder - ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" - - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - if ( - self.is_decoder - and encoder_attention_mask is None - and encoder_hidden_states is not None - ): - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, - encoder_seq_length, - device=inputs_embeds.device, - dtype=torch.long, - ) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.is_decoder and encoder_hidden_states is not None: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) - present_key_value_states = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds) - - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): - layer_head_mask = head_mask[i] - cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class T5ForConditionalGeneration(T5PreTrainedModel): - def __init__(self, config: T5Config, weights): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack( - config=encoder_config, - prefix="encoder", - weights=weights, - embed_tokens=self.shared, - ) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack( - config=decoder_config, - prefix="decoder", - weights=weights, - embed_tokens=self.shared, - ) - - try: - self.lm_head = SpeculativeHead.load( - config, prefix="lm_head", weights=weights - ) - except RuntimeError: - # Some models like t5-small were saved with shared weights unlike flan - # Since they are declared as the same arch we have no choice but hope - # that this is OK instead of using a proper flag. - self.lm_head = SpeculativeHead.load( - config, prefix="shared", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - logits, speculative_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # move labels to correct device to enable PP - labels = labels.to(logits.device) - loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - return ( - Seq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning( - "You might want to consider setting `use_cache=True` to speed up decoding" - ) - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) - return reordered_decoder_past From f95aa426602fcfdf589eee55ea2e19eeafe438ba Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 19 Mar 2025 23:27:27 -0700 Subject: [PATCH 10/35] multi-modality initial PR Signed-off-by: Wang, Yi A --- .../text_generation_server/models/__init__.py | 51 +- .../custom_modeling/flash_llama_modeling.py | 2 +- .../custom_modeling/flash_llava_next.py | 290 +++++ .../models/custom_modeling/flash_mllama.py | 996 ++++++++++++++++++ .../models/flash_vlm_causal_lm.py | 480 +++++++++ .../models/mllama_causal_lm.py | 45 +- .../models/pali_gemma.py | 6 +- .../server/text_generation_server/server.py | 6 + 8 files changed, 1829 insertions(+), 47 deletions(-) create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py create mode 100644 backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 6833ecce..7dac910e 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -19,18 +19,10 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder -from text_generation_server.models.vlm_causal_lm import VlmCausalLM -from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( PhiMoEConfig, ) -# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, @@ -58,8 +50,8 @@ if ATTENTION == "paged": try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.vlm_causal_lm import VlmCausalLM - from text_generation_server.models.mllama_causal_lm import MllamaCausalLM + from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( FlashDeepseekV2ForCausalLM, DeepseekV2Config, @@ -101,12 +93,12 @@ try: FlashPhiForCausalLM, ) from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM - from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch - from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch + from text_generation_server.models.custom_modeling.flash_mllama import ( + FlashMllamaForConditionalGeneration, ) - from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, + from text_generation_server.models.custom_modeling.flash_llava_next import ( + FlashLlavaNextForConditionalGeneration, ) from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( @@ -751,7 +743,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == QWEN2_VL: - return VlmCausalLM( + return FlashVlmCausalLM( model_id=model_id, model_class=Qwen2VLForConditionalGeneration, revision=revision, @@ -764,7 +756,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) elif model_type == QWEN2_5_VL: - return VlmCausalLM( + return FlashVlmCausalLM( model_id=model_id, model_class=Qwen2_5VLForConditionalGeneration, revision=revision, @@ -779,10 +771,10 @@ def get_model( processor_class=Qwen2_5_VLProcessor, ) elif model_type == MLLAMA: - return MllamaCausalLM( + return FlashMllamaCausalLM( model_id=model_id, - model_class=MllamaForConditionalGeneration, - batch_class=MllamaCausalLMBatch, + model_class=FlashMllamaForConditionalGeneration, + batch_class=FlashMllamaCausalLMBatch, revision=revision, quantize=quantize, speculator=speculator, @@ -792,7 +784,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) elif model_type == IDEFICS2: - return VlmCausalLM( + return FlashVlmCausalLM( model_id=model_id, model_class=Idefics2ForConditionalGeneration, revision=revision, @@ -807,7 +799,7 @@ def get_model( processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) elif model_type == IDEFICS3: - return VlmCausalLM( + return FlashVlmCausalLM( model_id=model_id, model_class=Idefics3ForConditionalGeneration, revision=revision, @@ -822,7 +814,7 @@ def get_model( processor_kwargs={"size": {"longest_edge": 1456}}, ) elif model_type == PALIGEMMA: - return VlmCausalLM( + return FlashVlmCausalLM( model_id=model_id, model_class=PaliGemmaForConditionalGeneration, revision=revision, @@ -837,8 +829,8 @@ def get_model( batch_class=PaliGemmaBatch, ) elif model_type == LLAVA_NEXT: - return VlmCausalLM( - model_class=LlavaNextForConditionalGeneration, + return FlashVlmCausalLM( + model_class=FlashLlavaNextForConditionalGeneration, model_id=model_id, revision=revision, quantize=quantize, @@ -847,6 +839,15 @@ def get_model( kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, ) + + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.mllama import ( + MllamaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + adapt_transformers_to_gaudi() if SDP_ON_BF16 == 1: torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a0c4fb8c..7deb6cbf 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -503,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module): # Skip first and last layers for layer_id in range(1, config.num_hidden_layers - 1): if layer_id in self.cross_attention_layers: - from text_generation_server.models.custom_modeling.mllama import ( + from text_generation_server.models.custom_modeling.flash_mllama import ( FlashLlamaCrossLayer, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py new file mode 100644 index 00000000..0e3487dc --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -0,0 +1,290 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava-NeXT model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.image_processing_utils import select_best_resolution + +from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (height, width). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (height, width). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext +class LlavaNextMultiModalProjector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.linear_1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class FlashLlavaNextForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + vision_config = config.vision_config + # Instead of selecting in hidden_states[-2]. + # Instead compute only the n -2 + 1 layers and don't pool + if config.vision_feature_layer < 0: + vision_config.num_hidden_layers += config.vision_feature_layer + 1 + else: + vision_config.num_hidden_layers = config.vision_feature_layer + 1 + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.multi_modal_projector = LlavaNextMultiModalProjector( + prefix="multi_modal_projector", config=config, weights=weights + ) + + self.image_newline = weights.get_tensor("image_newline") + + self.vocab_size = config.text_config.vocab_size + self.config = config + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = input_ids == self.config.image_token_index + # Let's pray we have enabled enough slots ! + try: + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + except Exception as e: + raise RuntimeError( + f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + # Unused for this model + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None and len(pixel_values) > 0: + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_features + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py new file mode 100644 index 00000000..cf47208b --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -0,0 +1,996 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, +) +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + attention_mask = ( + attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave( + num_vision_tokens, dim=3 + ) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision +class MllamaVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.embed_dim = config.hidden_size + self.head_dim = config.hidden_size // config.attention_heads + self.num_heads = config.attention_heads // weights.process_group.size() + + self.qkv_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_state) + query, key, value = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + ], + dim=2, + ) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = MllamaVisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 + ) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False + ) + self.gate_ffn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): + super().__init__() + self.config = config + self.layers = [ + MllamaVisionEncoderLayer( + prefix=f"{prefix}.layers.{i}", + config=config, + weights=weights, + is_gated=is_gated, + ) + for i in range(num_layers) + ] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + encoder_states = [hidden_states] + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs + encoder_states.append(hidden_states) + + return hidden_states, encoder_states + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + + self.embedding = TensorParallelEmbedding( + prefix=f"{prefix}.embedding", weights=weights + ) + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + # Always gated. + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + # position embedding + embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.embedding"), requires_grad=False + ) + self.gated_position_embedding = (1 - self.gate.tanh()) * embedding + self.tile_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.tile_embedding", weights=weights + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaVisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + self.dtype = weights.dtype + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False + ) + + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + prefix=f"{prefix}.gated_positional_embedding", + config=config, + weights=weights, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.pre_tile_positional_embedding", + config=config, + weights=weights, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.post_tile_positional_embedding", + config=config, + weights=weights, + ) + + ## layer norms + self.layernorm_pre = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_pre", + weights=weights, + # torch default + eps=1e-05, + ) + self.layernorm_post = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_post", + weights=weights, + # torch default + eps=1e-05, + ) + + ## encoders + self.transformer = MllamaVisionEncoder( + prefix=f"{prefix}.transformer", + config=config, + weights=weights, + is_gated=False, + num_layers=config.num_hidden_layers, + ) + self.global_transformer = MllamaVisionEncoder( + prefix=f"{prefix}.global_transformer", + config=config, + weights=weights, + is_gated=True, + num_layers=config.num_global_layers, + ) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # patch embedding + patch_embeds = self.patch_embedding(pixel_values) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + if attention_mask is not None: + attention_mask = attention_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + hidden_state, all_intermediate_hidden_states = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + hidden_state, _ = self.global_transformer( + hidden_state, attention_mask=attention_mask + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, *, prefix, config, weights, layer_idx): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_size = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.q_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps + ) + self.k_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps + ) + self.softmax_scale = self.head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + # past_key_value=None, + # attention_mask: Optional[torch.Tensor] = None, + # cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # hidden_states = hidden_states.unsqueeze(0) + # bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(-1, self.num_heads, self.head_size) + query_states = self.q_norm(query_states) + + ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + indices, + ) = cross_attention_states + + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) + value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + key_states = self.k_norm(key_states) + + # key_states = key_states.repeat(1, self.num_key_value_groups, 1) + # value_states = value_states.repeat(1, self.num_key_value_groups, 1) + + causal = False + # logger.info( + # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" + # ) + # execute sdpa + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query_states, + key_states, + value_states, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + return attn_output + + +# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText +class MllamaTextMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + shape = x.shape + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) + result = self.down_proj( + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] + ) + return result + + +class FlashLlamaCrossLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, *, prefix, config, weights, index) -> None: + layer_idx = index + super().__init__() + self.cross_attn = MllamaTextCrossAttention( + prefix=f"{prefix}.cross_attn", + config=config, + weights=weights, + layer_idx=layer_idx, + ) + + self.input_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.cross_attn_attn_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False + ) + + self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.post_attention_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.cross_attn_mlp_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False + ) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + cross_attention_states, # [ IB, ...] + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cross_attention_states is None: + return hidden_states, residual + if residual is not None: + hidden_states += residual + + indices = cross_attention_states[-1] + out_hidden_states = hidden_states[:] + if len(indices) > 0: + assert max(indices) < hidden_states.shape[0] + hidden_states = hidden_states[indices] + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + # attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + out_hidden_states[indices] = hidden_states + hidden_states = out_hidden_states + + return hidden_states, None + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, weight, eps): + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + @classmethod + def load(cls, *, prefix, weights, eps): + weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + return cls(weight=weight, eps=eps) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class FlashMllamaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + config.text_config._attn_implementation = "sdpa" + self.hidden_size = config.text_config.hidden_size + self.vision_model = MllamaVisionModel( + prefix="vision_model", config=config.vision_config, weights=weights + ) + self.multi_modal_projector = FastLinear.load( + prefix="multi_modal_projector", config=config, weights=weights, bias=True + ) + self.text_model = FlashLlamaForCausalLM( + prefix="language_model", config=config.text_config, weights=weights + ) + self.config = config + self.dtype = weights.dtype + self.device = weights.device + + def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # logger.info(f"PIxel values {pixel_values.shape}") + batch_size = pixel_values.shape[0] + vision_states = self.vision_model( + pixel_values, aspect_ratio_ids, aspect_ratio_mask + ) + cross_attention_states = self.multi_modal_projector(vision_states).reshape( + -1, vision_states.shape[-2], self.hidden_size + ) + _, _, h = cross_attention_states.shape + cross_attention_states = cross_attention_states.view(batch_size, -1, h) + # logger.info(f"cross {cross_attention_states.shape}") + return cross_attention_states + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, + # XXX: Putting these as optional so that the cuda warmup calls can go through. + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + if cross_attention_states is not None: + seqlen_q = len(image_indices) + n_images = cross_attention_states.shape[0] + seqlen_k = cross_attention_states.shape[1] + device = cross_attention_states.device + if cu_seqlen_prefill is not None: + offset = 0 + cu_q = [] + indices = [] + for index in image_indices: + cu_q.append(offset) + length = seqlen.input_lengths[index].item() + assert index < seqlen.cu_seqlen_q.shape[0] + input_ids_offset = seqlen.cu_seqlen_q[index] + indices.extend(range(input_ids_offset, input_ids_offset + length)) + offset += length + cu_q.append(offset) + cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) + + assert max(indices) < input_ids.shape[0] + + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + max_q = cu_seqlen_q[-1].item() + max_k = seqlen_k + else: + cu_seqlen_q = torch.arange( + seqlen_q + 1, device=device, dtype=torch.int32 + ) + seqlen_k = cross_attention_states.shape[1] + n_images = cross_attention_states.shape[0] + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + max_q = seqlen_q + max_k = seqlen_k + indices = image_indices[:] + + cross_attention_states = ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + indices, + ) + + outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + cross_attention_states=cross_attention_states, + ) + + return outputs diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py new file mode 100644 index 00000000..5d4d68fd --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -0,0 +1,480 @@ +import torch +from PIL import Image +from io import BytesIO + +from opentelemetry import trace +from typing import Iterable, Optional, Tuple, List, Type, Dict + +from transformers import PreTrainedTokenizerBase +from transformers.image_processing_utils import select_best_resolution +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, +) +from text_generation_server.models.globals import PREFIX_CACHING +from loguru import logger +from text_generation_server.utils.log import log_master +from transformers import AutoProcessor +from text_generation_server.layers.attention import Seqlen + +tracer = trace.get_tracer(__name__) + +IDEFICS2_FAKE_TOKEN = "" +IDEFICS2_IMAGE_TOKEN = "" + +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 +def _prompt_split_image( + *, + image_seq_len: int, + image_rows: int, + image_cols: int, + fake_token_around_image: str, + image_token: str, + global_img_token: str, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (height, width). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_text_replacement(processor, image_input, config, image_id: int) -> str: + if config.model_type == "idefics2": + image_seq_len = 64 + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" + if processor.image_processor.do_image_splitting: + image_str *= 5 + return image_str + if config.model_type == "idefics3": + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) + image_str = _prompt_split_image( + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) + return image_str + elif config.model_type == "llava_next": + height, width = image_input["image_sizes"][image_id] + num_features = get_number_of_features(height, width, config) + + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", + ) + return "" * num_features + + elif config.model_type == "paligemma": + return "" * config.text_config.num_image_tokens + elif config.model_type == "qwen2_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "qwen2_5_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "gemma3": + # TODO: get correct number of features via reviewing the Gemma3 architecture + # and calculating the number of image tokens + num_pads = 256 + padding = "" * num_pads + return f"\n\n{padding}\n\n" + else: + raise RuntimeError(f"Unknown config {config.model_type} for multimodal") + + +def image_text_replacement_fixup(config, text: str) -> str: + if config.model_type == "idefics2": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) + return text + + +def get_unpadded_features( + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = original_width / original_height + current_aspect_ratio: float = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +def get_number_of_features(height: int, width: int, config) -> int: + # From config + # Hardcoded for CLIP for now + # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + image_grid_pinpoints = config.image_grid_pinpoints + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + + assert image_size % patch_size == 0 + + npatches = image_size // patch_size + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + [height, width], + image_grid_pinpoints, + image_size, + ) + unpadded_features, newline_features = get_unpadded_features( + height, width, npatches, num_patch_height, num_patch_width + ) + # The base patch covers the entire image + base_features = npatches**2 + return unpadded_features + newline_features + base_features + + +class FlashVlmCausalLMBatch(FlashCausalLMBatch): + pixel_values: Optional[List[torch.Tensor]] + pixel_attention_mask: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] + image_grid_thw: Optional[torch.Tensor] + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + batch = super().filter(request_ids) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): + # Process images first. We need all of them so that the processor + # can make the image splits the same size. And we need the final + # sizes to insert correct number of image tokens. + images = [] + for r in requests: + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + pass + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the + # default warmup image is 20x20 + if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if image.width <= 20: + w = image.width * 2 + h = image.height * 2 + image = image.resize((w, h)) + + if config.model_type == "llava_next": + images.append(image) + elif config.model_type == "gemma3": + images.append(image) + else: + images.append([image]) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + if images: + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + image_inputs = processor.image_processor( + images, return_tensors="pt", **kwargs + ) + else: + image_inputs = None + + batch_tokenized_inputs = [] + max_length = 0 + image_id = 0 + for r in requests: + full_text = "" + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + full_text += image_text_replacement( + processor, image_inputs, config, image_id + ) + image_id += 1 + + full_text = image_text_replacement_fixup(config, full_text) + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) + + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashVlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to(device=device) + if "pixel_attention_mask" in image_inputs: + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + else: + batch.pixel_attention_mask = None + if "image_sizes" in image_inputs: + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.image_sizes = None + if "image_grid_thw" in image_inputs: + batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) + else: + batch.image_grid_thw = None + else: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + +class FlashVlmCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=FlashVlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, + **kwargs, + ) + + @property + def batch_type(self) -> Type[FlashVlmCausalLMBatch]: + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) + + def forward( + self, + batch: FlashVlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if position_ids.dim() == 1 and batch.prefilling: + position_ids = self.model.get_position_ids( + input_ids, batch.image_grid_thw + ) + batch.position_ids = position_ids + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 507dabee..f149d462 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -1,15 +1,21 @@ -from io import BytesIO -from PIL import Image import torch + +import numpy as np + from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request - +from io import BytesIO +from PIL import Image from dataclasses import dataclass from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM + +from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + FlashVlmCausalLM, +) from text_generation_server.pb import generate_pb2 from text_generation_server.layers.attention import Seqlen @@ -18,7 +24,7 @@ tracer = trace.get_tracer(__name__) @dataclass -class MllamaCausalLMBatch(VlmCausalLMBatch): +class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): image_indices: List[int] = 42 aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None @@ -154,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): config, dtype: torch.dtype, device: torch.device, - ) -> "VlmCausalLMBatch": + ) -> "FlashVlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( pb.requests, tokenizer, processor, config ) @@ -163,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( max=config.text_config.vocab_size - 1 ) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: @@ -183,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): return batch -class MllamaCausalLM(VlmCausalLM): +class FlashMllamaCausalLM(FlashVlmCausalLM): def forward( self, - batch: VlmCausalLMBatch, + batch: FlashMllamaCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward @@ -198,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -217,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -240,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -250,14 +263,10 @@ class MllamaCausalLM(VlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - input_lengths = input_lengths + prefix_lens_tensor - max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, ) if batch.pixel_values is not None: diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py index fe75570e..e91aaed9 100644 --- a/backends/gaudi/server/text_generation_server/models/pali_gemma.py +++ b/backends/gaudi/server/text_generation_server/models/pali_gemma.py @@ -4,8 +4,8 @@ import torch import torch.distributed from opentelemetry import trace from typing import Iterable -from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, +from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, image_text_replacement, ) @@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request tracer = trace.get_tracer(__name__) -class PaliGemmaBatch(VlmCausalLMBatch): +class PaliGemmaBatch(FlashVlmCausalLMBatch): @classmethod def batch_tokenized_inputs( cls, requests: Iterable[Request], tokenizer, processor, config diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 674a8aed..7a8a51d6 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -25,15 +25,21 @@ from text_generation_server.utils.tokens import make_tokenizer_optional try: from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) + from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + ) from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch VLM_BATCH_TYPES = { PaliGemmaBatch, VlmCausalLMBatch, + FlashVlmCausalLMBatch, IdeficsCausalLMBatch, + FlashMllamaCausalLMBatch, } except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. From 36b6612f9799a65bb6cca9042a5f9a247796d15b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 20 Mar 2025 01:09:58 -0700 Subject: [PATCH 11/35] adjust warmup and enable vlm Signed-off-by: Wang, Yi A --- .../custom_modeling/flash_llava_next.py | 7 ++- .../models/custom_modeling/flash_mllama.py | 36 ++++++------- .../flash_pali_gemma_modeling.py | 7 ++- .../models/custom_modeling/idefics2.py | 7 ++- .../models/custom_modeling/idefics3.py | 7 ++- .../models/custom_modeling/qwen2_5_vl.py | 6 +-- .../models/custom_modeling/qwen2_vl.py | 6 +-- .../models/flash_causal_lm.py | 52 ++---------------- .../models/flash_vlm_causal_lm.py | 11 +++- .../models/mllama_causal_lm.py | 16 ++++-- .../models/vlm_causal_lm.py | 11 ++-- .../server/text_generation_server/server.py | 53 ++++++++++++++++--- .../utils/prefill_chunking.py | 24 +++++++++ 13 files changed, 134 insertions(+), 109 deletions(-) create mode 100644 backends/gaudi/server/text_generation_server/utils/prefill_chunking.py diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 0e3487dc..3bdfdd83 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -23,7 +23,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -172,7 +172,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -279,8 +279,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index cf47208b..b26adad7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -31,6 +31,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.attention import ( Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, @@ -678,23 +679,23 @@ class MllamaTextCrossAttention(nn.Module): """Input shape: Batch x Time x Channel""" # hidden_states = hidden_states.unsqueeze(0) # bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(-1, self.num_heads, self.head_size) - query_states = self.q_norm(query_states) - ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, - max_q, - max_k, indices, ) = cross_attention_states + bs = cu_seqlen_q.size(0) - 1 + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bs, -1, self.num_heads, self.head_size) + query_states = self.q_norm(query_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) - value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size) + value_states = value_states.view( + bs, -1, self.num_key_value_heads, self.head_size + ) key_states = self.k_norm(key_states) # key_states = key_states.repeat(1, self.num_key_value_groups, 1) @@ -705,9 +706,9 @@ class MllamaTextCrossAttention(nn.Module): # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) # execute sdpa - query_states = query_states.unsqueeze(0).transpose(1, 2) - key_states = key_states.unsqueeze(0).transpose(1, 2) - value_states = value_states.unsqueeze(0).transpose(1, 2) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) fsdpa_op = ModuleFusedSDPA(FusedSDPA) attn_output = fsdpa_op( query_states, @@ -803,9 +804,10 @@ class FlashLlamaCrossLayer(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, # [ IB, ...] + prefill_cache_indices, + hpu_attention_meta, ) -> Tuple[torch.Tensor, torch.Tensor]: if cross_attention_states is None: return hidden_states, residual @@ -912,7 +914,7 @@ class FlashMllamaForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, @@ -949,8 +951,6 @@ class FlashMllamaForConditionalGeneration(nn.Module): ) * seqlen_k ) - max_q = cu_seqlen_q[-1].item() - max_k = seqlen_k else: cu_seqlen_q = torch.arange( seqlen_q + 1, device=device, dtype=torch.int32 @@ -965,16 +965,12 @@ class FlashMllamaForConditionalGeneration(nn.Module): ) * seqlen_k ) - max_q = seqlen_q - max_k = seqlen_k indices = image_indices[:] cross_attention_states = ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, - max_q, - max_k, indices, ) @@ -986,7 +982,7 @@ class FlashMllamaForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, adapter_data=adapter_data, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index b1f89eff..532f118f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,7 +19,7 @@ from torch import nn from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -72,7 +72,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -85,7 +85,6 @@ class PaliGemmaForConditionalGeneration(nn.Module): inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: - max_s += 1 position_ids += 1 if pixel_values is not None: @@ -110,7 +109,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 923123d6..31a01d7c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,7 +25,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -742,7 +742,7 @@ class Idefics2ForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -829,8 +829,7 @@ class Idefics2ForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index 580398cb..ce5e8115 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -24,7 +24,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -485,7 +485,7 @@ class Idefics3ForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -573,8 +573,7 @@ class Idefics3ForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index efd9cccd..832efdfa 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -40,6 +40,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.attention import ( Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, @@ -906,7 +907,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, @@ -937,8 +938,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index b32ab577..856635fd 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -39,6 +39,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.attention import ( Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, @@ -482,7 +483,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, @@ -512,8 +513,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 3a0dc15e..4cdf2628 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -61,7 +61,6 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments - from text_generation_server.utils.import_utils import ( empty_cache, synchronize, @@ -77,10 +76,6 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None -def small_power_of_2(n: int): - return 1 << ((n - 1).bit_length() - 1) - - def set_sliding_window(sliding_window: int): global SLIDING_WINDOW SLIDING_WINDOW = sliding_window @@ -91,40 +86,6 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW -def init_cpu_threads_env(rank_id: int, world_size: int): - import importlib.util - - if importlib.util.find_spec("numa") is not None: - import numa - import psutil - - nodes = numa.info.get_max_node() + 1 - rank_per_node = math.ceil(world_size / nodes) - num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int(rank_id / rank_per_node) - rank_offset_per_node = rank_id % rank_per_node - if os.getenv("OMP_NUM_THREADS") is None: - num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) - else: - num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) - if len(numa.memory.get_membind_nodes()) == nodes: - numa.memory.set_membind_nodes((node_id)) - torch.set_num_threads(num_cpus_per_rank) - if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True): - cpu_start = num_cpus_per_rank * rank_offset_per_node - numa.schedule.run_on_cpus( - 0, - *( - numa.info.node_to_cpus(node_id)[ - cpu_start : cpu_start + num_cpus_per_rank - ] - ), - ) - logger.info( - f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}" - ) - - @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -1447,16 +1408,13 @@ class FlashCausalLM(Model): def warmup( self, - request: generate_pb2.WarmupRequest, + batch: FlashCausalLMBatch, + max_input_tokens: Optional[int], + max_total_tokens: Optional[int], ): # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() - max_input_tokens = request.max_input_tokens - max_total_tokens = request.max_total_tokens - batch = self.batch_type.from_pb( - request.batch, self.tokenizer, self.dtype, self.device - ) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory @@ -1505,10 +1463,10 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") - if max_total_tokens is None or max_total_tokens == 0: + if max_total_tokens is None: max_total_tokens = sum(batch.cache_lengths) - if max_input_tokens is None or max_input_tokens == 0: + if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 del _batch, batch diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 5d4d68fd..7cff7797 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -16,7 +16,8 @@ from text_generation_server.models.globals import PREFIX_CACHING from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch tracer = trace.get_tracer(__name__) @@ -447,6 +448,10 @@ class FlashVlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, @@ -459,13 +464,15 @@ class FlashVlmCausalLM(FlashCausalLM): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - seqlen=seqlen, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, image_grid_thw=batch.image_grid_thw, + **kwargs, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index f149d462..be67b6ae 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -17,8 +17,8 @@ from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLM, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.layers.attention import Seqlen - +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch tracer = trace.get_tracer(__name__) @@ -279,6 +279,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states = batch.cross_attention_states + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -286,13 +290,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - seqlen=seqlen, - max_s=max_s, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, - adapter_data=adapter_data, + # TODO list + adapter_data=None, image_indices=batch.image_indices[:], + **kwargs, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index 66e00171..1c7b12b8 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") -if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) -else: - raise ValueError("MAX_BATCH_SIZE is not set") + PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = [] @@ -1464,6 +1460,11 @@ class VlmCausalLM(Model): batch = self.batch_from_pb(request.batch, is_warmup=True) max_input_tokens = request.max_input_tokens max_prefill_batch_size = batch.input_ids.shape[0] + max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") + if max_batch_size_str is not None: + MAX_BATCH_SIZE = int(max_batch_size_str) + else: + raise ValueError("MAX_BATCH_SIZE is not set") try: # max prefill batch size warmup diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 7a8a51d6..6e470361 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -18,10 +18,11 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, ATTENTION from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -109,14 +110,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens, max_input_tokens, max_total_tokens = ( - self.model.warmup(request) - ) + if ATTENTION == "paged": + set_max_prefill_tokens(request.max_prefill_tokens) + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.dtype, + self.model.device, + ) - # W/A for the skip tokenizer path - # We need to call make_tokenizer_optional after the warmup, - # because router is not aware of that feature - make_tokenizer_optional(self.model.tokenizer) + # Override default values with None for clearer semantics. + max_input_tokens = ( + request.max_input_tokens + if request.HasField("max_input_tokens") + else None + ) + max_total_tokens = ( + request.max_total_tokens + if request.HasField("max_total_tokens") + else None + ) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(batch, max_input_tokens, max_total_tokens) + ) + else: + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(request) + ) + + # W/A for the skip tokenizer path + # We need to call make_tokenizer_optional after the warmup, + # because router is not aware of that feature + make_tokenizer_optional(self.model.tokenizer) return generate_pb2.WarmupResponse( max_supported_total_tokens=max_supported_total_tokens, diff --git a/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 00000000..c227d30f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS From fdf0733f564159ea939ebc57d9e473562cb9f658 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 21 Mar 2025 01:01:37 -0700 Subject: [PATCH 12/35] fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/rotary.py | 15 ++++++++++++--- .../flash_pali_gemma_modeling.py | 1 + .../models/custom_modeling/idefics2.py | 19 ++++++++++++++++++- .../models/custom_modeling/idefics3.py | 19 ++++++++++++++++++- .../models/custom_modeling/qwen2_5_vl.py | 9 ++++++--- .../models/custom_modeling/qwen2_vl.py | 3 ++- .../models/custom_modeling/vlm.py | 2 +- 7 files changed, 58 insertions(+), 10 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index b25f9fab..6a83d6a5 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -95,7 +95,10 @@ class PositionRotaryEmbedding(nn.Module): mrope_section = rope_scaling["mrope_section"] if mrope_section is not None: return RotaryPositionEmbeddingMultimodalSections( - inv_freq, scaling_factor, mrope_section + inv_freq, + scaling_factor, + mrope_section, + config.max_position_embeddings, ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] @@ -557,8 +560,13 @@ def apply_llama3_scaling( class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): - def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list): - super().__init__(inv_freq, scaling_factor) + def __init__( + self, + inv_freq: torch.Tensor, + scaling_factor: float, + sections: list, + max_position_embeddings, + ): self.sections = sections self._cos_cached = None self._sin_cached = None @@ -568,6 +576,7 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): .view(1, 1, -1) .to(inv_freq.device) ) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache( self, dtype: torch.dtype, device: torch.device, seqlen: int diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 532f118f..af0f8f89 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -110,6 +110,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, + prefill_cache_indices=None, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 31a01d7c..0a4305ec 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module): ): """In place merges in vision_embeddings with inputs_embeds.""" # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds @@ -794,6 +795,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ].contiguous() patch_size = self.config.vision_config.patch_size + """ patches_subgrid = pixel_attention_mask.unfold( dimension=1, size=patch_size, step=patch_size ) @@ -801,6 +803,21 @@ class Idefics2ForConditionalGeneration(nn.Module): dimension=2, size=patch_size, step=patch_size ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) # Get sequence from the vision encoder image_hidden_states = self.vision_model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index ce5e8115..9278a86a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -471,7 +471,8 @@ class Idefics3ForConditionalGeneration(nn.Module): ): """In place merges in vision_embeddings with inputs_embeds.""" # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds @@ -539,6 +540,7 @@ class Idefics3ForConditionalGeneration(nn.Module): ].contiguous() patch_size = self.config.vision_config.patch_size + """ patches_subgrid = pixel_attention_mask.unfold( dimension=1, size=patch_size, step=patch_size ) @@ -546,6 +548,21 @@ class Idefics3ForConditionalGeneration(nn.Module): dimension=2, size=patch_size, step=patch_size ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) # Get sequence from the vision encoder image_hidden_states = self.vision_model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 832efdfa..75dd2b40 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -739,10 +739,12 @@ class Qwen2_5VisionModel(nn.Module): cu_window_seqlens = torch.tensor( cu_window_seqlens, - device=hidden_states.device, + device="cpu", dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to( + hidden_states.device + ) # create a cu_seqlens tensor to be used in the attention mask cu_seqlens = torch.repeat_interleave( @@ -928,7 +930,8 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): image_embeds = self.visual( pixel_values, grid_thw=image_grid_thw ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 856635fd..3b4965a2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -503,7 +503,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_embeds = self.visual( pixel_values, grid_thw=image_grid_thw ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py index 94b8522d..ae704af3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None): FlashGemmaForCausalLM, ) - return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + return FlashGemmaForCausalLM(prefix, config, weights) elif config.model_type == "gemma2": from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( FlashGemma2ForCausalLM, From 9914ffe1f195e8dba790d91db6ba931fa4bbf329 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 21 Mar 2025 18:28:58 -0700 Subject: [PATCH 13/35] remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A --- .../gaudi/server/tests/utils/test_weights.py | 137 ------ .../server/text_generation_server/cli.py | 6 - .../layers/awq/quantize/hpu.py | 117 ++++- .../layers/compressed_tensors/__init__.py | 3 - .../layers/compressed_tensors/loader.py | 196 -------- .../layers/compressed_tensors/w8a8_int.py | 239 --------- .../layers/compressed_tensors/w8an_fp.py | 168 ------- .../layers/compressed_tensors/wna16_int.py | 188 ------- .../layers/compressed_tensors/wna16_int_24.py | 101 ---- .../layers/gptq/__init__.py | 14 +- .../layers/gptq/{ipex.py => hpu.py} | 306 +++++++----- .../layers/marlin/__init__.py | 15 - .../layers/marlin/fp8.py | 141 ------ .../layers/marlin/gptq.py | 465 ------------------ .../layers/marlin/marlin.py | 359 -------------- .../layers/marlin/util.py | 137 ------ .../custom_modeling/flash_gemma_modeling.py | 2 + .../custom_modeling/flash_gpt2_modeling.py | 4 - .../flash_pali_gemma_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 2 +- .../flash_santacoder_modeling.py | 4 - .../models/seq2seq_lm.py | 2 +- .../utils/quantization.py | 57 +-- 23 files changed, 291 insertions(+), 2373 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py rename backends/gaudi/server/text_generation_server/layers/gptq/{ipex.py => hpu.py} (57%) delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/__init__.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/fp8.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/gptq.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/marlin.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/util.py diff --git a/backends/gaudi/server/tests/utils/test_weights.py b/backends/gaudi/server/tests/utils/test_weights.py index 556fcea1..c301e50e 100644 --- a/backends/gaudi/server/tests/utils/test_weights.py +++ b/backends/gaudi/server/tests/utils/test_weights.py @@ -7,10 +7,6 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader -from text_generation_server.layers.marlin.marlin import ( - MarlinWeight, - MarlinWeightsLoader, -) from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path @@ -40,11 +36,6 @@ def gptq_weights_loader_awq(): ) -@pytest.fixture -def marlin_weights_loader(): - return MarlinWeightsLoader(bits=4, is_marlin_24=False) - - dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -125,10 +116,6 @@ dummy_file_system = { "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, - "test_get_weights_col_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - }, "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ @@ -273,18 +260,6 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_weights_row_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), - }, - "test_get_multi_weights_col_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), - }, - "test_get_weights_col_packed_marlin": { - "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), - }, } @@ -718,33 +693,6 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_weights_col_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_weights_col( - prefix=prefix, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - # test_get_weights_col_packed @@ -868,36 +816,6 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_weights_col_packed_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_multi_weights_col( - prefixes=[prefix], - dim=0, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - print(expected_weight) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - # test_get_multi_weights_col @@ -1004,34 +922,6 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_multi_weights_col_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_multi_weights_col( - prefixes=[prefix], - dim=0, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" - - # test_get_weights_row @@ -1148,30 +1038,3 @@ def test_get_weights_row_gptq(gptq_weights_loader): assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" - - -def test_get_weights_row_marlin(marlin_weights_loader): - weights = MockWeights( - [ - "test_get_weights_row_marlin", - ], - device="cpu", - dtype=torch.float16, - process_group=dummy_process_group, - dummy_fs=dummy_file_system, - weights_loader=marlin_weights_loader, - ) - - prefix = "weight" - - w = weights.get_weights_row( - prefix=prefix, - ) - - expected_weight = MarlinWeight( - B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), - ) - - assert torch.allclose(w.B, expected_weight.B), "B mismatch" - assert torch.allclose(w.s, expected_weight.s), "s mismatch" diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 24d1d748..e1c0298d 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -16,15 +16,9 @@ app = typer.Typer() class Quantization(str, Enum): - bitsandbytes = "bitsandbytes" - bitsandbytes_nf4 = "bitsandbytes-nf4" - bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" - eetq = "eetq" - exl2 = "exl2" fp8 = "fp8" - marlin = "marlin" class Dtype(str, Enum): diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py index 391371a5..3af0131b 100644 --- a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py @@ -1,19 +1,93 @@ -# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py - from typing import Optional import torch import torch.nn as nn -import awq_inference_engine # with CUDA kernels + +try: + import habana_frameworks.torch.hpu # noqa: F401 + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + +AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] -# class ScaledActivation(nn.Module): -# def __init__(self, module, scales): -# super().__init__() -# self.act = module -# self.scales = nn.Parameter(scales.data) -# -# def forward(self, x): -# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) +def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + # unpacking columnwise + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + iweights = iweights.view(iweights.shape[0], -1) + + # unpacking columnwise + if qzeros is not None: + izeros = torch.bitwise_right_shift( + qzeros[:, :, None], shifts[None, None, :] + ).to( + torch.int8 # smallest dtype available + ) + izeros = izeros.view(izeros.shape[0], -1) + else: + izeros = qzeros + + return iweights, izeros + + +def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange( + iweights.shape[-1], + dtype=torch.int32, + device=izeros.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + if izeros is not None: + izeros = izeros[:, reverse_order_tensor] + iweights = iweights[:, reverse_order_tensor] + + return iweights, izeros + + +def unpack_weight_and_zeros(qweight, qzeros, bits): + # Unpack the qweight and qzeros tensors + iweight, izeros = unpack_awq(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + return iweight, izeros + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros( + (normal.shape[0], normal.shape[1] // 32 * bits), + dtype=torch.int32, + device=input.device, + ) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q class WQLinear(nn.Module): @@ -38,12 +112,23 @@ class WQLinear(nn.Module): self.qzeros = qzeros self.scales = scales self.bias = bias + self._preprocessing() + + def _preprocessing(self): + device = self.qweight.device + weight, zeros = unpack_weight_and_zeros( + self.qweight.cpu(), self.qzeros.cpu(), self.w_bit + ) + self.qweight = pack_tensor(weight).to(device) + self.qzeros = pack_tensor(zeros).to(device) @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) - out = awq_inference_engine.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) + x = x.reshape(-1, x.shape[-1]) + weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + outputs = torch.matmul(x, weights) + + outputs = outputs + self.bias if self.bias is not None else outputs + outputs = outputs.reshape(out_shape) + return outputs diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py deleted file mode 100644 index 507af706..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .loader import CompressedTensorsLoader - -__all__ = ["CompressedTensorsLoader"] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py deleted file mode 100644 index 17d0224e..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import Any, Dict, List, Union - -from compressed_tensors import QuantizationConfig, QuantizationStatus -from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import ( - QuantizationScheme, - QuantizationType, - find_name_or_class_matches, -) -from loguru import logger -from pydantic import ValidationError -from torch import nn - -from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader -from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader -from text_generation_server.layers.compressed_tensors.wna16_int_24 import ( - WNA16Int24Loader, -) -from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import ( - DefaultWeightsLoader, - UnquantizedWeight, - Weights, - WeightsLoader, -) - -# compressed-tensors can match modules as quantization targets. However, -# they need to be objects rather than classes or class names. Since we -# need to match `Linear` targets, make an instance that can be re-used. -_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0) - - -class CompressedTensorsLoader(WeightsLoader): - """Loader for checkpoints stored in the compressed-tensors format.""" - - def __init__(self, config: Dict[str, Any]): - quantization_config_raw = config.get("quantization_config") - if quantization_config_raw is None: - # `compression_config` was renamed to `quantization_config`; support - # retained for backward compatibility. - quantization_config_raw = config.get("compression_config") - if quantization_config_raw is None: - raise ValueError( - "Checkpoint does not have compressed-tensors configuration" - ) - - try: - quantization_config = QuantizationConfig.model_validate( - quantization_config_raw - ) - except ValidationError as e: - raise ValueError("Cannot parse compressed-tensors configuration") from e - - if quantization_config.quantization_status not in ( - QuantizationStatus.COMPRESSED, - QuantizationStatus.FROZEN, - ): - raise ValueError( - f"Model quantization was not finished, status was: {quantization_config.quantization_status}" - ) - - self.ignore = ( - quantization_config.ignore if quantization_config.ignore is not None else [] - ) - self.loaders = self._get_target_loaders(quantization_config) - - for target, loader in self.loaders.items(): - log_once( - logger.info, - f"Using {loader} for compressed-tensors target '{target}'", - ) - - def get_weights(self, weights: Weights, prefix: str): - loader = self._lookup_loader(prefix) - return loader.get_weights(weights, prefix) - - def get_weights_col_packed( - self, - weights: "Weights", - prefix: str, - block_sizes: Union[int, List[int]], - ): - loader = self._lookup_loader(prefix) - return loader.get_weights_col_packed(weights, prefix, block_sizes) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - loader = self._lookup_loader(prefixes[0]) - return loader.get_multi_weights_col(weights, prefixes, dim) - - def get_weights_row(self, weights: Weights, prefix: str): - loader = self._lookup_loader(prefix) - return loader.get_weights_row(weights, prefix) - - def _get_target_loaders( - self, quantization_config: QuantizationConfig - ) -> Dict[str, WeightsLoader]: - """ - A compressed-tensors checkpoint can use different quantizations - for different targets. This method returns a dictionary with a - loader per target. - """ - - loaders: Dict[str, WeightsLoader] = {} - - format = quantization_config.format - - for group_name, group in quantization_config.config_groups.items(): - # The group configuration can be a string, but does that ever - # happen in a serialized quantization config? - assert isinstance(group, QuantizationScheme) - - loader = self._create_loader_for_group(format, group_name, group) - - # A quantized parameter group can have multiple targets, add the - # loader for all the targets. - for target in group.targets: - if target in loaders: - raise ValueError( - f"Target '{target} has multiple configured loaders'" - ) - loaders[target] = loader - - return loaders - - def _create_loader_for_group( - self, format: str, group_name: str, group: QuantizationScheme - ) -> WeightsLoader: - """ - Find and create a loader for the group with the given quantization - scheme. - """ - # NOTE: we ignore group.output_activations because we don't support - # output quantization yet. - - input_activations = group.input_activations - weights = group.weights - if ( - format - in { - CompressionFormat.float_quantized.value, - CompressionFormat.naive_quantized.value, - } - and weights is not None - and weights.type == QuantizationType.FLOAT - and weights.num_bits == 8 - ): - # FP W8A8 or W8A16. - return W8ANFpLoader(input_activations=input_activations, weights=weights) - elif ( - format == CompressionFormat.pack_quantized.value - and weights is not None - and weights.type == QuantizationType.INT - and weights.num_bits in (4, 8) - ): - # INT W4A16 or W8A16 (GPTQ/AWQ-like). - return WNA16IntLoader(weights) - elif ( - format == CompressionFormat.marlin_24.value - and weights is not None - and weights.type == QuantizationType.INT - and weights.num_bits in (4, 8) - ): - return WNA16Int24Loader(weights) - elif ( - format - in { - CompressionFormat.int_quantized.value, - CompressionFormat.naive_quantized.value, - } - and weights is not None - and weights.type == QuantizationType.INT - and weights.num_bits == 8 - ): - return W8A8IntLoader(input_args=input_activations, weight_args=weights) - else: - raise ValueError( - f"Group '{group_name}' has unsupported compressed-tensors configurtion" - ) - - def _lookup_loader(self, prefix: str) -> WeightsLoader: - """ - Look up the loader to use for a given parameter name (prefix). - """ - - if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0: - return DefaultWeightsLoader(UnquantizedWeight) - - # We currently only handle linear layers, so unconditionally pass - # a `Linear` instance. - targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys()) - if len(targets) == 0: - raise ValueError( - f"Cannot find compressed-tensors target for prefix: {prefix}" - ) - return self.loaders[targets[0]] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py deleted file mode 100644 index fff0c765..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8a8_int.py +++ /dev/null @@ -1,239 +0,0 @@ -from typing import List, Optional, Union, TypeVar -from dataclasses import dataclass - -from loguru import logger -import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationType - -from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - - -quantization = None - - -class W8A8IntLoader(WeightsLoader): - """ - Loader for w8a8 integer compressed-tensors parameters. - """ - - def __init__( - self, - *, - input_args: Optional[QuantizationArgs], - weight_args: QuantizationArgs, - ): - if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8: - raise ValueError( - f"{type(self).__name__} only supports w8a8 int checkpoints" - ) - - if not weight_args.symmetric: - raise ValueError("Checkpoints with asymmetric weights are not supported") - - self.load_weight_scale = not weight_args.dynamic - - if input_args is not None: - self.input_symmetric = input_args.symmetric - - if not input_args.dynamic: - log_once( - logger.warning, - "Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).", - ) - else: - self.input_symmetric = True - - def __str__(self) -> str: - def scale_to_str(scale): - return "static" if scale else "dynamic" - - def symmetric_to_str(symmetric): - return "symmetric" if symmetric else "asymmetric" - - return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))" - - def get_weights(self, weights: "Weights", prefix: str): - w = weights.get_tensor(f"{prefix}.weight", to_dtype=False) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - w = weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False - ) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - if weight_scale.numel() > 1: - weight_scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", - dim=0, - block_sizes=block_sizes, - to_dtype=False, - ) - weight_scale = weight_scale.reshape(-1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - w = [ - weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes - ] - shapes = [x.shape for x in w] - - w = torch.cat(w, dim=dim) - - weight_scale = None - if self.load_weight_scale: - weight_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) - for p, shape in zip(prefixes, shapes) - ] - weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - def get_weights_row(self, weights: "Weights", prefix: str): - w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) - - return Int8Weight( - input_symmetric=self.input_symmetric, - weight=w, - weight_scale=weight_scale, - ) - - -OtherT = TypeVar("OtherT") - - -def _get_tensor_or_else( - weights: Weights, prefix: str, other: OtherT -) -> Union[torch.Tensor, OtherT]: - # Even if a checkpoint uses e.g. zero-points, they can be elided: - # https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105 - if weights.has_tensor(prefix): - return weights.get_tensor(prefix, to_dtype=False) - else: - return other - - -@dataclass -class Int8Weight(Weight): - input_symmetric: bool - weight: torch.Tensor - weight_scale: Optional[torch.Tensor] - - def get_linear(self, bias: torch.Tensor): - if self.weight_scale is None: - assert quantization is not None - qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight) - return W8A8IntLinear( - bias=bias, - input_symmetric=self.input_symmetric, - weight=qweight, - weight_scale=weight_scale, - ) - else: - return W8A8IntLinear( - bias=bias, - input_symmetric=self.input_symmetric, - weight=self.weight, - weight_scale=self.weight_scale, - ) - - -class W8A8IntLinear(torch.nn.Module): - def __init__( - self, - *, - bias: Optional[torch.Tensor], - input_symmetric: bool, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ): - super().__init__() - - weight_scale = weight_scale.to(torch.float32) - - self.bias = bias - self.input_symmetric = input_symmetric - # cutlass kernels require transposed weights. - self.weight = weight.t() - self.weight_scale = weight_scale - - if input_symmetric: - self.zero_point_adj = None - else: - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp - self.zero_point_adj = self.weight.sum( - dim=0, keepdim=True, dtype=torch.int32 - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - qinput, input_scale, input_zero_point = quantization.scaled_int8_quant( - input=input, - scale=None, - azp=None, - symmetric=self.input_symmetric, - ) - - if self.input_symmetric: - return quantization.cutlass_scaled_mm( - a=qinput, - b=self.weight, - scale_a=input_scale, - scale_b=self.weight_scale, - out_dtype=input.dtype, - bias=self.bias, - ) - else: - assert ( - self.zero_point_adj is not None - and input_scale is not None - and (self.input_symmetric or input_zero_point is not None) - ) - - return quantization.cutlass_scaled_mm_azp( - a=qinput, - b=self.weight, - scale_a=input_scale, - scale_b=self.weight_scale, - out_dtype=input.dtype, - azp_adj=self.zero_point_adj, - azp=input_zero_point, - bias=self.bias, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py deleted file mode 100644 index ed63806e..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import List, Optional, Union - -import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationType - -from text_generation_server.layers.fp8 import ( - Fp8Weight, - _load_scalar_or_matrix_scale, -) -from text_generation_server.utils.weights import Weights, WeightsLoader - - -class W8ANFpLoader(WeightsLoader): - """ - Loader for W8A8/W8A16 FP compressed-tensors parameters. - """ - - def __init__( - self, - *, - input_activations: Optional[QuantizationArgs], - weights: QuantizationArgs, - ): - assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8 - - # We ignore the `strategy` option which sets the scales to be - # per-tensor, per-channel or per-token. What scales are supported - # is dependent on the kernels used (e.g. cutlass can do tokenwise, - # Torch cannot, and FP8-Marlin does not quantize inputs at all). - # So, instead we try to use the best-possible configuration. - - self.load_weight_scale = not weights.dynamic - self.load_input_scale = ( - input_activations is not None and not input_activations.dynamic - ) - self.force_w8a16 = ( - input_activations is not None and input_activations.num_bits == 16 - ) - - def __str__(self) -> str: - def scale_to_str(scale): - return "static" if scale else "dynamic" - - quantization_type = f"W8A{16 if self.force_w8a16 else 8}" - - return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})" - - def get_weights(self, weights: "Weights", prefix: str): - w = weights.get_tensor(f"{prefix}.weight") - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - w = weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - if weight_scale.numel() > 1: - weight_scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", - dim=0, - block_sizes=block_sizes, - to_dtype=False, - ) - - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) - if input_scale.numel() > 1: - input_scale = weights.get_packed_sharded( - f"{prefix}.input_scale", - dim=0, - block_sizes=block_sizes, - to_dtype=False, - ) - input_scale = input_scale.reshape(-1).max() - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) - - def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet - w = [ - weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes - ] - shapes = [x.shape for x in w] - - # Concat then send to the device - w = torch.cat(w, dim=dim).to(weights.device) - - weight_scale = None - if self.load_weight_scale: - weight_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) - for p, shape in zip(prefixes, shapes) - ] - weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) - - input_scale = None - if self.load_input_scale: - input_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) - for p, shape in zip(prefixes, shapes) - if weights.has_tensor(f"{p}.input_scale") - ] - assert len(input_scale) == 0 or len(input_scale) == len(prefixes) - input_scale = ( - torch.cat(input_scale, dim=0).reshape(-1).max() - if len(input_scale) != 0 - else None - ) - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) - - def get_weights_row(self, weights: "Weights", prefix: str): - w = weights.get_sharded(f"{prefix}.weight", dim=1) - weight_scale = None - if self.load_weight_scale: - weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - return Fp8Weight( - weight=w, - weight_scale=weight_scale, - input_scale=input_scale, - dtype=weights.dtype, - force_w8a16=self.force_w8a16, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py deleted file mode 100644 index bb69c6b5..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int.py +++ /dev/null @@ -1,188 +0,0 @@ -from typing import List, Union - -import torch -from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs -from loguru import logger - -from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weights, WeightsLoader - - -class WNA16IntLoader(WeightsLoader): - """ - Loader for W4A16/W8A16 INT compressed-tensors parameters. - """ - - def __init__(self, weights: QuantizationArgs): - self.weights = weights - self.desc_act = self.weights.actorder == ActivationOrdering.GROUP - self.groupsize = ( - -1 if self.weights.group_size is None else self.weights.group_size - ) - - def __str__(self) -> str: - quantization_type = f"W{self.weights.num_bits}A16" - - return f"{self.__class__.__name__} ({quantization_type})" - - def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t() - except RuntimeError: - raise RuntimeError( - f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" - ) - - zero_point = None - if not self.weights.symmetric: - zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t() - - g_idx = None - if self.desc_act: - g_idx = weights.get_tensor(f"{prefix}.weight_g_idx") - - scales = weights.get_tensor(f"{prefix}.weight.scales").t() - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=False, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - try: - weight_packed = weights.get_packed_sharded( - f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes - ).t() - except RuntimeError: - raise RuntimeError( - f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" - ) - scales = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes - ).t() - scales = scales.to(dtype=weights.dtype) - - zero_point = None - if not self.weights.symmetric: - zero_point = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=0, block_sizes=block_sizes - ).t() - - g_idx = None - if self.desc_act: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=False, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - try: - weight_packed = torch.cat( - [ - weights.get_sharded(f"{p}.weight_packed", dim=0).t() - for p in prefixes - ], - dim=1, - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes], - dim=1, - ) - - zero_point = None - if not self.weights.symmetric: - zero_point = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1 - ).t() - - g_idx = None - if self.desc_act: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=False, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t() - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." - ) - - zero_point = None - if not self.weights.symmetric: - if self.desc_act or self.groupsize == -1: - zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t() - else: - zero_point = weights.get_sharded( - f"{prefix}.weight_zero_point", dim=1 - ).t() - - g_idx = None - if self.desc_act: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.weight_scale").t() - else: - scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t() - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=weight_packed.contiguous(), - scales=scales, - qzeros=zero_point, - g_idx=g_idx, - bits=self.weights.num_bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method="compressed-tensors", - sym=self.weights.symmetric, - sharded_infeatures=sharded_in_features, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py deleted file mode 100644 index 27b8614c..00000000 --- a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/wna16_int_24.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import List, Union - -import torch - - -from compressed_tensors.quantization import QuantizationArgs, QuantizationType -from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight -from text_generation_server.utils.weights import Weights, WeightsLoader - - -class WNA16Int24Loader(WeightsLoader): - """ - Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints. - """ - - def __init__(self, weight_args: QuantizationArgs): - super().__init__() - - if weight_args.type != QuantizationType.INT: - raise ValueError( - f"{type(self).__name__} only supports wNa8 int checkpoints" - ) - - if weight_args.strategy == "group" and weight_args.group_size is None: - raise ValueError("`group_size` must be set when `actorder` is `group`") - - self.bits = weight_args.num_bits - self.group_size = weight_args.group_size - - def __str__(self) -> str: - quantization_type = f"W{self.bits}A16 2:4 sparsity" - - return f"{self.__class__.__name__} ({quantization_type})" - - def get_weights(self, weights: Weights, prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - weight_packed = weights.get_tensor(f"{prefix}.weight_packed") - meta = weights.get_tensor(f"{prefix}.meta") - scale_packed = weights.get_tensor(f"{prefix}.scale_packed") - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - weight_packed = weights.get_packed_sharded( - f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes - ) - meta = weights.get_packed_sharded( - f"{prefix}.meta", dim=1, block_sizes=block_sizes - ) - scale_packed = weights.get_packed_sharded( - f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes - ) - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - weight_packed = torch.cat( - [weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1 - ) - meta = torch.cat( - [weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1 - ) - scale_packed = torch.cat( - [weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1 - ) - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0) - meta = weights.get_sharded(f"{prefix}.meta", dim=0) - if self.group_size is None: - scale_packed = weights.get_tensor(f"{prefix}.scale_packed") - else: - scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0) - - return GPTQMarlin24Weight( - weight_packed=weight_packed, - meta=meta, - scale_packed=scale_packed, - bits=self.bits, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index e62a334c..90b8f692 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -7,7 +7,7 @@ from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -QuantLinear = None +from .hpu import QuantLinear @dataclass @@ -215,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader): [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 ) - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - self.bits == 4 - and HAS_EXLLAMA - and self.quantize == "gptq" - and not self.desc_act - ) + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act if self.quantize == "gptq" and self.quant_method == "gptq": w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] @@ -362,6 +355,3 @@ class GPTQWeightsLoader(WeightsLoader): else False ) self.quant_method = "gptq" - - -HAS_EXLLAMA = False diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py similarity index 57% rename from backends/gaudi/server/text_generation_server/layers/gptq/ipex.py rename to backends/gaudi/server/text_generation_server/layers/gptq/hpu.py index 48584e90..25d5c3d2 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/ipex.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -1,125 +1,181 @@ -import math -import numpy as np -import torch -import torch.nn as nn - -import intel_extension_for_pytorch as ipex - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [4]: - raise NotImplementedError("Only 4 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - self.woq_linear = ( - ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( - self.qweight, - self.scales, - self.qzeros, - self.infeatures, - self.outfeatures, - bias=self.bias, - group_size=self.groupsize, - g_idx=g_idx, - quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, - dtype=ipex.llm.quantization.QuantDtype.INT4, - ) - ) - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [4]: - raise NotImplementedError("Only 4 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [4]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 4 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [4]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 4 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = self.woq_linear(x.reshape(-1, x.shape[-1])) - return out.reshape(out_shape) +import math +import numpy as np +import torch +import torch.nn as nn + +try: + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self._preprocessing() + + def unpack_zeros_from_cuda_old_format(self): + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0), + ).to(torch.int16 if self.bits == 8 else torch.int8) + + zeros = zeros + 1 + zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to( + self.scales.dtype + ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. + zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2]) + return zeros + + def unpack_weight_from_cuda_old_format(self): + weight = torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1), + ).to(torch.int16 if self.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**self.bits) - 1) + weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2])) + return weight + + def _preprocessing(self): + self.qweight = self.qweight.cpu() + weight = self.unpack_weight_from_cuda_old_format() + new_qweight = pack_tensor(weight) + self.qweight = new_qweight.to("hpu") + + # TODO: Support group indexing and remove the check + columns = self.qweight.shape[0] + g_idx_trivial = [i // self.group_size for i in range(columns)] + g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32) + assert torch.equal( + self.g_idx, g_idx_trivial + ), "Non-trivial tensor g_idx is not supported" + + zeros = self.unpack_zeros_from_cuda_old_format().cpu() + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to("hpu") + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + out = torch.matmul(x, weight) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py b/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py deleted file mode 100644 index 3ff3ed58..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeightsLoader, - can_use_gptq_marlin, - repack_gptq_for_marlin, -) -from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader - -__all__ = [ - "GPTQMarlinFP8Linear", - "GPTQMarlinWeightsLoader", - "MarlinWeightsLoader", - "can_use_gptq_marlin", - "repack_gptq_for_marlin", -] diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py deleted file mode 100644 index c2666d2b..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn -from text_generation_server.layers.fp8 import fp8_quantize -from text_generation_server.layers.marlin.gptq import _check_valid_shape -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - permute_scales, -) - - -quantization = None - - -MARLIN_TILE_SIZE = 16 - - -class GPTQMarlinFP8Linear(nn.Module): - """ - FP8 GPTQ-Marlin linear layer. - """ - - def __init__( - self, - qweight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> None: - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - scales = scales.unsqueeze(0) - if scales.shape[1] == 1: - out_features, in_features = qweight.shape - scales = scales.repeat(1, out_features) - qweight, scales = repack_fp8_for_marlin(qweight, scales) - - in_features = qweight.shape[0] * MARLIN_TILE_SIZE - out_features = scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.qweight = qweight - self.scales = scales - self.bias = bias if bias is not None else None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=qweight.device - ) - - @classmethod - def from_unquant(cls, weight, bias, dtype): - qweight, scales = fp8_quantize(weight) - return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) - - @classmethod - def from_fp8( - cls, - weight: torch.Tensor, - scale: torch.Tensor, - bias: torch.Tensor, - dtype: torch.dtype, - **kwargs, - ): - return cls(qweight=weight, scales=scale.to(dtype), bias=bias) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - A_flat = A.view(-1, A.shape[-1]) - C = quantization.fp8_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.workspace, - 8, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements). - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - - if fp8_tensor.shape[0] % 4 != 0: - raise ValueError( - f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" - ) - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = torch.zeros( - fp8_tensor.shape[0] // 4, - fp8_tensor.shape[1], - dtype=torch.int32, - device=fp8_tensor.device, - ) - - for i in range(4): - packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) - - return packed - - -def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): - """ - Repack FP8 tensor for GPTQ-Marlin. - """ - - out_features, in_features = weight.shape - - # Torch linear layers weights with shape [out_features, in_features], - # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], - # so transpose before packing. - qweight = pack_fp8_as_int32(weight.t()) - - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = quantization.gptq_marlin_repack( - qweight, perm, in_features, out_features, 8 - ) - - scales = permute_scales(scales) - - return repacked, scales diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py deleted file mode 100644 index 185a6d77..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py +++ /dev/null @@ -1,465 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - marlin_zero_points, - permute_scales, - unpack_cols, -) -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - - -quantization = None - - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -GPTQ_MARLIN_BITS = [4, 8] -GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] -MARLIN_TILE_SIZE = 16 - - -def can_use_gptq_marlin( - *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool -) -> bool: - return False - - -class GPTQMarlinWeightsLoader(WeightsLoader): - """ - Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. - """ - - def __init__( - self, - *, - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - quantize: str, - sym: bool, - ): - self.bits = bits - self.desc_act = desc_act - self.groupsize = groupsize - self.quant_method = quant_method - self.quantize = quantize - self.sym = sym - - def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_tensor(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - scales = weights.get_tensor(f"{prefix}.scales") - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - try: - qweight = weights.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." - ) - scales = weights.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=weights.dtype) - - if not self.sym: - qzeros = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - try: - qweight = torch.cat( - [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - if not self.sym: - qzeros = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - if self.desc_act or self.groupsize == -1: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) - - def _get_gptq_params(self, weights: Weights): - if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): - self.bits = weights.get_tensor("gptq_bits").item() - self.groupsize = weights.get_tensor("gptq_groupsize").item() - self.desc_act = False - # `server quantize` used asymmetric quantization unconditionally - # before the `gptq_sym` setting tensor was added. - self.sym = ( - weights.get_tensor("gptq_sym").item() - if weights.has_tensor("gptq_sym") - else False - ) - self.quant_method = "gptq" - - -@dataclass -class GPTQMarlinWeight(Weight): - """ - Repacked GPTQ Marlin weights. - """ - - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - bits: int - is_full_k: bool - - def __post_init__(self): - assert self.qweight.dtype == torch.int32 - assert self.scales.dtype in (torch.float16, torch.bfloat16) - assert self.g_idx.dtype == torch.int32 - assert self.perm.dtype == torch.int32 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlinLinear( - weight=self, - bias=bias, - ) - - -def repack_gptq_for_marlin( - *, - qweight: torch.Tensor, - qzeros: Optional[torch.Tensor], - scales: torch.Tensor, - g_idx: Optional[torch.Tensor], - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - sym: bool, - sharded_infeatures: bool, -) -> GPTQMarlinWeight: - """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" - _check_marlin_kernels() - assert quantization is not None - - if bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" - ) - - if groupsize not in GPTQ_MARLIN_GROUP_SIZES: - supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) - raise RuntimeError( - f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" - ) - if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"): - raise RuntimeError( - "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." - ) - - log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") - - weights_per_int = 32 // bits - in_features = qweight.shape[0] - out_features = qweight.shape[1] - - # AWQ uses column packing, GPTQ uses row packing - if quant_method == "awq": - out_features *= weights_per_int - else: - in_features *= weights_per_int - - if in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({groupsize})" - ) - - if g_idx is not None and desc_act and groupsize != -1: - perm = torch.argsort(g_idx).to(torch.int) - g_idx = g_idx[perm] - else: - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - - if quant_method == "awq": - repacked = quantization.awq_marlin_repack( - qweight, in_features, out_features, bits - ) - if qzeros is not None: - qzeros = awq_to_marlin_zero_points( - qzeros, - in_features // groupsize, - out_features, - bits, - ) - - else: - repacked = quantization.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) - - if qzeros is None: - qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) - - scales = permute_scales(scales) - - is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) - - return GPTQMarlinWeight( - qweight=repacked, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - bits=bits, - is_full_k=is_full_k, - ) - - -class GPTQMarlinLinear(nn.Module): - """ - Linear layer for GPTQ weights that were converted for the GPTQ-Marlin - kernels. - """ - - def __init__( - self, - *, - weight: GPTQMarlinWeight, - bias: Optional[torch.Tensor], - ): - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE - out_features = weight.scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - if weight.bits not in (4, 8): - raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization") - - if weight.qzeros.numel() > 0: - if weight.bits == 4: - self.quant_type = quantization.scalar_types.uint4 - else: - self.quant_type = quantization.scalar_types.uint8 - else: - if weight.bits == 4: - self.quant_type = quantization.scalar_types.uint4b8 - else: - self.quant_type = quantization.scalar_types.uint8b128 - - self.is_full_k = weight.is_full_k - - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx - self.perm = weight.perm - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - A_flat = A.view(-1, A.shape[-1]) - C = quantization.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.perm, - self.workspace, - self.quant_type, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, - self.qzeros.numel() > 0, - True, - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def _check_valid_shape(in_features: int, out_features: int): - if (in_features % 128 != 0 or out_features % 64 != 0) and ( - in_features % 64 != 0 or out_features % 128 != 0 - ): - raise ValueError( - f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." - " The shape elements must be divisible by (128, 64) or (64, 128)." - ) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py deleted file mode 100644 index 2ffbcf33..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py +++ /dev/null @@ -1,359 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import torch -import torch.nn as nn - -from text_generation_server.layers.marlin.util import _check_marlin_kernels -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -quantization = None - - -class MarlinWeightsLoader(WeightsLoader): - """Loader for Marlin-quantized weights.""" - - def __init__(self, *, bits: int, is_marlin_24: bool): - self.bits = bits - self.is_marlin_24 = is_marlin_24 - - def get_weights(self, weights: "Weights", prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = weights.get_tensor(f"{prefix}.B_24") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_tensor(f"{prefix}.B_meta") - s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - try: - B = weights.get_tensor(f"{prefix}.B") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - s = weights.get_tensor(f"{prefix}.s") - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - if self.is_marlin_24: - B = weights.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = weights.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - B = weights.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - if self.is_marlin_24: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_row(self, weights: Weights, prefix: str): - if self.is_marlin_24: - try: - B = weights.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - - weight = GPTQMarlin24Weight( - weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits - ) - else: - try: - B = weights.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - - return weight - - -@dataclass -class MarlinWeight(Weight): - """ - Marlin weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): bfloat16/float16 scales. - """ - - B: torch.Tensor - s: torch.Tensor - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.s.dtype in [torch.float16, torch.bfloat16] - - def get_linear(self, bias: torch.Tensor): - return MarlinLinear(weight=self, bias=bias) - - -class MarlinLinear(nn.Module): - def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE - out_features = weight.s.shape[1] - assert ( - in_features % 128 == 0 - ), f"Number of input features ({in_features}) not divisable by 128" - assert ( - out_features % 256 == 0 - ), f"Number of output features ({out_features}) not divisable by 256" - - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert groupsize in { - -1, - 128, - }, f"Group size must be -1 or 128, was {groupsize}" - - self.B = weight.B - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.B.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - C = quantization.marlin_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.s, - self.workspace, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_TILE_SIZE = 16 - - -@dataclass -class GPTQMarlin24Weight: - """ - GPTQ-Marlin 2:4 weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - B_meta (torch.Tensor): metadata for 2:4 sparsity. - s (torch.Tensor): float16 scales. - bits: quantized weight size. - """ - - weight_packed: torch.Tensor - meta: torch.Tensor - scale_packed: torch.Tensor - bits: int - - def __post_init__(self): - assert self.weight_packed.dtype == torch.int32 - assert self.meta.dtype == torch.int16 - assert self.scale_packed.dtype == torch.float16 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlin24Linear( - weight=self, - bias=bias, - ) - - -class GPTQMarlin24Linear(nn.Module): - def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert quantization is not None - - if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: - supported_bits = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS - ) - raise RuntimeError( - f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" - ) - - in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.scale_packed.shape[1] - groupsize = ( - -1 - if weight.scale_packed.shape[0] == 1 - else in_features // weight.scale_packed.shape[0] - ) - - if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: - supported_sizes = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ) - raise RuntimeError( - f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" - ) - - if weight.bits == 4: - self.quant_type = quantization.scalar_types.uint4b8 - else: - self.quant_type = quantization.scalar_types.uint8b128 - weights_per_int32 = 32 // weight.bits - - assert ( - out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" - assert ( - out_features % weights_per_int32 == 0 - ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" - - assert ( - in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" - if groupsize != -1 and in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisable by group size ({groupsize})" - ) - - self.weight_packed = weight.weight_packed - self.meta = weight.meta - self.scale_packed = weight.scale_packed - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, - dtype=torch.int, - device=weight.weight_packed.device, - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert quantization is not None - - C = quantization.gptq_marlin_24_gemm( - A.view(-1, A.shape[-1]), - self.weight_packed, - self.meta, - self.scale_packed, - self.workspace, - self.quant_type, - A.shape[0], - self.scale_packed.shape[1], - A.shape[1], - ) - - C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/util.py b/backends/gaudi/server/text_generation_server/layers/marlin/util.py deleted file mode 100644 index 9f52340f..00000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/util.py +++ /dev/null @@ -1,137 +0,0 @@ -import functools -from typing import List, Tuple - -import numpy -import torch - - -quantization = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def _check_marlin_kernels(): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) - - if quantization is None: - raise NotImplementedError( - "marlin is not installed, install it with: pip install server/marlin" - ) - - -# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 -@functools.cache -def get_perms() -> Tuple[List[int], List[int]]: - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def permute_scales(scales: torch.Tensor): - scale_perm, scale_perm_single = get_perms() - out_features = scales.shape[1] - if scales.shape[0] == 1: - scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - else: - scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] - return scales.reshape((-1, out_features)).contiguous() - - -# Functions below are from vLLM - - -def get_pack_factor(bits: int) -> int: - if 32 % bits != 0: - raise ValueError(f"Cannot {bits} bit values into uint32") - return 32 // bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - scale_perm, _ = get_perms() - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index d832fb00..c3e5727b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -398,6 +398,7 @@ class FlashGemmaModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, + adapter_data: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: @@ -479,6 +480,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): block_tables, slots, seqlen, + adapter_data, prefill_cache_indices, hpu_attention_meta, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 80236fe8..a7a85d3a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads): prefix, weights, ) - elif config.quantize == "marlin": - raise RuntimeError( - "GPT-2 models with marlin quantization are not yet supported" - ) else: return _load_qkv(config, prefix, weights, head_size, num_heads) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index af0f8f89..2b67501d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -111,6 +111,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 21c4bc71..cf7c9a79 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: + if config.quantize not in ["gptq", "awq"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 57d4ee64..a41518d7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -32,10 +32,6 @@ def load_multi_mqa( return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) - elif config.quantize == "marlin": - raise RuntimeError( - "santacoder models with marlin quantization are not yet supported" - ) else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py index 7a63d4dd..0ee6ed16 100644 --- a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py +++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py @@ -588,7 +588,7 @@ class Seq2SeqLM(Model): aliases=aliases, weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + if config.quantize in ["awq", "gptq"]: weights._set_gptq_params(model_id, revision) model = model_class(config, weights) diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index ee561acc..a8faf4a5 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -4,9 +4,7 @@ from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download -from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.utils.weights import ( - DefaultWeightsLoader, WeightsLoader, ) @@ -129,64 +127,13 @@ def get_loader( f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." ) - if can_use_gptq_marlin( + return GPTQWeightsLoader( bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ): - from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader - - return GPTQMarlinWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - else: - return GPTQWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - elif quantize == "bitsandbytes": - from text_generation_server.layers.bnb import BNBWeight - - return DefaultWeightsLoader(BNBWeight) - elif quantize == "bitsandbytes-fp4": - from text_generation_server.layers.bnb import BNBFP4Weight - - return DefaultWeightsLoader(BNBFP4Weight) - elif quantize == "bitsandbytes-nf4": - from text_generation_server.layers.bnb import BNBNF4Weight - - return DefaultWeightsLoader(BNBNF4Weight) - elif quantize == "eetq": - from text_generation_server.layers.eetq import EETQWeight - - return DefaultWeightsLoader(EETQWeight) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2WeightsLoader - - return Exl2WeightsLoader() - elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeightsLoader - - # TODO: improve check once we have one config type per quantize value - if not isinstance(quantizer_config, _QuantizerConfig): - raise ValueError( - f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." - ) - - return MarlinWeightsLoader( - bits=quantizer_config.bits, - is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader From 8d221b7b7931b6ba287513285311c37183da57b6 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sat, 22 Mar 2025 20:58:37 -0700 Subject: [PATCH 14/35] fix gptq issue Signed-off-by: Wang, Yi A --- .../server/text_generation_server/cli.py | 2 ++ .../text_generation_server/layers/gptq/hpu.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index e1c0298d..569e2e5b 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -99,6 +99,8 @@ def serve( "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4", + "gptq", + "awq", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py index 25d5c3d2..72944fa0 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -50,6 +50,9 @@ class QuantLinear(nn.Module): self.outfeatures = qweight.shape[1] self.infeatures = qweight.shape[0] * 32 // bits + self.wf = torch.tensor( + list(range(0, 32, self.bits)), dtype=torch.int32 + ).unsqueeze(0) self._preprocessing() def unpack_zeros_from_cuda_old_format(self): @@ -75,22 +78,24 @@ class QuantLinear(nn.Module): return weight def _preprocessing(self): + orig_device = self.qweight.device self.qweight = self.qweight.cpu() weight = self.unpack_weight_from_cuda_old_format() new_qweight = pack_tensor(weight) - self.qweight = new_qweight.to("hpu") - + self.qweight = new_qweight.to(orig_device) # TODO: Support group indexing and remove the check columns = self.qweight.shape[0] - g_idx_trivial = [i // self.group_size for i in range(columns)] - g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32) + g_idx_trivial = [i // self.groupsize for i in range(columns)] + g_idx_trivial = torch.tensor( + g_idx_trivial, dtype=torch.int32, device=self.g_idx.device + ) assert torch.equal( self.g_idx, g_idx_trivial ), "Non-trivial tensor g_idx is not supported" - - zeros = self.unpack_zeros_from_cuda_old_format().cpu() + self.qzeros = self.qzeros.cpu() + zeros = self.unpack_zeros_from_cuda_old_format() new_qzeros = pack_tensor(zeros) - self.qzeros = new_qzeros.to("hpu") + self.qzeros = new_qzeros.to(orig_device) @classmethod def new(cls, bits, groupsize, infeatures, outfeatures, bias): From 69773767c50a19f6b288fe0ee63ca8f782dd1dd3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 24 Mar 2025 20:21:45 -0700 Subject: [PATCH 15/35] enable fp8 Signed-off-by: Wang, Yi A --- .../server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/fp8.py | 103 ++++++++---------- 2 files changed, 45 insertions(+), 59 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 569e2e5b..53837ef7 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -101,6 +101,7 @@ def serve( "bitsandbytes-fp4", "gptq", "awq", + "fp8", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index e37c4983..6c8d637e 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Optional, Tuple, Type, Union, List import torch -from loguru import logger from text_generation_server.utils.weights import ( Weight, @@ -10,18 +9,16 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_once -quantization = None +from vllm_hpu_extension.ops import scaled_fp8_quant +from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 +import habana_frameworks.torch.utils.experimental as htexp + w8a8_block_fp8_matmul = None per_token_group_quant_fp8 = None - quant_dtype: torch.dtype = torch.float8_e4m3fn -CUTLASS_FP8_AVAILABLE = False - - def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. @@ -43,7 +40,13 @@ def per_tensor_dequantize( inv_scale: Union[float, torch.Tensor], dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - fake_qweight = tensor.to(dtype) + device = tensor.device + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + # dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to("cpu") + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale return dq_weight @@ -55,7 +58,10 @@ def requantize_with_max_scale( dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. - max_w_scale = weight_scale.max().float() + max_w_scale = weight_scale.max() + + if is_hpu_gaudi2(): + max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor() start = 0 for idx, logical_width in enumerate(logical_widths): @@ -84,37 +90,16 @@ def fp8_quantize( argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can be used without modification). """ - if quantization is not None: - shape = weight.shape - qweight, scale = quantization.scaled_fp8_quant( - weight.reshape(-1, shape[-1]), - scale=scale, - scale_ub=scale_upper_bound, - # TODO: don't do this when we have to use the Torch kernel. - use_per_token_if_dynamic=not scalar, - ) + shape = weight.shape + qweight, scale = scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + scale=scale, + scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, + ) - return qweight.reshape(shape), scale - - finfo = torch.finfo(qdtype) - - if scale is None: - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) - scale = scale.float().reciprocal() - else: - # Use reciprocal to avoid more expensive division. - qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) - - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) - - return qweight, scale + return qweight.reshape(shape), scale class HybridFP8UnquantLoader(WeightsLoader): @@ -153,6 +138,10 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, @@ -201,6 +190,10 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) input_scale = input_scale.reshape(-1).max() + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, @@ -259,6 +252,11 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -296,7 +294,10 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -353,27 +354,19 @@ class Fp8Linear(torch.nn.Module): weight_block_size: Optional[List[int]] = None, ) -> None: super().__init__() - if CUTLASS_FP8_AVAILABLE: - log_once(logger.info, "Using cutlass w8a8 kernels") self.dtype = dtype self.qweight = qweight self.scale = scale.float() self.input_scale = input_scale.float() if input_scale is not None else None self.weight_block_size = weight_block_size - - if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: - self.scale_upper_bound = torch.tensor( - scale_upper_bound, dtype=torch.float32, device=qweight.device - ) - else: - self.scale_upper_bound = scale_upper_bound + self.scale_upper_bound = scale_upper_bound self.bias = bias if bias is not None else None @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=True) return cls( qweight=qweight, scale=scale, @@ -434,14 +427,6 @@ class Fp8Linear(torch.nn.Module): if self.bias is not None: output = output + self.bias return output.to(dtype=input.dtype) - if CUTLASS_FP8_AVAILABLE: - # cutlass FP8 supports per-token scales, so get non-scalar scales. - qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound, scalar=False - ) - return quantization.cutlass_scaled_mm( - qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias - ) qinput, scale = fp8_quantize( input, @@ -470,4 +455,4 @@ def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Siz if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1).expand(shape[0]) + return scale.reshape(-1) From fd70ad703e960bd4589c4d04c8ca725d15179d0b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 25 Mar 2025 22:21:44 -0700 Subject: [PATCH 16/35] warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A --- .../text_generation_server/models/__init__.py | 17 -- .../models/flash_causal_lm.py | 209 ++++++++++-------- .../models/flash_vlm_causal_lm.py | 2 +- .../models/mllama_causal_lm.py | 2 +- .../server/text_generation_server/server.py | 2 - 5 files changed, 117 insertions(+), 115 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 7dac910e..778b14a1 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -92,7 +92,6 @@ try: from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch from text_generation_server.models.custom_modeling.flash_mllama import ( FlashMllamaForConditionalGeneration, @@ -144,7 +143,6 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(IdeficsCausalLM) class ModelType(enum.Enum): @@ -301,12 +299,6 @@ class ModelType(enum.Enum): "name": "Gptj", "url": "https://huggingface.co/EleutherAI/gpt-j-6b", } - IDEFICS = { - "type": "idefics", - "name": "Idefics", - "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", - "multimodal": True, - } MLLAMA = { "type": "mllama", "name": "Mllama", @@ -733,15 +725,6 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - elif model_type == IDEFICS: - return IdeficsCausalLM( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) elif model_type == QWEN2_VL: return FlashVlmCausalLM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 4cdf2628..b26184e4 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -69,6 +69,8 @@ from text_generation_server.utils.import_utils import ( import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch +import itertools +from vllm_hpu_extension.ops import batch2block, block2batch tracer = trace.get_tracer(__name__) @@ -86,6 +88,78 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW +def prepare_for_decode( + dtype, use_contiguous_pa, device, slot, block_tables, batch_size +): + # Prepare values if we need to continue decoding + # need for HPUPagedAttentionMetadata preparation + def flatten(in_list): + return list(itertools.chain(*in_list)) + + def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + def pad_list(input, k, v): + input_len = len(input) + target_len = (input_len + k - 1) // k * k + padding = target_len - input_len + return input + [v] * padding + + last_block_usage = slot % BLOCK_SIZE + 1 + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [ + [BLOCK_SIZE] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt + ] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + if use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + # block_bucket_size) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + block_list = gather_list(block_list, indices, 0) + block_groups = gather_list(block_groups, indices, -1) + block_usage = gather_list(block_usage, indices, 1) + else: + block_bucket_size = len(block_list) + block_list = pad_list(block_list, block_bucket_size, 0) + block_groups = pad_list(block_groups, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + + block_list = torch.tensor(block_list, dtype=torch.int, device=device) + block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) + block_usage = torch.tensor(block_usage, dtype=dtype, device=device) + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + ones = torch.ones( + (block_mapping.size(0),), device=device, dtype=block_mapping.dtype + ) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + return trim_attn_metadata( + HPUPagedAttentionMetadata( + block_list=block_list, + block_groups=block_groups, + block_usage=block_usage, + block_mapping=block_mapping.to(dtype), + attn_bias=attn_bias, + block_scales=block_scales, + ) + ) + + @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -879,83 +953,18 @@ class FlashCausalLMBatch(Batch): ) def prepare_for_decode(self, dtype, use_contiguous_pa): - # Prepare values if we need to continue decoding - # need for HPUPagedAttentionMetadata preparation - import itertools - from vllm_hpu_extension.ops import batch2block, block2batch - - def flatten(in_list): - return list(itertools.chain(*in_list)) - - def gather_list(input, indices, v): - return [input[i] if i is not None else v for i in indices] - - def pad_list(input, k, v): - input_len = len(input) - target_len = (input_len + k - 1) // k * k - padding = target_len - input_len - return input + [v] * padding - - device = self.block_tables_tensor.device - last_block_usage = self.slots[self.slot_indices] % BLOCK_SIZE + 1 block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) - block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] - block_usage = [ - [BLOCK_SIZE] * (len(bt) - 1) + [lbu] - for bt, lbu in zip(block_tables, last_block_usage) - if bt - ] - block_list = flatten(block_tables) - block_groups = flatten(block_groups) - block_usage = flatten(block_usage) - batch = self.input_ids.size(0) - - assert len(block_list) == len(block_groups) - assert len(block_list) == len(block_usage) - if use_contiguous_pa: - block_bucket_size = max(max(block_list) + 1, len(block_list)) - # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - # block_bucket_size) - indices: List[Any] - indices = [None] * block_bucket_size - for i, bid in enumerate(block_list): - indices[bid] = i - block_list = gather_list(block_list, indices, 0) - block_groups = gather_list(block_groups, indices, -1) - block_usage = gather_list(block_usage, indices, 1) - else: - block_bucket_size = len(block_list) - block_list = pad_list(block_list, block_bucket_size, 0) - block_groups = pad_list(block_groups, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - - block_list = torch.tensor(block_list, dtype=torch.int, device=device) - block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) - block_usage = torch.tensor(block_usage, dtype=dtype, device=device) - block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch) - mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze( - 0 - ) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) - ones = torch.ones( - (block_mapping.size(0),), device=device, dtype=block_mapping.dtype - ) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) - self.hpu_attn_meta = trim_attn_metadata( - HPUPagedAttentionMetadata( - block_list=block_list, - block_groups=block_groups, - block_usage=block_usage, - block_mapping=block_mapping.to(dtype), - attn_bias=attn_bias, - block_scales=block_scales, - ) + self.hpu_attn_meta = prepare_for_decode( + dtype, + use_contiguous_pa, + self.block_tables_tensor.device, + self.slots[self.slot_indices], + block_tables, + self.input_ids.size(0), ) def prepare_for_prefill(self): @@ -1481,32 +1490,44 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + for bs in [1, 2, 4, 8]: + for seqlen in [32, 64, 128, 256, 512, 1024]: + self.warmup_prefill(seqlen, bs) return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def tunableop_warmup(self, seqlen: int, max_bt: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.zeros( - seqlen, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ) - + def warmup_prefill(self, prompt_len: int, bs: int): + input_ids = torch.zeros( + prompt_len, dtype=torch.int64, device=self.device + ).repeat(bs) + position_ids = torch.arange( + prompt_len, dtype=torch.int32, device=self.device + ).repeat(bs) + max_bt = (prompt_len // BLOCK_SIZE + 1) * bs block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).repeat(seqlen) - block_tables = block_tables.reshape((seqlen, max_bt)) + ).reshape(bs, -1) + slot_acc = [] + for i in range(bs): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + + input_lengths = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, ) + lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1514,11 +1535,13 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=block_tables, - seqlen=seqlen, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, - lm_head_indices=None, + seqlen=trim_seqlen_metadata(seqlen), prefill_cache_indices=None, + lm_head_indices=lm_head_indices, + adapter_data=None, + hpu_attention_meta=None, ) def forward( @@ -1606,7 +1629,7 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, + block_tables=None, slots=slots, seqlen=trim_seqlen_metadata(seqlen), prefill_cache_indices=batch.prefill_cache_indices, @@ -1637,9 +1660,7 @@ class FlashCausalLM(Model): batch.prepare_for_prefill() else: batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) - prefill_logprobs = batch.prefill_next_token_indices is not None - # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta if batch.speculative_ids is not None: diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 7cff7797..48bfce89 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -462,7 +462,7 @@ class FlashVlmCausalLM(FlashCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index be67b6ae..4471aab3 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -288,7 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 6e470361..5a7d2117 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -33,13 +33,11 @@ try: from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLMBatch, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch VLM_BATCH_TYPES = { PaliGemmaBatch, VlmCausalLMBatch, FlashVlmCausalLMBatch, - IdeficsCausalLMBatch, FlashMllamaCausalLMBatch, } except (ImportError, NotImplementedError): From ba7a131e041519326767953dae5d956022423593 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 26 Mar 2025 17:39:26 -0700 Subject: [PATCH 17/35] add warmup_decode Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index b26184e4..cb879c9c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1494,6 +1494,10 @@ class FlashCausalLM(Model): for seqlen in [32, 64, 128, 256, 512, 1024]: self.warmup_prefill(seqlen, bs) + for bs in [1, 2, 4, 8]: + for block_num in [1, 2, 4, 8, 16]: + self.warmup_decode(bs, block_num * bs) + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): @@ -1544,6 +1548,57 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) + def warmup_decode(self, bs: int, block_num: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + block_num, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slots = [] + past_len = ( + len(block_tables[0]) * BLOCK_SIZE - 3 + ) # for decode, we only need to pass the past token + # fetch the last blocked to warmup block num + for i in range(bs): + slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2) + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + ) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + bs, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + prefill_cache_indices=None, + lm_head_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, + ) + def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: From 7900be5ac3af19a66a69046ab2b13c8412449553 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 26 Mar 2025 20:19:13 -0700 Subject: [PATCH 18/35] warmup decode Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index cb879c9c..e032242c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1552,15 +1552,15 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) block_tables = torch.arange( - block_num, dtype=torch.int32, device=self.device + start=1, end=block_num + 1, dtype=torch.int32, device=self.device ).reshape(bs, -1) slots = [] past_len = ( - len(block_tables[0]) * BLOCK_SIZE - 3 + len(block_tables[0]) * BLOCK_SIZE - 1 ) # for decode, we only need to pass the past token # fetch the last blocked to warmup block num for i in range(bs): - slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2) + slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) slots = torch.tensor(slots, dtype=torch.int64, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) @@ -1575,12 +1575,17 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) + + block_num = cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables_valid = [] + for i, bt in enumerate(block_tables.tolist()): + block_tables_valid.append(bt[0 : block_num[i]]) hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, self.device, slots, - block_tables, + block_tables_valid, bs, ) From 1508ee8de125d97a305807553537a9b5487e70d5 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 27 Mar 2025 22:51:21 -0700 Subject: [PATCH 19/35] remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A --- .../layers/attention/hpu.py | 2 - .../custom_modeling/flash_cohere_modeling.py | 27 +--- .../custom_modeling/flash_dbrx_modeling.py | 28 +--- .../flash_deepseek_v2_modeling.py | 27 +--- .../flash_deepseek_v3_modeling.py | 26 +--- .../custom_modeling/flash_gemma2_modeling.py | 24 +--- .../custom_modeling/flash_gemma_modeling.py | 25 +--- .../custom_modeling/flash_gpt2_modeling.py | 27 +--- .../custom_modeling/flash_llama_modeling.py | 30 +---- .../custom_modeling/flash_llava_next.py | 4 - .../custom_modeling/flash_mistral_modeling.py | 31 +---- .../custom_modeling/flash_mixtral_modeling.py | 30 +---- .../models/custom_modeling/flash_mllama.py | 6 - .../custom_modeling/flash_neox_modeling.py | 26 +--- .../flash_pali_gemma_modeling.py | 4 - .../custom_modeling/flash_phi_modeling.py | 24 +--- .../custom_modeling/flash_qwen2_modeling.py | 31 +---- .../custom_modeling/flash_rw_modeling.py | 44 +----- .../flash_santacoder_modeling.py | 25 +--- .../flash_starcoder2_modeling.py | 30 +---- .../models/custom_modeling/idefics2.py | 4 - .../models/custom_modeling/idefics3.py | 4 - .../models/custom_modeling/qwen2_5_vl.py | 4 - .../models/custom_modeling/qwen2_vl.py | 4 - .../models/flash_causal_lm.py | 127 ++++++++---------- .../models/flash_vlm_causal_lm.py | 2 - .../models/mllama_causal_lm.py | 2 - 27 files changed, 88 insertions(+), 530 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 56143541..526dbcec 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -26,7 +26,6 @@ def attention( kv_cache: KVCache, kv_scales: KVScales, seqlen: Seqlen, - block_tables: torch.Tensor, softmax_scale: float, window_size_left: int = -1, causal: bool = True, @@ -61,7 +60,6 @@ def paged_attention( kv_cache: KVCache, kv_head_mapping: torch.Tensor, softmax_scale: float, - block_tables: torch.Tensor, seqlen: Seqlen, *, kv_scales: KVScales, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 77dec80d..3bcc689d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -219,10 +219,8 @@ class FlashCohereAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -247,16 +245,9 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value - kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -271,7 +262,6 @@ class FlashCohereAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -281,7 +271,6 @@ class FlashCohereAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -356,10 +345,8 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -371,10 +358,8 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -424,10 +409,8 @@ class FlashCohereModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: torch.Tensor, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -446,10 +429,8 @@ class FlashCohereModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -488,10 +469,8 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -501,10 +480,8 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b335a81f..15c243c9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -308,10 +308,8 @@ class DbrxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -329,14 +327,10 @@ class DbrxAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -351,7 +345,6 @@ class DbrxAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -361,7 +354,6 @@ class DbrxAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -398,10 +390,8 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -413,10 +403,8 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -630,10 +618,8 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): # Self Attention @@ -644,10 +630,8 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -689,10 +673,8 @@ class DbrxModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -710,10 +692,8 @@ class DbrxModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -744,10 +724,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -757,10 +735,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 3298a30a..9d61c694 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -256,10 +256,8 @@ class DeepseekV2Attention(torch.nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache: KVCache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: @@ -316,15 +314,10 @@ class DeepseekV2Attention(torch.nn.Module): value = torch.nn.functional.pad( value, (0, self.head_pad_size - self.value_head_size), value=0 ) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value + kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -339,7 +332,6 @@ class DeepseekV2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -349,7 +341,6 @@ class DeepseekV2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -512,10 +503,8 @@ class DeepseekV2Layer(nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -527,10 +516,8 @@ class DeepseekV2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -577,10 +564,8 @@ class DeepseekV2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -598,10 +583,8 @@ class DeepseekV2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -629,10 +612,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -642,10 +623,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 736e0c9a..1a7ce5cf 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -256,10 +256,8 @@ class DeepseekV3Attention(torch.nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache: KVCache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: @@ -317,15 +315,9 @@ class DeepseekV3Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -340,7 +332,6 @@ class DeepseekV3Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -350,7 +341,6 @@ class DeepseekV3Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -522,10 +512,8 @@ class DeepseekV3Layer(nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -537,10 +525,8 @@ class DeepseekV3Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -587,10 +573,8 @@ class DeepseekV3Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -608,10 +592,8 @@ class DeepseekV3Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -639,10 +621,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -652,10 +632,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 632e8017..79f21b0f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -235,11 +235,9 @@ class FlashGemma2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) @@ -254,14 +252,10 @@ class FlashGemma2Attention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -276,7 +270,6 @@ class FlashGemma2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, @@ -288,7 +281,6 @@ class FlashGemma2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, softcap=self.softcap, kv_scales=self.kv_scales, @@ -402,11 +394,9 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -418,11 +408,9 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) @@ -472,11 +460,9 @@ class FlashGemma2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, adapter_data: Optional[torch.Tensor], - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -494,11 +480,9 @@ class FlashGemma2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) @@ -543,10 +527,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -557,11 +539,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index c3e5727b..609f03ac 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -207,10 +207,8 @@ class FlashGemmaAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -226,14 +224,9 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -248,7 +241,6 @@ class FlashGemmaAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, causal=self.causal, ) @@ -259,7 +251,6 @@ class FlashGemmaAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -331,10 +322,8 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -346,10 +335,8 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -395,11 +382,9 @@ class FlashGemmaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, adapter_data: Optional[torch.Tensor], - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -417,10 +402,8 @@ class FlashGemmaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -463,10 +446,8 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -477,11 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a7a85d3a..10024a6d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -209,10 +209,8 @@ class FlashGPT2Attention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( @@ -222,16 +220,9 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value - kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -246,7 +237,6 @@ class FlashGPT2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -256,7 +246,6 @@ class FlashGPT2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -325,10 +314,8 @@ class FlashGPT2Layer(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): residual = hidden_states @@ -339,10 +326,8 @@ class FlashGPT2Layer(nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -393,10 +378,8 @@ class FlashGPT2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -408,10 +391,8 @@ class FlashGPT2Model(torch.nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -446,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -462,10 +441,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices=prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7deb6cbf..81af5560 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -201,11 +201,9 @@ class FlashLlamaAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache: KVCache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): qkv = self.query_key_value(hidden_states, adapter_data) @@ -221,14 +219,9 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -243,7 +236,6 @@ class FlashLlamaAttention(torch.nn.Module): kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -253,7 +245,6 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -441,12 +432,10 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, cross_attention_states, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -458,11 +447,9 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) if self.residual_multiplier is not None: @@ -554,10 +541,8 @@ class FlashLlamaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, @@ -577,12 +562,10 @@ class FlashLlamaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, adapter_data, cross_attention_states, - prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) @@ -643,30 +626,21 @@ class FlashLlamaForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, hpu_attention_meta=hpu_attention_meta, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 3bdfdd83..62e8470c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -169,11 +169,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused for this model @@ -276,11 +274,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 38eba082..d23d4f67 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -178,10 +178,8 @@ class MistralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -198,14 +196,9 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -220,7 +213,6 @@ class MistralAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -231,7 +223,6 @@ class MistralAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -335,10 +326,8 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -351,10 +340,8 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -403,10 +390,8 @@ class MistralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): @@ -424,10 +409,8 @@ class MistralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -475,30 +458,20 @@ class FlashMistralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index fbcb0970..1ef6be48 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -235,10 +235,8 @@ class MixtralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -254,14 +252,9 @@ class MixtralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -276,7 +269,6 @@ class MixtralAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -287,7 +279,6 @@ class MixtralAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -384,10 +375,8 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -399,10 +388,8 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -454,10 +441,8 @@ class MixtralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -475,10 +460,8 @@ class MixtralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -510,29 +493,20 @@ class FlashMixtralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index b26adad7..216642e0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -801,12 +801,10 @@ class FlashLlamaCrossLayer(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, cross_attention_states, # [ IB, ...] - prefill_cache_indices, hpu_attention_meta, ) -> Tuple[torch.Tensor, torch.Tensor]: if cross_attention_states is None: @@ -911,11 +909,9 @@ class FlashMllamaForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, # XXX: Putting these as optional so that the cuda warmup calls can go through. @@ -979,11 +975,9 @@ class FlashMllamaForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d1904c03..33f63333 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -147,10 +147,8 @@ class FlashNeoxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -166,14 +164,10 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(query_rot, key_rot, cos, sin) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - if prefill_cache_indices is not None: - qkv_to_cache = qkv[prefill_cache_indices] - else: - qkv_to_cache = qkv kv_cache.store( - key=qkv_to_cache[:, 1], - value=qkv_to_cache[:, 2], + key=qkv[:, 1], + value=qkv[:, 2], slots=slots, kv_scales=self.kv_scales, ) @@ -188,7 +182,6 @@ class FlashNeoxAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -198,7 +191,6 @@ class FlashNeoxAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -268,10 +260,8 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): if self.use_parallel_residual: @@ -283,10 +273,8 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -308,10 +296,8 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -363,10 +349,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -384,10 +368,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -417,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -430,10 +410,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 2b67501d..4d31d5dd 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -69,11 +69,9 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused here @@ -106,11 +104,9 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index cf7c9a79..0c777912 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -160,10 +160,8 @@ class FlashPhiAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): # Compute query, key, value and split @@ -190,13 +188,9 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -210,7 +204,6 @@ class FlashPhiAttention(torch.nn.Module): kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -220,7 +213,6 @@ class FlashPhiAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -287,10 +279,8 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -301,10 +291,8 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -354,10 +342,8 @@ class FlashPhiModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -375,10 +361,8 @@ class FlashPhiModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -409,10 +393,8 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -422,10 +404,8 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 480a17d1..af4b404d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -106,10 +106,8 @@ class Qwen2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -125,14 +123,9 @@ class Qwen2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -147,7 +140,6 @@ class Qwen2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -158,7 +150,6 @@ class Qwen2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -230,10 +221,8 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, residual = self.input_layernorm(hidden_states) @@ -245,10 +234,8 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) hidden_states = attn_output + residual @@ -296,10 +283,8 @@ class Qwen2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -317,10 +302,8 @@ class Qwen2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -364,21 +347,13 @@ class Qwen2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and prefill_cache_indices.size( - 0 - ) != slots.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -386,10 +361,8 @@ class Qwen2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7c4b2b6..141e13a6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -182,10 +182,8 @@ class FlashRWAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -203,14 +201,9 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -225,7 +218,6 @@ class FlashRWAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -235,7 +227,6 @@ class FlashRWAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -309,10 +300,8 @@ class FlashRWLargeAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -329,14 +318,9 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, :, 0].contiguous(), - value=kv_to_cache[:, :, 1].contiguous(), + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), slots=slots, kv_scales=self.kv_scales, ) @@ -351,7 +335,6 @@ class FlashRWLargeAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -361,7 +344,6 @@ class FlashRWLargeAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -447,10 +429,8 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): if self.parallel_attn: @@ -462,10 +442,8 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -485,10 +463,8 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -573,10 +549,8 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): # Layer norm. @@ -589,10 +563,8 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -651,10 +623,8 @@ class FlashRWModel(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -672,10 +642,8 @@ class FlashRWModel(FlashRWPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -703,10 +671,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -716,10 +682,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a41518d7..b68f4784 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -265,10 +265,8 @@ class FlashMQAttention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.c_attn(hidden_states) @@ -282,14 +280,9 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - if prefill_cache_indices is not None: - key_value_to_cache = key_value[prefill_cache_indices] - else: - key_value_to_cache = key_value - kv_cache.store( - key=key_value_to_cache[:, 0], - value=key_value_to_cache[:, 1], + key=key_value[:, 0], + value=key_value[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -304,7 +297,6 @@ class FlashMQAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -314,7 +306,6 @@ class FlashMQAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -379,10 +370,8 @@ class Block(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -390,10 +379,8 @@ class Block(nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -445,10 +432,8 @@ class FlashSantacoderModel(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -463,10 +448,8 @@ class FlashSantacoderModel(nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -496,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -509,10 +490,8 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 082e5d82..76f6f473 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -235,10 +235,8 @@ class Starcoder2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -255,14 +253,9 @@ class Starcoder2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -277,7 +270,6 @@ class Starcoder2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -288,7 +280,6 @@ class Starcoder2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -448,10 +439,8 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -464,10 +453,8 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -518,10 +505,8 @@ class Starcoder2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: @@ -540,10 +525,8 @@ class Starcoder2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -589,29 +572,20 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 0a4305ec..02806ac9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -740,11 +740,9 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -843,11 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index 9278a86a..964526fc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -483,11 +483,9 @@ class Idefics3ForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -587,11 +585,9 @@ class Idefics3ForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 75dd2b40..441b0016 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -906,11 +906,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -938,11 +936,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 3b4965a2..47ae2ac9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -480,11 +480,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -511,11 +509,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index e032242c..b0859c3d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -25,7 +25,7 @@ from typing import ( Dict, Union, ) - +import torch.nn.functional as F from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.models import Model @@ -116,7 +116,6 @@ def prepare_for_decode( block_list = flatten(block_tables) block_groups = flatten(block_groups) block_usage = flatten(block_usage) - assert len(block_list) == len(block_groups) assert len(block_list) == len(block_usage) if use_contiguous_pa: @@ -979,29 +978,27 @@ class FlashCausalLMBatch(Batch): # padding to left to work with sliding window # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate # the right logit position - input_ids_padded = None - input_ids_padded_length = None + input_ids_padded_length = [] + # need extra pad to match warmup seq + extra_pad = 0 if isinstance(self.input_ids, list) and len(self) > 1: - input_ids_padded = [] input_ids_padded_length = [] + input_ids = [] for input_id in self.input_ids: - padded = self.max_input_length - len(input_id) - input_id_padded = input_id + padded = self.max_input_length - len(input_id) + extra_pad if padded > 0: - input_id_padded = [0] * padded + input_id_padded - input_ids_padded.append(input_id_padded) + input_id = [0] * padded + input_id + input_ids.append(input_id) input_ids_padded_length.append(padded) - input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) - input_ids_padded = torch.tensor( - input_ids_padded, dtype=torch.int64, device=device - ) - - if isinstance(self.input_ids, list): - if len(self) > 1: - input_ids = np.concatenate(self.input_ids, dtype=np.int64) - else: - input_ids = self.input_ids[0] + input_ids = np.concatenate(input_ids, dtype=np.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + elif isinstance(self.input_ids, list): + input_ids = self.input_ids[0] + input_ids_padded_length.append(extra_pad) + input_ids = [0] * extra_pad + input_ids + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + else: + logger.error("should not be here, prefill self.input_ids is a tensor") self.input_lengths_tensor = torch.tensor( self.input_lengths, dtype=torch.int32, device=device @@ -1052,10 +1049,9 @@ class FlashCausalLMBatch(Batch): request_position_ids = torch.arange( cache_length, cache_length + input_length, dtype=torch.int32 ) - if input_ids_padded is not None: - position_ids.append( - torch.ones(input_ids_padded_length[i], dtype=torch.int32) - ) + request_position_ids = F.pad( + request_position_ids, (input_ids_padded_length[i], 0), value=1 + ) position_ids.append(request_position_ids) if not r.slots: @@ -1079,12 +1075,11 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill - if input_ids_padded is not None: - # hpu need request_prefill_cache_indices to skip padding in kv cache - sliding_window = get_sliding_windows() - if sliding_window is None: - sliding_window = input_length - cumulative_length += input_ids_padded_length[i] + # hpu need request_prefill_cache_indices to skip padding in kv cache + sliding_window = get_sliding_windows() + if sliding_window is None: + sliding_window = input_length + cumulative_length += input_ids_padded_length[i] if sliding_window is not None: request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - sliding_window), @@ -1105,8 +1100,7 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - if sliding_window is not None: - prefill_cache_indices.append(request_prefill_cache_indices) + prefill_cache_indices.append(request_prefill_cache_indices) ADAPTER_TO_INDEX = get_adapter_to_index() if ADAPTER_TO_INDEX: @@ -1171,23 +1165,20 @@ class FlashCausalLMBatch(Batch): position_ids = torch.cat(position_ids) if slot_indices: slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) + prefill_cache_indices = torch.cat(prefill_cache_indices) else: if position_ids: position_ids = position_ids[0] if slot_indices: slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] + prefill_cache_indices = prefill_cache_indices[0] self.position_ids = position_ids.to(device) self.slot_indices = slot_indices.to(device) self.prefill_cu_outlens = prefill_cu_outlens - self.prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) + self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices[prefill_cache_indices.to(device)] = True if all_prefill_logprobs: prefill_head_indices = None @@ -1203,21 +1194,19 @@ class FlashCausalLMBatch(Batch): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - if input_ids_padded is not None: - self.input_ids = input_ids_padded - input_ids_padded_length_tensor = torch.cumsum( - torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), - dim=-1, + input_ids_padded_length_tensor = torch.cumsum( + torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), + dim=-1, + ) + if self.prefill_head_indices is not None: + self.prefill_head_indices = ( + self.prefill_head_indices + input_ids_padded_length_tensor ) - if self.prefill_head_indices is not None: - self.prefill_head_indices = ( - self.prefill_head_indices + input_ids_padded_length_tensor - ) - if self.prefill_next_token_indices is not None: - self.prefill_next_token_indices = ( - self.prefill_next_token_indices + input_ids_padded_length_tensor - ) + if self.prefill_next_token_indices is not None: + self.prefill_next_token_indices = ( + self.prefill_next_token_indices + input_ids_padded_length_tensor + ) if adapter_set: adapter_indices = torch.cat(adapter_indices_list).to( @@ -1232,7 +1221,6 @@ class FlashCausalLMBatch(Batch): adapter_segments = torch.tensor( adapter_segments, dtype=torch.int32, device=device ) - self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, @@ -1490,14 +1478,6 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - for bs in [1, 2, 4, 8]: - for seqlen in [32, 64, 128, 256, 512, 1024]: - self.warmup_prefill(seqlen, bs) - - for bs in [1, 2, 4, 8]: - for block_num in [1, 2, 4, 8, 16]: - self.warmup_decode(bs, block_num * bs) - return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): @@ -1539,10 +1519,8 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), - prefill_cache_indices=None, lm_head_indices=lm_head_indices, adapter_data=None, hpu_attention_meta=None, @@ -1562,7 +1540,6 @@ class FlashCausalLM(Model): for i in range(bs): slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) slots = torch.tensor(slots, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) cache_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * past_len @@ -1575,11 +1552,11 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - block_num = cache_lengths_tensor // BLOCK_SIZE + 1 block_tables_valid = [] for i, bt in enumerate(block_tables.tolist()): block_tables_valid.append(bt[0 : block_num[i]]) + hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, @@ -1595,10 +1572,8 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), - prefill_cache_indices=None, lm_head_indices=None, adapter_data=None, hpu_attention_meta=hpu_attention_meta, @@ -1684,26 +1659,23 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=None, slots=slots, seqlen=trim_seqlen_metadata(seqlen), - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, # TODO not support adapter now, need the add in the future adapter_data=None, hpu_attention_meta=batch.hpu_attn_meta, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - # fix following runtime error in graph replay - # RuntimeError: Neither storage attached to input tensor, not its view - htorch.core.mark_step() return logits, speculative_logits @tracer.start_as_current_span("generate_token") @@ -1801,7 +1773,14 @@ class FlashCausalLM(Model): # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: indices = batch.cu_seqlen_prefill[1:] - 1 - batch.position_ids = batch.position_ids[indices] + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ + indices + ] + else: + batch.position_ids = batch.position_ids[indices] + batch.slot_indices = batch.slot_indices[indices] batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 48bfce89..b5d93cbc 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -462,11 +462,9 @@ class FlashVlmCausalLM(FlashCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 4471aab3..eabbe247 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -288,11 +288,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, # TODO list From 787dbe98a8cc92c234b9911a5e2e7f600397996b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 28 Mar 2025 00:09:26 -0700 Subject: [PATCH 20/35] fix comment Signed-off-by: Wang, Yi A --- launcher/src/env_runtime.rs | 4 ++-- launcher/src/main.rs | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d7ae11d5..d9056e41 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -28,8 +28,8 @@ impl Env { } } - pub fn is_hpu_device(&self) -> bool { - self.hpu_env != "N/A" + pub fn should_start_a_single_hpu_shard(&self) -> bool { + self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") } } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 23288b20..86d8714a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1559,10 +1559,7 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 - && env_runtime::Env::new().is_hpu_device() - && std::env::var("ATTENTION").as_deref() != Ok("paged") - { + if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); break; } @@ -1642,7 +1639,7 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().is_hpu_device() { + if env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("HPU detected, shard is ready"); break; } From 376e0507b7fda4c63113b64a117711e7a73ccc68 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 28 Mar 2025 01:08:40 -0700 Subject: [PATCH 21/35] missing gptj change... Signed-off-by: Wang, Yi A --- .../custom_modeling/flash_gptj_modeling.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 3135acde..41eeab78 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -156,10 +156,8 @@ class FlashGPTJAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( @@ -177,16 +175,9 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value - kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -201,7 +192,6 @@ class FlashGPTJAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -211,7 +201,6 @@ class FlashGPTJAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -272,10 +261,8 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -286,10 +273,8 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -334,10 +319,8 @@ class FlashGPTJModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) @@ -355,10 +338,8 @@ class FlashGPTJModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -387,10 +368,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -400,10 +379,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices=prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: From f0e5faec1a77b3d6719d8363113db3361a314b32 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 28 Mar 2025 07:01:06 -0700 Subject: [PATCH 22/35] fix some issue Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/kv_cache.py | 9 +++++---- .../models/custom_modeling/flash_llava_next.py | 2 +- .../text_generation_server/models/flash_causal_lm.py | 5 +++-- .../text_generation_server/models/flash_vlm_causal_lm.py | 4 ++++ .../text_generation_server/models/mllama_causal_lm.py | 5 ++++- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index 26c80c70..cdd0458b 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -5,6 +5,7 @@ import torch from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.utils.weights import Weights +from vllm_hpu_extension import cache_ops @dataclass @@ -115,12 +116,12 @@ def paged_reshape_and_cache( v_scale: float = 1.0, ): - from vllm_hpu_extension import cache_ops - + mask = torch.where(slots != -1) + slots = slots[mask] block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE - cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 62e8470c..88548042 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -153,7 +153,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): image_features: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" - mask = input_ids == self.config.image_token_index + mask = torch.where(input_ids == self.config.image_token_index) # Let's pray we have enabled enough slots ! try: inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index b0859c3d..816f05d0 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -998,7 +998,8 @@ class FlashCausalLMBatch(Batch): input_ids = [0] * extra_pad + input_ids self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: - logger.error("should not be here, prefill self.input_ids is a tensor") + self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) + input_ids_padded_length.append(extra_pad) self.input_lengths_tensor = torch.tensor( self.input_lengths, dtype=torch.int32, device=device @@ -1660,7 +1661,7 @@ class FlashCausalLM(Model): if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index b5d93cbc..f630a85a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -457,6 +457,10 @@ class FlashVlmCausalLM(FlashCausalLM): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) + if batch.prefill_cache_indices is not None: + slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index eabbe247..bd123725 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -282,7 +282,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False - + if batch.prefill_cache_indices is not None: + slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, From c55a8caea27b28817a18e1b2bbf7afea3f6146b3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 31 Mar 2025 22:51:54 -0700 Subject: [PATCH 23/35] remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/kv_cache.py | 7 ++----- .../text_generation_server/models/flash_causal_lm.py | 2 +- .../text_generation_server/models/flash_vlm_causal_lm.py | 2 +- .../text_generation_server/models/mllama_causal_lm.py | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index cdd0458b..d238cdb9 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -115,13 +115,10 @@ def paged_reshape_and_cache( k_scale: float = 1.0, v_scale: float = 1.0, ): - - mask = torch.where(slots != -1) - slots = slots[mask] block_idx = slots // BLOCK_SIZE block_offset = slots % BLOCK_SIZE - cache_ops.insert_or_update_cache(key[mask], key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value[mask], value_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 816f05d0..a4d58596 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1661,7 +1661,7 @@ class FlashCausalLM(Model): if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: - slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index f630a85a..208ab358 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -458,7 +458,7 @@ class FlashVlmCausalLM(FlashCausalLM): cu_seqlen_q=cu_seqlen_prefill, ) if batch.prefill_cache_indices is not None: - slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index bd123725..e034ed49 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -283,7 +283,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: - slots_pad = torch.ones_like(input_ids, dtype=torch.long) * -1 + slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad logits, speculative_logits = self.model.forward( From 9d85ac948549e1d19cb0a9705c5a066fd1ca8918 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 30 Mar 2025 20:20:09 -0700 Subject: [PATCH 24/35] LLM warmup logic Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 208 +++++++++++++----- 1 file changed, 156 insertions(+), 52 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index a4d58596..ed8b658a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -71,6 +71,7 @@ import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools from vllm_hpu_extension.ops import batch2block, block2batch +from vllm_hpu_extension.bucketing import HPUBucketingContext tracer = trace.get_tracer(__name__) @@ -89,7 +90,7 @@ def get_sliding_windows() -> int: def prepare_for_decode( - dtype, use_contiguous_pa, device, slot, block_tables, batch_size + dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx ): # Prepare values if we need to continue decoding # need for HPUPagedAttentionMetadata preparation @@ -120,8 +121,10 @@ def prepare_for_decode( assert len(block_list) == len(block_usage) if use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - # block_bucket_size) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -131,6 +134,10 @@ def prepare_for_decode( block_usage = gather_list(block_usage, indices, 1) else: block_bucket_size = len(block_list) + if bucketing_ctx is not None: + block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size + ) block_list = pad_list(block_list, block_bucket_size, 0) block_groups = pad_list(block_groups, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) @@ -835,15 +842,16 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = ( - batch.slot_indices + cumulative_slots + index = torch.tensor( + list(range(start_index, end_index)), device=batch.input_ids.device ) - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor - - # Copy over adapter indices + input_ids.index_copy_(0, index, batch.input_ids) + position_ids.index_copy_(0, index, batch.position_ids) + slot_indices.index_copy_( + 0, index, batch.slot_indices + cumulative_slots + ) + input_lengths_tensor.index_copy_(0, index, batch.input_lengths_tensor) + cache_lengths_tensor.index_copy_(0, index, batch.cache_lengths_tensor) adapter_start_index = cumulative_adapter_indices_size adapter_end_index = ( cumulative_adapter_indices_size @@ -951,22 +959,34 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta=None, ) - def prepare_for_decode(self, dtype, use_contiguous_pa): + def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) + if bucketing_ctx is not None: + padded_bs = bucketing_ctx.get_padded_decode_batch_size( + self.input_ids.shape[0] + ) + else: + padded_bs = self.input_ids.shape[0] + slots = self.slots[self.slot_indices] + extra_pad = padded_bs - self.input_ids.shape[0] + if extra_pad != 0: + slots = F.pad(slots, (0, extra_pad), value=0) + block_tables.extend([[0]] * extra_pad) self.hpu_attn_meta = prepare_for_decode( dtype, use_contiguous_pa, self.block_tables_tensor.device, - self.slots[self.slot_indices], + slots, block_tables, - self.input_ids.size(0), + padded_bs, + bucketing_ctx, ) - def prepare_for_prefill(self): + def prepare_for_prefill(self, max_padded_input_len): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -980,7 +1000,7 @@ class FlashCausalLMBatch(Batch): # the right logit position input_ids_padded_length = [] # need extra pad to match warmup seq - extra_pad = 0 + extra_pad = max_padded_input_len - self.max_input_length if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1355,9 +1375,9 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - + self.bucketing_ctx = None if htorch.utils.internal.is_lazy(): - htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" @@ -1479,9 +1499,31 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + self.bucketing_ctx = HPUBucketingContext( + os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO + os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO + BLOCK_SIZE, + num_blocks * BLOCK_SIZE, + ) + self.bucketing_ctx.num_hpu_blocks = num_blocks + warmup_times = 3 + self.bucketing_ctx.generate_prompt_buckets() + for i, (batch_size, seq_len) in enumerate( + reversed(self.bucketing_ctx.prompt_buckets) + ): + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size) + self.bucketing_ctx.generate_decode_buckets(num_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num) + synchronize(self.device) return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): + logger.info(f"warmup prefill seq {prompt_len} bs {bs}") input_ids = torch.zeros( prompt_len, dtype=torch.int64, device=self.device ).repeat(bs) @@ -1527,25 +1569,32 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) - def warmup_decode(self, bs: int, block_num: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - start=1, end=block_num + 1, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + def warmup_decode(self, batch_size: int, block_num: int): + logger.info(f"warmup decode bs {batch_size} block_num {block_num}") + input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) + position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] slots = [] - past_len = ( - len(block_tables[0]) * BLOCK_SIZE - 1 - ) # for decode, we only need to pass the past token + start_idx = 0 + # fetch the last blocked to warmup block num - for i in range(bs): - slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] slots = torch.tensor(slots, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) - cache_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( @@ -1553,20 +1602,16 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - block_num = cache_lengths_tensor // BLOCK_SIZE + 1 - block_tables_valid = [] - for i, bt in enumerate(block_tables.tolist()): - block_tables_valid.append(bt[0 : block_num[i]]) hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, self.device, slots, - block_tables_valid, - bs, + block_tables, + batch_size, + bucketing_ctx=None, ) - # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -1651,19 +1696,69 @@ class FlashCausalLM(Model): # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - kwargs = {} - if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + orig_bs = input_lengths.shape[0] + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + position_ids = F.pad( + position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1677,7 +1772,9 @@ class FlashCausalLM(Model): hpu_attention_meta=batch.hpu_attn_meta, **kwargs, ) - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) @tracer.start_as_current_span("generate_token") def generate_token( @@ -1690,9 +1787,16 @@ class FlashCausalLM(Model): start = time.time_ns() prefill = batch.prefilling if prefill: - batch.prepare_for_prefill() + if self.bucketing_ctx is not None: + batch.prepare_for_prefill( + self.bucketing_ctx.get_padded_prompt_seq_len(batch.max_input_length) + ) + else: + batch.prepare_for_prefill(batch.max_input_length) else: - batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) + batch.prepare_for_decode( + self.dtype, self.use_contiguous_pa, self.bucketing_ctx + ) prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta From 705cc0b6195c7a7572d85d7a3acff563e9af32d1 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 1 Apr 2025 23:57:07 -0700 Subject: [PATCH 25/35] multi-modality warmup Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 16 +- .../models/flash_vlm_causal_lm.py | 153 +++++++++++++- .../models/mllama_causal_lm.py | 197 +++++++++++++++++- 3 files changed, 345 insertions(+), 21 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ed8b658a..48165256 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1487,7 +1487,6 @@ class FlashCausalLM(Model): if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 - del _batch, batch self.kv_cache = [] empty_cache() @@ -1499,6 +1498,7 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + self.bucketing_ctx = HPUBucketingContext( os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO @@ -1506,6 +1506,17 @@ class FlashCausalLM(Model): num_blocks * BLOCK_SIZE, ) self.bucketing_ctx.num_hpu_blocks = num_blocks + if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true": + logger.info("skip warmup hpu graph, not recommmended") + del _batch, batch + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + + self.warmup_hpu_graph(batch) + del _batch, batch + + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + + def warmup_hpu_graph(self, batch): warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() for i, (batch_size, seq_len) in enumerate( @@ -1513,14 +1524,13 @@ class FlashCausalLM(Model): ): for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size) - self.bucketing_ctx.generate_decode_buckets(num_blocks) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): for index in range(warmup_times): self.warmup_decode(batch_size, block_num) synchronize(self.device) - return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): logger.info(f"warmup prefill seq {prompt_len} bs {bs}") diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 208ab358..2f9de99f 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + prepare_for_decode, ) -from text_generation_server.models.globals import PREFIX_CACHING +from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata import habana_frameworks.torch as htorch +from text_generation_server.utils.import_utils import ( + synchronize, +) +import torch.nn.functional as F tracer = trace.get_tracer(__name__) @@ -375,6 +380,80 @@ class FlashVlmCausalLM(FlashCausalLM): def max_past(self) -> Optional[int]: return getattr(self.model.text_model, "max_past", None) + def warmup_decode( + self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch + ): + logger.info(f"warmup decode bs {batch_size} block_num {block_num}") + input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) + position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + if batch.position_ids is not None and batch.position_ids.dim() == 2: + # qwen2_vl and qwen2_5_vl case + position_ids = position_ids.unsqueeze(-1).repeat( + (1, batch.position_ids.shape[-1]) + ) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + batch_size, + bucketing_ctx=None, + ) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, + ) + + def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + def forward( self, batch: FlashVlmCausalLMBatch, @@ -450,17 +529,75 @@ class FlashVlmCausalLM(FlashCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + orig_bs = input_lengths.shape[0] + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + if position_ids.dim() == 2: + # qwen2_vl and qwen2_5_vl case + position_ids = F.pad( + position_ids, + (0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]), + value=1, + ) + else: + position_ids = F.pad( + position_ids, + (0, (padded_bs - orig_bs) * input_seq.shape[-1]), + value=1, + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -476,8 +613,6 @@ class FlashVlmCausalLM(FlashCausalLM): image_grid_thw=batch.image_grid_thw, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None if batch.pixel_attention_mask is not None: diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index e034ed49..55d80ca5 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -11,7 +11,9 @@ from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) - +from text_generation_server.models.flash_causal_lm import ( + prepare_for_decode, +) from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLMBatch, FlashVlmCausalLM, @@ -19,6 +21,12 @@ from text_generation_server.models.flash_vlm_causal_lm import ( from text_generation_server.pb import generate_pb2 from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata import habana_frameworks.torch as htorch +from loguru import logger +from text_generation_server.models.globals import BLOCK_SIZE +from text_generation_server.utils.import_utils import ( + synchronize, +) +import torch.nn.functional as F tracer = trace.get_tracer(__name__) @@ -197,6 +205,131 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): class FlashMllamaCausalLM(FlashVlmCausalLM): + def warmup_decode( + self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch + ): + logger.info(f"warmup decode bs {batch_size} block_num {block_num}") + input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) + position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + blocks = [block_num // batch_size for _ in range(batch_size)] + blocks[0] += block_num % batch_size + past_len = [] + block_tables = [] + slots = [] + start_idx = 0 + + # fetch the last blocked to warmup block num + for i in range(batch_size): + block_array = list(range(start_idx, start_idx + blocks[i])) + slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1) + block_tables.append(block_array) + past_len.append(blocks[i] * BLOCK_SIZE - 1) + start_idx += blocks[i] + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.tensor( + past_len, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + batch_size, + bucketing_ctx=None, + ) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, + cross_attention_states=batch.cross_attention_states, + image_indices=batch.image_indices[:], + ) + + def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): + logger.info(f"warmup prefill seq {prompt_len} bs {bs}") + input_ids = torch.zeros( + prompt_len, dtype=torch.int64, device=self.device + ).repeat(bs) + position_ids = torch.arange( + prompt_len, dtype=torch.int32, device=self.device + ).repeat(bs) + max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slot_acc = [] + for i in range(bs): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + + input_lengths = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + lm_head_indices = input_lengths - 1 + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=lm_head_indices, + cross_attention_states=batch.cross_attention_states, + adapter_data=None, + hpu_attention_meta=None, + image_indices=batch.image_indices[:], + ) + + def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + warmup_times = 3 + self.bucketing_ctx.generate_prompt_buckets() + for i, (batch_size, seq_len) in enumerate( + reversed(self.bucketing_ctx.prompt_buckets) + ): + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) + for i, (batch_size, block_num) in enumerate( + reversed(self.bucketing_ctx.decode_buckets) + ): + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + def forward( self, batch: FlashMllamaCausalLMBatch, @@ -263,12 +396,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - ) - if batch.pixel_values is not None: cross_attention_states = self.model.vision_forward( pixel_values=batch.pixel_values, @@ -286,6 +413,60 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad + + if self.bucketing_ctx is not None: + if batch.prefilling: + padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = self.bucketing_ctx.get_padded_decode_batch_size( + input_lengths.shape[0] + ) + else: + padded_bs = input_lengths.shape[0] + if padded_bs != input_lengths.shape[0]: + orig_bs = input_lengths.shape[0] + padded_input_lengths = F.pad( + input_lengths, + (0, padded_bs - orig_bs), + value=0, + ) + padded_cache_lengths_tensor = F.pad( + cache_lengths_tensor, + (0, padded_bs - orig_bs), + value=0, + ) + if cu_seqlen_prefill is not None: + cu_seqlen_prefill = torch.zeros( + padded_bs + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( + input_lengths=padded_input_lengths, + cache_lengths=padded_cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( + input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + ) + position_ids = F.pad( + position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 + ) + slots = F.pad( + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + ) + if lm_head_indices is not None: + lm_head_indices = F.pad( + lm_head_indices, (0, padded_bs - orig_bs), value=0 + ) + else: + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -301,8 +482,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): image_indices=batch.image_indices[:], **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None return logits, speculative_logits From a84da5b698d6c20c5b18d983f1fc8de959d6635d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 2 Apr 2025 00:56:15 -0700 Subject: [PATCH 26/35] optimize code Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 75 ++++--------------- 1 file changed, 15 insertions(+), 60 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 48165256..52a2ea61 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -328,6 +328,8 @@ class FlashCausalLMBatch(Batch): ### Deactivating it by default seems like the best course. if not REQUEST_LOGPROBS: r.prefill_logprobs = False + else: + assert False, "prefill_logprobs not supported yet" # request id -> idx in list mapping requests_idx_mapping[r.id] = i @@ -1847,10 +1849,6 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: prefill_logprobs = None next_token_logits = out @@ -1900,19 +1898,6 @@ class FlashCausalLM(Model): batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices ] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.all_input_ids, - accepted_ids, - current_prefilling_mask, - batch.prefilling_mask, - ) - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a HPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time @@ -1921,38 +1906,8 @@ class FlashCausalLM(Model): # Cumulative length cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - cumulative_length = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - all_input_ids, - n_accepted_ids, - request_was_prefilling, - request_is_prefilling, - ) in enumerate(iterator): - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[ - i, cache_length + 1 : cache_length + input_length + 1 - ] - if len(batch) > 1: - prefill_tokens_indices[out_start_index:out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids - - # If the device does not support triton, we copy one by one - if not request_is_prefilling: - # Only save tokens if we are done prefilling for this request + if speculative_logits is not None: + for i in range(len(batch)): batch.all_input_ids_tensor[ i, batch.cache_lengths_tensor[i] @@ -1960,7 +1915,17 @@ class FlashCausalLM(Model): + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - cumulative_length += input_length + else: + index = batch.cache_lengths_tensor + batch.input_lengths_tensor + batch_idx = torch.arange( + 0, + batch.all_input_ids_tensor.shape[0], + dtype=torch.long, + device=batch.input_lengths_tensor.device, + ) + batch.all_input_ids_tensor.index_put_( + (batch_idx, index.long()), next_input_ids + ) # Update values # These values can be updated without a HPU -> CPU sync @@ -1976,16 +1941,6 @@ class FlashCausalLM(Model): batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids - if prefill and prefill_logprobs: - # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) - torch.log_softmax(out, -1, out=out) - prefill_logprobs_tensor = out - prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) - ) - # HPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - # Does a HPU <-> CPU sync internally if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding From 8591687561e197cd6722b5277b77d151642a71c5 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 2 Apr 2025 19:11:35 -0700 Subject: [PATCH 27/35] refine log and fix some issue Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 35 +++++++++++-------- .../models/flash_vlm_causal_lm.py | 29 ++++++++++----- .../models/mllama_causal_lm.py | 31 ++++++++++------ 3 files changed, 60 insertions(+), 35 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 52a2ea61..334f004e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1508,7 +1508,7 @@ class FlashCausalLM(Model): num_blocks * BLOCK_SIZE, ) self.bucketing_ctx.num_hpu_blocks = num_blocks - if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true": + if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": logger.info("skip warmup hpu graph, not recommmended") del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens @@ -1524,23 +1524,26 @@ class FlashCausalLM(Model): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size) + self.warmup_prefill(seq_len, batch_size, batch) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) for index in range(warmup_times): - self.warmup_decode(batch_size, block_num) + self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) - def warmup_prefill(self, prompt_len: int, bs: int): - logger.info(f"warmup prefill seq {prompt_len} bs {bs}") + def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch): input_ids = torch.zeros( - prompt_len, dtype=torch.int64, device=self.device + prompt_len, dtype=batch.input_ids.dtype, device=self.device ).repeat(bs) position_ids = torch.arange( - prompt_len, dtype=torch.int32, device=self.device + prompt_len, dtype=batch.position_ids.dtype, device=self.device ).repeat(bs) max_bt = (prompt_len // BLOCK_SIZE + 1) * bs block_tables = torch.arange( @@ -1552,7 +1555,7 @@ class FlashCausalLM(Model): for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len @@ -1581,10 +1584,13 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) - def warmup_decode(self, batch_size: int, block_num: int): - logger.info(f"warmup decode bs {batch_size} block_num {block_num}") - input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) - position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch): + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size past_len = [] @@ -1599,7 +1605,7 @@ class FlashCausalLM(Model): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -1725,7 +1731,6 @@ class FlashCausalLM(Model): padded_bs = input_lengths.shape[0] orig_bs = input_lengths.shape[0] if padded_bs != input_lengths.shape[0]: - orig_bs = input_lengths.shape[0] padded_input_lengths = F.pad( input_lengths, (0, padded_bs - orig_bs), @@ -1754,7 +1759,7 @@ class FlashCausalLM(Model): position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 ) slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 ) if lm_head_indices is not None: lm_head_indices = F.pad( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 2f9de99f..725e7517 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -383,9 +383,12 @@ class FlashVlmCausalLM(FlashCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch ): - logger.info(f"warmup decode bs {batch_size} block_num {block_num}") - input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) - position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) if batch.position_ids is not None and batch.position_ids.dim() == 2: # qwen2_vl and qwen2_5_vl case position_ids = position_ids.unsqueeze(-1).repeat( @@ -405,7 +408,7 @@ class FlashVlmCausalLM(FlashCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -438,9 +441,12 @@ class FlashVlmCausalLM(FlashCausalLM): kv_cache=self.kv_cache, slots=slots, seqlen=trim_seqlen_metadata(seqlen), - lm_head_indices=None, - adapter_data=None, hpu_attention_meta=hpu_attention_meta, + lm_head_indices=None, + pixel_values=None, + pixel_attention_mask=None, + image_sizes=None, + image_grid_thw=None, ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): @@ -450,6 +456,9 @@ class FlashVlmCausalLM(FlashCausalLM): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) @@ -546,8 +555,8 @@ class FlashVlmCausalLM(FlashCausalLM): ) else: padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] if padded_bs != input_lengths.shape[0]: - orig_bs = input_lengths.shape[0] padded_input_lengths = F.pad( input_lengths, (0, padded_bs - orig_bs), @@ -586,7 +595,7 @@ class FlashVlmCausalLM(FlashCausalLM): value=1, ) slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 ) if lm_head_indices is not None: lm_head_indices = F.pad( @@ -621,4 +630,6 @@ class FlashVlmCausalLM(FlashCausalLM): batch.image_sizes = None if batch.image_grid_thw is not None: batch.image_grid_thw = None - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 55d80ca5..acd5d9a5 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -27,6 +27,7 @@ from text_generation_server.utils.import_utils import ( synchronize, ) import torch.nn.functional as F +from text_generation_server.utils.log import log_master tracer = trace.get_tracer(__name__) @@ -208,9 +209,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch ): - logger.info(f"warmup decode bs {batch_size} block_num {block_num}") - input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) - position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) + input_ids = torch.zeros( + batch_size, dtype=batch.input_ids.dtype, device=self.device + ) + position_ids = torch.arange( + batch_size, dtype=batch.position_ids.dtype, device=self.device + ) blocks = [block_num // batch_size for _ in range(batch_size)] blocks[0] += block_num % batch_size past_len = [] @@ -225,7 +229,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -266,12 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): - logger.info(f"warmup prefill seq {prompt_len} bs {bs}") input_ids = torch.zeros( - prompt_len, dtype=torch.int64, device=self.device + prompt_len, dtype=batch.input_ids.dtype, device=self.device ).repeat(bs) position_ids = torch.arange( - prompt_len, dtype=torch.int32, device=self.device + prompt_len, dtype=batch.position_ids.dtype, device=self.device ).repeat(bs) max_bt = (prompt_len // BLOCK_SIZE + 1) * bs block_tables = torch.arange( @@ -283,7 +286,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slot_acc.extend(slots[:prompt_len]) - slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) + slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len @@ -320,12 +323,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + log_master( + logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) @@ -425,8 +432,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) else: padded_bs = input_lengths.shape[0] + orig_bs = input_lengths.shape[0] if padded_bs != input_lengths.shape[0]: - orig_bs = input_lengths.shape[0] padded_input_lengths = F.pad( input_lengths, (0, padded_bs - orig_bs), @@ -455,7 +462,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 ) slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 + slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 ) if lm_head_indices is not None: lm_head_indices = F.pad( @@ -484,4 +491,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) if batch.pixel_values is not None: batch.pixel_values = None - return logits, speculative_logits + return logits[:orig_bs], ( + speculative_logits[:orig_bs] if speculative_logits is not None else None + ) From 29703dbd274d0c86f06fe7dac59dbbd71719cea9 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 4 Apr 2025 05:42:59 -0700 Subject: [PATCH 28/35] fix warmup issue for mllama Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_mllama.py | 58 +--------- .../models/flash_causal_lm.py | 24 ++-- .../models/mllama_causal_lm.py | 104 ++++++++++++++---- 3 files changed, 100 insertions(+), 86 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index 216642e0..421a0a65 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module): # bsz, q_len, _ = hidden_states.size() ( cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, + cross_attention_len, indices, ) = cross_attention_states - bs = cu_seqlen_q.size(0) - 1 + bs = cross_attention_len.size(0) query_states = self.q_proj(hidden_states) query_states = query_states.view(bs, -1, self.num_heads, self.head_size) query_states = self.q_norm(query_states) @@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module): indices = cross_attention_states[-1] out_hidden_states = hidden_states[:] - if len(indices) > 0: - assert max(indices) < hidden_states.shape[0] hidden_states = hidden_states[indices] residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, - # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, - image_indices=None, + indices=None, + cross_attention_len: Optional[torch.Tensor] = None, ): if cross_attention_states is not None: - seqlen_q = len(image_indices) - n_images = cross_attention_states.shape[0] - seqlen_k = cross_attention_states.shape[1] - device = cross_attention_states.device - if cu_seqlen_prefill is not None: - offset = 0 - cu_q = [] - indices = [] - for index in image_indices: - cu_q.append(offset) - length = seqlen.input_lengths[index].item() - assert index < seqlen.cu_seqlen_q.shape[0] - input_ids_offset = seqlen.cu_seqlen_q[index] - indices.extend(range(input_ids_offset, input_ids_offset + length)) - offset += length - cu_q.append(offset) - cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) - - assert max(indices) < input_ids.shape[0] - - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - else: - cu_seqlen_q = torch.arange( - seqlen_q + 1, device=device, dtype=torch.int32 - ) - seqlen_k = cross_attention_states.shape[1] - n_images = cross_attention_states.shape[0] - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - indices = image_indices[:] - cross_attention_states = ( cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, + cross_attention_len, indices, ) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 334f004e..23a40016 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1538,19 +1538,21 @@ class FlashCausalLM(Model): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) - def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch): + def warmup_prefill( + self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch + ): input_ids = torch.zeros( prompt_len, dtype=batch.input_ids.dtype, device=self.device - ).repeat(bs) + ).repeat(batch_size) position_ids = torch.arange( prompt_len, dtype=batch.position_ids.dtype, device=self.device - ).repeat(bs) - max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + ).repeat(batch_size) + max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + ).reshape(batch_size, -1) slot_acc = [] - for i in range(bs): + for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) @@ -1558,10 +1560,14 @@ class FlashCausalLM(Model): slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index acd5d9a5..940ee1b0 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -205,6 +205,24 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): return batch +def generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling +): + if cross_attention_states is None: + return None, None, None + device = cross_attention_states.device + indices_list = [] + if prefilling: + for i in image_indices: + indices_list.append( + torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device) + ) + indices = torch.cat(indices_list, dim=0) + else: + indices = image_indices[:] + return indices, seqlen.input_lengths.index_select(0, image_indices) + + class FlashMllamaCausalLM(FlashVlmCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch @@ -255,6 +273,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): bucketing_ctx=None, ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = image_indices.repeat(batch_size) + cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, 1, False + ) self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -262,26 +286,29 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kv_cache=self.kv_cache, slots=slots, seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, adapter_data=None, - hpu_attention_meta=hpu_attention_meta, - cross_attention_states=batch.cross_attention_states, - image_indices=batch.image_indices[:], + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, ) - def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): + def warmup_prefill( + self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch + ): input_ids = torch.zeros( prompt_len, dtype=batch.input_ids.dtype, device=self.device - ).repeat(bs) + ).repeat(batch_size) position_ids = torch.arange( prompt_len, dtype=batch.position_ids.dtype, device=self.device - ).repeat(bs) - max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + ).repeat(batch_size) + max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + ).reshape(batch_size, -1) slot_acc = [] - for i in range(bs): + for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) @@ -289,10 +316,14 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( @@ -303,6 +334,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = image_indices.repeat(batch_size) + cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, prompt_len, True + ) self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -310,11 +347,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kv_cache=self.kv_cache, slots=slots, seqlen=trim_seqlen_metadata(seqlen), - lm_head_indices=lm_head_indices, - cross_attention_states=batch.cross_attention_states, - adapter_data=None, hpu_attention_meta=None, - image_indices=batch.image_indices[:], + lm_head_indices=lm_head_indices, + adapter_data=None, + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): @@ -433,6 +471,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): else: padded_bs = input_lengths.shape[0] orig_bs = input_lengths.shape[0] + padded_input_len = input_ids.view(orig_bs, -1).shape[-1] + image_indices = torch.tensor(batch.image_indices, device=self.device) if padded_bs != input_lengths.shape[0]: padded_input_lengths = F.pad( input_lengths, @@ -454,26 +494,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cache_lengths=padded_cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( - input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0 ) position_ids = F.pad( - position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 - ) - slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1 ) + slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0) if lm_head_indices is not None: lm_head_indices = F.pad( lm_head_indices, (0, padded_bs - orig_bs), value=0 ) + if cross_attention_states is not None: + cross_attention_states = F.pad( + cross_attention_states, + (0, 0, 0, 0, 0, (padded_bs - orig_bs)), + value=0, + ) + if len(image_indices) != 0: + pad_indices = torch.arange(orig_bs, padded_bs, device=self.device) + image_indices = torch.cat((image_indices, pad_indices), dim=0) else: seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) + + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, + image_indices, + seqlen, + padded_input_len, + batch.prefilling, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -483,10 +538,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, # TODO list adapter_data=None, - image_indices=batch.image_indices[:], + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, **kwargs, ) if batch.pixel_values is not None: From cd900c3b729c9fecea930b34ce26972d2e527e82 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 8 Apr 2025 19:56:10 -0700 Subject: [PATCH 29/35] pingpong optimization Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 615 ++++++++---------- .../models/flash_vlm_causal_lm.py | 2 + .../models/mllama_causal_lm.py | 2 + 3 files changed, 274 insertions(+), 345 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 23a40016..51adffc7 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -253,6 +253,9 @@ class FlashCausalLMBatch(Batch): hpu_attn_meta: Optional[HPUPagedAttentionMetadata] + next_token_logits: Optional[torch.Tensor] + speculative_logits: Optional[torch.Tensor] + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -490,6 +493,8 @@ class FlashCausalLMBatch(Batch): input_lengths_tensor=None, adapter_meta=None, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) @classmethod @@ -698,6 +703,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) @classmethod @@ -959,6 +966,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, + next_token_logits=None, + speculative_logits=None, ) def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): @@ -1484,7 +1493,7 @@ class FlashCausalLM(Model): log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") if max_total_tokens is None: - max_total_tokens = sum(batch.cache_lengths) + max_total_tokens = sum(batch.input_lengths) if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 @@ -1531,6 +1540,8 @@ class FlashCausalLM(Model): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + if batch_size > block_num: + continue log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) @@ -1803,6 +1814,144 @@ class FlashCausalLM(Model): def generate_token( self, batches: List[FlashCausalLMBatch] ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + prev_batches = [] + requests_to_generate = [] + for batch_id, batch in enumerate(batches): + if batch.next_token_logits is not None: + prefill = batch.prefilling + if batch.prefilling: + batch.prefilling = False + batch.prefilling_mask = [False] * len(batch) + + speculate = get_speculate() + ( + next_input_ids, + next_token_logprobs, + logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_current_length], + batch.next_token_logits, + speculate, + batch.speculative_ids, + batch.speculative_logits, + ) + + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill: + indices = batch.cu_seqlen_prefill[1:] - 1 + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[ + batch.prefill_cache_indices + ][indices] + else: + batch.position_ids = batch.position_ids[indices] + + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = ( + batch.adapter_meta.adapter_indices[indices] + ) + # For each member of the batch + # Cumulative length + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + if batch.speculative_logits is not None: + for i in range(len(batch)): + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] + else: + index = batch.cache_lengths_tensor + batch.input_lengths_tensor + batch_idx = torch.arange( + 0, + batch.all_input_ids_tensor.shape[0], + dtype=torch.long, + device=batch.input_lengths_tensor.device, + ) + batch.all_input_ids_tensor.index_put_( + (batch_idx, index.long()), next_input_ids + ) + + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.speculative_ids = speculative_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += ( + batch.input_lengths_tensor + accepted_ids - 1 + ) + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) + batch.slot_indices += accepted_ids + + # Does a HPU <-> CPU sync internally + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments( + batch.adapter_meta.adapter_indices + ) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + prev_batches.append( + { + "next_token_ids": next_input_ids, + "next_token_logprobs": next_token_logprobs, + "accepted_ids": accepted_ids, + } + ) + idx = len(prev_batches) - 1 + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append( + { + "idx": idx, + "request_id": req.id, + "cache_length": batch.cache_lengths[req_idx], + "input_length": batch.input_lengths[req_idx], + "prefix_offset": batch.prefix_offsets[req_idx], + "read_offset": batch.read_offsets[req_idx], + "stopping_criteria": batch.stopping_criterias[req_idx], + "all_input_ids": batch.all_input_ids[req_idx], + "do_sample": batch.next_token_chooser.do_sample[req_idx], + "seed": batch.next_token_chooser.seeds[req_idx], + "top_n_tokens": batch.top_n_tokens[req_idx], + "top_token_ids": batch_top_token_ids[req_idx], + "top_token_logprobs": batch_top_token_logprobs[req_idx], + } + ) + if prefill: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None + batch.next_token_logits = None + batch.speculative_ids = None + + htorch.core.mark_step() + # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: batch = self.batch_type.concatenate(batches) else: @@ -1851,7 +2000,7 @@ class FlashCausalLM(Model): out, speculative_logits = self.forward(batch, adapter_data) if prefill: - next_token_logits = ( + batch.next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) if speculative_logits is not None: @@ -1862,364 +2011,147 @@ class FlashCausalLM(Model): ) else: prefill_logprobs = None - next_token_logits = out + batch.next_token_logits = out + batch.speculative_logits = speculative_logits - finished_prefilling = True - next_chunk_lengths = [] - current_prefilling_mask = batch.prefilling_mask - if prefill: - finished_prefilling = True - next_prefilling_mask = [False] * len(batch) - - batch.prefilling = not finished_prefilling - batch.prefilling_mask = next_prefilling_mask - - speculate = get_speculate() - ( - next_input_ids, - next_token_logprobs, - logprobs, - accepted_ids, - speculative_ids, - ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], - next_token_logits, - speculate, - batch.speculative_ids, - speculative_logits, - ) - - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids - ) - - # Since we are done prefilling, all the tensors that were concatenating values for all the requests - # instantly become of shape [BATCH_SIZE] - if prefill and finished_prefilling: - indices = batch.cu_seqlen_prefill[1:] - 1 - # pad in left - if batch.prefill_cache_indices is not None: - batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ - indices - ] - else: - batch.position_ids = batch.position_ids[indices] - - batch.slot_indices = batch.slot_indices[indices] - batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ - indices - ] - # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a HPU <-> CPU sync - # It is faster if we delay this sync for the maximum amount of time - - # For each member of the batch - # Cumulative length - cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) - torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - if speculative_logits is not None: - for i in range(len(batch)): - batch.all_input_ids_tensor[ - i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] - + batch.input_lengths[i] - + accepted_ids[i], - ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] - else: - index = batch.cache_lengths_tensor + batch.input_lengths_tensor - batch_idx = torch.arange( - 0, - batch.all_input_ids_tensor.shape[0], - dtype=torch.long, - device=batch.input_lengths_tensor.device, - ) - batch.all_input_ids_tensor.index_put_( - (batch_idx, index.long()), next_input_ids - ) - - # Update values - # These values can be updated without a HPU -> CPU sync - if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] - batch.speculative_ids = speculative_ids - if batch.position_ids.dim() == 2: - # Qwen2_vl case: - batch.position_ids += accepted_ids.unsqueeze(-1) - else: - batch.position_ids += accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 - batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids - - # Does a HPU <-> CPU sync internally - if prefill and finished_prefilling: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) - - # HPU <-> CPU sync - next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = next_input_ids.tolist() - accepted_ids = accepted_ids.tolist() - - # Update values if we need to continue prefilling - # This represents the `else` case of the `Update values` if above - # but since this require the `next_token_ids` to be on CPU, it is better to do it here - if prefill and not finished_prefilling: - # Speculation must be ignored while we prefill even with chunking - # it simplifies everything - assert batch.speculative_ids is None - - all_postfix_ids = [] - for i, ( - request_prefilling, - next_token_id, - all_input_ids, - cache_length, - input_length, - next_chunk_length, - ) in enumerate( - zip( - batch.prefilling_mask, - next_token_ids, - batch.all_input_ids, - batch.cache_lengths, - batch.input_lengths, - next_chunk_lengths, - ) - ): - if request_prefilling: - next_cache_length = cache_length + input_length - # Get new prompt IDs to prefill - postfix_ids = all_input_ids[ - next_cache_length : next_cache_length + next_chunk_length - ] - else: - # This request is done prefilling, the new id is the one selected the sampling method - postfix_ids = [next_token_id] - - all_postfix_ids.append(postfix_ids) - - batch.input_ids = all_postfix_ids + # HPU->CPU sync + for prev_batch in prev_batches: + prev_batch["next_token_logprobs"] = prev_batch[ + "next_token_logprobs" + ].tolist() + prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist() + prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist() start_decode = time.time_ns() - + # Stage 3. Finish and return previous generations # Results generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.prompt_lengths, - batch.cache_lengths, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - batch.stopping_criterias, - batch.all_input_ids, - batch.next_token_chooser.do_sample, - batch.next_token_chooser.seeds, - batch.top_n_tokens, - current_prefilling_mask, - batch.prefilling_mask, - accepted_ids, - batch_top_token_ids, - batch_top_token_logprobs, - ) - + stopped = len(requests_to_generate) > 0 # Reset max_input_length batch.max_input_length = 0 # For each member of the batch - index = 0 - for i, ( - request, - prompt_length, - cache_length, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - top_n_tokens, - request_was_prefilling, - request_is_prefilling, - n_accepted_ids, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Compute logprobs first as, even though we might skip the token, - # it can still be required to compute the logprobs - # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need - # this state to be stable - if request.id % self.world_size == self.rank: - # Prefill - if request_was_prefilling and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - if not request_is_prefilling: - # The request is dones prefilling, meaning that we started generating new tokens - # The last logprob is a logprob for a generated token that was not part of the prompt - # We need to remove it - out_end_index -= 1 + indexs = [0] * len(prev_batches) + idx_accept_ids = [0] * len(prev_batches) + for i, req_data in enumerate(requests_to_generate): + idx = req_data["idx"] + request_id = req_data["request_id"] + cache_length = req_data["cache_length"] + input_length = req_data["input_length"] + prefix_offset = req_data["prefix_offset"] + read_offset = req_data["read_offset"] + stopping_criteria = req_data["stopping_criteria"] + all_input_ids = req_data["all_input_ids"] + do_sample = req_data["do_sample"] + seed = req_data["seed"] + top_n_tokens = req_data["top_n_tokens"] + n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]] + top_token_ids = req_data["top_token_ids"] + top_token_logprobs = req_data["top_token_logprobs"] - request_prefill_logprobs = prefill_logprobs[ - out_start_index:out_end_index - ] - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[ - cache_length + 1 : cache_length + input_length + 1 - ] + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 + # Append next token to all tokens + next_token_texts = [] + left = 0 - past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - if past_prefill_logprob_tokens is None: - # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * ( - cache_length + 1 - ) + request_prefill_logprobs - prefill_token_ids = ( - all_input_ids[: cache_length + 1] + prefill_token_ids - ) + current_stopped = False + index = indexs[idx] + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = prev_batches[idx]["next_token_ids"][j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - prefill_logprob_tokens = Tokens( - prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - is_special=[], - ) - if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = ( - past_prefill_logprob_tokens + prefill_logprob_tokens - ) - - batch.prefill_logprob_tokens[i] = prefill_logprob_tokens + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break else: - batch.prefill_logprob_tokens[i] = None + current_stopped = False + stopped = stopped and current_stopped - # If it is, the tokens we decoded should be ignored - if request_is_prefilling: - # Make sure that we do not stop as even though this request did not create a token, it is still - # processing - stopped = False - new_input_length = next_chunk_lengths[i] - new_cache_length = cache_length + input_length - else: - new_input_length = 1 - new_cache_length = cache_length + input_length + n_accepted_ids - 1 - # Append next token to all tokens - next_token_texts = [] - left = 0 + _next_token_ids = prev_batches[idx]["next_token_ids"][ + index : index + n_accepted_ids - left + ] + _next_token_logprobs = prev_batches[idx]["next_token_logprobs"][ + index : index + n_accepted_ids - left + ] - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( + # Shard generations + # All generations will be appended in the rust sharded client + if request_id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( all_input_ids, - prefix_offset, - read_offset, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) + else: + generated_text = None - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - - # Shard generations - # All generations will be appended in the rust sharded client - if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, ) - else: - generated_text = None + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None + generation = Generation( + request_id, + None, + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) - generation = Generation( - request.id, - batch.prefill_logprob_tokens[i], - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) + generations.append(generation) # accept each new token for this specific request since we may # have more than one new token per request with speculative decoding @@ -2231,7 +2163,8 @@ class FlashCausalLM(Model): ) # Update values - index += n_accepted_ids + indexs[idx] += n_accepted_ids + idx_accept_ids[idx] += 1 batch.cache_lengths[i] = new_cache_length batch.max_input_length = max(batch.max_input_length, new_input_length) batch.input_lengths[i] = new_input_length @@ -2248,14 +2181,6 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - if prefill and finished_prefilling: - # We do not need prefill tensors anymore - batch.cu_seqlen_prefill = None - batch.prefill_cache_indices = None - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None - forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 725e7517..cdda751a 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -456,6 +456,8 @@ class FlashVlmCausalLM(FlashCausalLM): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + if batch_size > block_num: + continue log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 940ee1b0..d21cc39d 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -368,6 +368,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) ): + if batch_size > block_num: + continue log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) From 4cdc34ec4dc46f56b4745e0b8b25faef35c486da Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 10 Apr 2025 19:32:32 -0700 Subject: [PATCH 30/35] match the latest vllm_extension ops Signed-off-by: Wang, Yi A --- .../gaudi/server/text_generation_server/layers/attention/hpu.py | 2 +- backends/gaudi/server/text_generation_server/layers/fp8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 526dbcec..f34e93ab 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -68,7 +68,7 @@ def paged_attention( ): batch_size, head_num, head_size = query.shape output = ops.flat_pa( - query=query, + query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, value_cache=kv_cache.value, block_list=hpu_attention_meta.block_list, diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 6c8d637e..0dc5cdaf 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -11,7 +11,7 @@ from text_generation_server.utils.weights import ( ) from vllm_hpu_extension.ops import scaled_fp8_quant -from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 +from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 import habana_frameworks.torch.utils.experimental as htexp w8a8_block_fp8_matmul = None From a83e9fe003c4a3373f97deebf590fc32d637724e Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 10 Apr 2025 19:56:58 -0700 Subject: [PATCH 31/35] work with the latest vllm extension ops Signed-off-by: Wang, Yi A --- .../server/text_generation_server/models/flash_causal_lm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 51adffc7..a2cbf30c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1512,9 +1512,10 @@ class FlashCausalLM(Model): self.bucketing_ctx = HPUBucketingContext( os.getenv("DECODE_MAX_BS", 128), # self.max_num_seqs, #TODO - os.getenv("PREFILL_MAX_BS", 16), # self.max_num_prefill_seqs, #TODO + os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, num_blocks * BLOCK_SIZE, + False, ) self.bucketing_ctx.num_hpu_blocks = num_blocks if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": From 76cc12979620fe73d2e51ffb44d13b5f4af7c8a0 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 11 Apr 2025 01:27:49 -0700 Subject: [PATCH 32/35] remove block_scales which is not needed anymore Signed-off-by: Wang, Yi A --- .../text_generation_server/layers/attention/common.py | 2 -- .../server/text_generation_server/layers/attention/hpu.py | 1 - .../text_generation_server/models/flash_causal_lm.py | 7 ------- 3 files changed, 10 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 8ec9fb46..34c77040 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata: block_list: Optional[torch.Tensor] block_mapping: Optional[torch.Tensor] block_usage: Optional[torch.Tensor] - block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor] @@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: "block_list", "block_mapping", "block_usage", - "block_scales", "block_groups", "attn_bias", ], diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index f34e93ab..1d73dcb3 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -74,7 +74,6 @@ def paged_attention( block_list=hpu_attention_meta.block_list, block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, - block_scales=hpu_attention_meta.block_scales, block_groups=hpu_attention_meta.block_groups, scale=softmax_scale, matmul_qk_op=Matmul(), diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index a2cbf30c..d4ff3f70 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -70,7 +70,6 @@ from text_generation_server.utils.import_utils import ( import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.bucketing import HPUBucketingContext tracer = trace.get_tracer(__name__) @@ -149,11 +148,6 @@ def prepare_for_decode( mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= block_usage.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) - ones = torch.ones( - (block_mapping.size(0),), device=device, dtype=block_mapping.dtype - ) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list, @@ -161,7 +155,6 @@ def prepare_for_decode( block_usage=block_usage, block_mapping=block_mapping.to(dtype), attn_bias=attn_bias, - block_scales=block_scales, ) ) From ba049c9d49465a362d5ecfa4a58366bcc635696d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 11 Apr 2025 06:10:17 -0700 Subject: [PATCH 33/35] improve performance Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 52 ++++++++++++------- .../models/flash_vlm_causal_lm.py | 4 +- .../models/mllama_causal_lm.py | 4 +- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index d4ff3f70..5c7b8bc0 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -89,7 +89,7 @@ def get_sliding_windows() -> int: def prepare_for_decode( - dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx + dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx ): # Prepare values if we need to continue decoding # need for HPUPagedAttentionMetadata preparation @@ -105,7 +105,7 @@ def prepare_for_decode( padding = target_len - input_len return input + [v] * padding - last_block_usage = slot % BLOCK_SIZE + 1 + last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots] block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] block_usage = [ [BLOCK_SIZE] * (len(bt) - 1) + [lbu] @@ -964,7 +964,7 @@ class FlashCausalLMBatch(Batch): ) def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): - block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 + block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths] block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) @@ -984,7 +984,7 @@ class FlashCausalLMBatch(Batch): dtype, use_contiguous_pa, self.block_tables_tensor.device, - slots, + slots.cpu(), block_tables, padded_bs, bucketing_ctx, @@ -1616,7 +1616,6 @@ class FlashCausalLM(Model): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -1641,13 +1640,14 @@ class FlashCausalLM(Model): batch_size, bucketing_ctx=None, ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), lm_head_indices=None, adapter_data=None, @@ -1866,8 +1866,8 @@ class FlashCausalLM(Model): for i in range(len(batch)): batch.all_input_ids_tensor[ i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.cache_lengths[i] + + batch.input_lengths[i] : batch.cache_lengths[i] + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] @@ -1915,14 +1915,36 @@ class FlashCausalLM(Model): } ) idx = len(prev_batches) - 1 + if batch.speculative_logits is not None: + accepted_ids_cpu = accepted_ids.cpu() for req_idx, req in enumerate(batch.requests): + new_input_length = 1 + if batch.speculative_logits is not None: + new_cache_length = ( + batch.cache_lengths[req_idx] + + batch.input_lengths[req_idx] + + accepted_ids_cpu[req_idx] + - 1 + ) + else: + new_cache_length = ( + batch.cache_lengths[req_idx] + batch.input_lengths[req_idx] + ) + batch.cache_lengths[req_idx] = new_cache_length + batch.max_input_length = max( + batch.max_input_length, new_input_length + ) + batch.input_lengths[req_idx] = new_input_length + current_length = new_cache_length + new_input_length + batch.max_current_length = max( + batch.max_current_length, current_length + ) + requests_to_generate.append( { "idx": idx, "request_id": req.id, - "cache_length": batch.cache_lengths[req_idx], - "input_length": batch.input_lengths[req_idx], "prefix_offset": batch.prefix_offsets[req_idx], "read_offset": batch.read_offsets[req_idx], "stopping_criteria": batch.stopping_criterias[req_idx], @@ -2029,8 +2051,6 @@ class FlashCausalLM(Model): for i, req_data in enumerate(requests_to_generate): idx = req_data["idx"] request_id = req_data["request_id"] - cache_length = req_data["cache_length"] - input_length = req_data["input_length"] prefix_offset = req_data["prefix_offset"] read_offset = req_data["read_offset"] stopping_criteria = req_data["stopping_criteria"] @@ -2041,9 +2061,6 @@ class FlashCausalLM(Model): n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]] top_token_ids = req_data["top_token_ids"] top_token_logprobs = req_data["top_token_logprobs"] - - new_input_length = 1 - new_cache_length = cache_length + input_length + n_accepted_ids - 1 # Append next token to all tokens next_token_texts = [] left = 0 @@ -2159,11 +2176,6 @@ class FlashCausalLM(Model): # Update values indexs[idx] += n_accepted_ids idx_accept_ids[idx] += 1 - batch.cache_lengths[i] = new_cache_length - batch.max_input_length = max(batch.max_input_length, new_input_length) - batch.input_lengths[i] = new_input_length - current_length = new_cache_length + new_input_length - batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index cdda751a..c885816b 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -408,7 +408,6 @@ class FlashVlmCausalLM(FlashCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -433,13 +432,14 @@ class FlashVlmCausalLM(FlashCausalLM): batch_size, bucketing_ctx=None, ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index d21cc39d..6a066185 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -247,7 +247,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -279,12 +278,13 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): indices, cross_attention_len = generate_cross_attention_states( cross_attention_states, image_indices, seqlen, 1, False ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, From 5ec7f15d0c61ea80eba606cfa18af3db0555cb1a Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 15 Apr 2025 00:27:07 -0700 Subject: [PATCH 34/35] prefill bypass graph Signed-off-by: Wang, Yi A --- .../server/text_generation_server/models/flash_causal_lm.py | 2 +- .../server/text_generation_server/models/mllama_causal_lm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 5c7b8bc0..8a5668a5 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1785,7 +1785,7 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling logits, speculative_logits = self.model.forward( input_ids=input_ids, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 6a066185..c1ea36f2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -455,7 +455,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = batch.prefilling if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) slots_pad[batch.prefill_cache_indices] = slots From bf3987e25e5e1bcfd7699ce46708b49b606e36c0 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 15 Apr 2025 21:56:51 -0700 Subject: [PATCH 35/35] pingpong optimization issue fix Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 22 ++++++++++++------- backends/v3/src/block_allocator.rs | 4 +++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 8a5668a5..ecedd4aa 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -615,6 +615,12 @@ class FlashCausalLMBatch(Batch): max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] + next_token_logits = self.next_token_logits[indices] + speculative_logits = ( + self.speculative_logits[indices] + if self.speculative_logits is not None + else None + ) block_tables_tensor = self.block_tables_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] @@ -696,8 +702,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, - next_token_logits=None, - speculative_logits=None, + next_token_logits=next_token_logits, + speculative_logits=speculative_logits, ) @classmethod @@ -825,8 +831,11 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - # Copy tensors (GPU) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # Copy tensors (HPU) + index = torch.tensor( + list(range(start_index, end_index)), device=batch.input_ids.device + ) + top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -834,7 +843,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor + prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor) slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) @@ -844,9 +853,6 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - index = torch.tensor( - list(range(start_index, end_index)), device=batch.input_ids.device - ) input_ids.index_copy_(0, index, batch.input_ids) position_ids.index_copy_(0, index, batch.position_ids) slot_indices.index_copy_( diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 6da2b51d..1628a00b 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator { (required_blocks, repeats) }; - let tokens = tokens as usize; + let mut tokens = tokens as usize; if required_blocks > self.free_blocks.len() as u32 { None } else { @@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator { .split_off(self.free_blocks.len() - required_blocks as usize); if self.is_hpu_device { blocks.sort(); + // need 1 slot for ping-pong optimization + tokens += 1; } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);