diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index eff87ab65..06073fe40 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -95,7 +95,7 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir - +RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 700f763e9..53837ef71 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -16,15 +16,9 @@ app = typer.Typer() class Quantization(str, Enum): - bitsandbytes = "bitsandbytes" - bitsandbytes_nf4 = "bitsandbytes-nf4" - bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" - eetq = "eetq" - exl2 = "exl2" fp8 = "fp8" - marlin = "marlin" class Dtype(str, Enum): @@ -105,6 +99,9 @@ def serve( "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4", + "gptq", + "awq", + "fp8", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." @@ -112,7 +109,7 @@ def serve( logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - if sharded: + if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: tgi_file = Path(__file__).resolve().parent / "tgi_service.py" num_shard = int(os.getenv("WORLD_SIZE", "1")) logger.info("CLI SHARDED = {}".format(num_shard)) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 4d83a11fc..9ba9f6e08 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -1,43 +1,28 @@ -from text_generation_server.utils.import_utils import SYSTEM -import os +from .common import ( + Seqlen, + HPUPagedAttentionMetadata, + trim_attn_metadata, + trim_seqlen_metadata, +) -from .common import Seqlen +from .hpu import ( + SUPPORTS_WINDOWING, + attention, + paged_attention, +) -if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") -if SYSTEM == "cuda": - from .cuda import ( - attention, - paged_attention, - reshape_and_cache, - SUPPORTS_WINDOWING, - PREFILL_IN_KV_CACHE, - ) -elif SYSTEM == "rocm": - from .rocm import ( - attention, - paged_attention, - reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, - ) -elif SYSTEM == "ipex": - from .ipex import ( - attention, - paged_attention, - reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, - ) -else: - raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") +# KVCache needs `reshape_and_cache`, so ensure that it is defined already. +from .kv_cache import KVCache, get_kv_scales __all__ = [ "attention", + "get_kv_scales", "paged_attention", - "reshape_and_cache", - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", + "KVCache", "Seqlen", + "HPUPagedAttentionMetadata", + "trim_seqlen_metadata", + "trim_attn_metadata", ] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index d6e512c01..8ec9fb461 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -1,72 +1,147 @@ from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION import torch -from typing import Optional +from typing import Optional, List, Dict +import collections + +_TYPE_CACHE = {} -if ATTENTION in {"flashinfer", "flashdecoding"}: +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] - max_q: int - max_k: int + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] + attn_bias: Optional[torch.Tensor] - def __init__( - self, - input_lengths, - prefix_lengths, - cu_seqlen_q=None, - max_q=None, - max_k=None, - ): - self.input_lengths = input_lengths - self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - max_q = 1 - else: - assert max_q is not None - assert max_k is not None - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) +def subtuple( + obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None, +): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + if isinstance(obj, dict): + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields)) + return _TYPE_CACHE[typename](**values) - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k - self.max_q = max_q - self.max_k = max_k - def clamp(self, max): - # Flash decoding doesn't need to clamp - return self +def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. -else: + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: torch.Tensor - max_q: int - max_k: int + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple( + metadata, + "TrimmedAttentionMetadata", + [ + "block_list", + "block_mapping", + "block_usage", + "block_scales", + "block_groups", + "attn_bias", + ], + ) + return attention_metadata - def clamp(self, max): - if SYSTEM == "rocm": - return self - raise NotImplementedError("Not implemented seqlen for paged") - return Seqlen(torch.clamp(self.input_lengths, max=max)) + +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cache_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__( + self, + input_lengths, + cache_lengths, + cu_seqlen_q=None, + ): + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.cache_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + + +def trim_seqlen_metadata(metadata: Seqlen) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple( + metadata, + "TrimmedSeqlen", + [ + "input_lengths", + "cache_lengths", + "cu_seqlen_q", + "cu_seqlen_k", + ], + ) + return attention_metadata diff --git a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py b/backends/gaudi/server/text_generation_server/layers/attention/cuda.py deleted file mode 100644 index 51af928d5..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py +++ /dev/null @@ -1,357 +0,0 @@ -import torch -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ( - ATTENTION, - BLOCK_SIZE, -) -from text_generation_server.layers.attention import Seqlen -from typing import Optional - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE = 512 - -try: - from vllm._C import cache_ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION in {"flashdecoding", "flashinfer"}: - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - # value_cache => [num_blocks, num_heads, head_size, block_size] - # block_size = value_cache.shape[3] - block_size = BLOCK_SIZE - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import decode_state - - return decode_state.get().forward( - query.contiguous(), - paged_kv_cache=(key_cache, value_cache), - logits_soft_cap=softcap, - sm_scale=softmax_scale, - ) - elif ATTENTION == "flashdecoding": - max_q = 1 - max_k = max_s - import flash_attn_2_cuda - - # TODO fixme when flash contains the fix. - # Number of splits is not correctly handled - # by the current path - # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 - # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. - if softcap is None: - softcap = 0.0 - out = flash_attn_2_cuda.varlen_fwd( - query, - key_cache, - value_cache, - None, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, # pad_k - None, - block_tables, - None, - max_q, - max_k, - 0.0, # dropout - softmax_scale, - False, # zero_tensors - True, # causal - -1, # Window_left - -1, # Window right - softcap, - False, # return softmax - None, # generator - ) - return out[0] - else: - if softcap is not None: - raise RuntimeError("Paged attention doesn't support softcapping") - input_lengths = seqlen.input_lengths - from vllm._C import ops - - out = torch.empty_like(query) - - use_v1 = max_s <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 - ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - return out - - -try: - is_ampere_or_newer = major >= 8 and minor >= 0 - if not is_ampere_or_newer: - raise ImportError("FlashAttention only supports Ampere GPUs or newer.") - - import flash_attn_2_cuda - - V2 = True -except ImportError: - try: - import flash_attn_cuda - - V2 = False - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - - -SUPPORTS_WINDOWING = V2 - -if ATTENTION == "flashinfer": - - def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - from text_generation_server.layers.attention.flashinfer import ( - prefill_with_paged_kv_state, - ) - - return prefill_with_paged_kv_state.get().forward( - q.contiguous(), - causal=causal, - paged_kv_cache=(key_cache, value_cache), - logits_soft_cap=softcap, - sm_scale=softmax_scale, - window_left=window_size_left, - ) - -elif V2: - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] - -else: - - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - if softcap is not None: - raise NotImplementedError("softcap is only available with flash attn v2") - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - False, - 0, - None, - ) - return out - - -# Prefill in the cache with every kind of attention, unless we -# have a configuration that requires flash-attention v1, which -# does not support block tables. -PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py b/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py deleted file mode 100644 index 3a6f9a730..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py +++ /dev/null @@ -1,813 +0,0 @@ -#!/usr/bin/env python -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team - -Features supported: - -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: - -1) Non power of two head dims - -""" - -import torch -import triton -import triton.language as tl - -torch_dtype: tl.constexpr = torch.float16 - - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets( - philox_seed, philox_offset, dropout_p, m, n, stride - ).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) - rng_keep = rng_output > dropout_p - return rng_keep - - -@triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): - # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn( - bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" - ) - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = ( - batch_philox_offset - + start_m * BLOCK_M * actual_seqlen_k - + start_n - - BLOCK_N - ) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, BLOCK_N) - ) - return acc, l_i, m_i - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config( - { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - ], - key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], -) -@triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = MAX_SEQLENS_K - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N - ) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = ( - off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - ) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE - else: - off_h_k = off_h_q - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k - ) - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_HEAD, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, n_full_blocks) - ) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_HEAD, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full( - (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 - ) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - - -class _attention(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): - if o is None: - o = torch.empty_like(q, dtype=v.dtype) - - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) - - def grid(META): - return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch - - encoded_softmax = None - - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) - else: - bias_strides = (0, 0, 0, 0) - - attn_fwd[grid]( - q, - k, - v, - bias, - sm_scale, - None, - o, - *q_strides, - *k_strides, - *v_strides, - *o_strides, - *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, - philox_seed=philox_seed, - philox_offset_base=philox_offset, - encoded_softmax=encoded_softmax, - HQ=nheads_q, - HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, - BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) - - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False - return o, encoded_softmax - - -triton_attention = _attention.apply diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py b/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py deleted file mode 100644 index d603c6f5f..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Optional -from contextvars import ContextVar -from contextlib import contextmanager - -import flashinfer -import torch - -prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( - "prefill_state" -) - -prefill_with_paged_kv_state: ContextVar[ - flashinfer.BatchPrefillWithPagedKVCacheWrapper -] = ContextVar("prefill_with_paged_kv_state") - -decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( - "decode_state" -) - -workspace: Optional[torch.Tensor] = None - - -def get_workspace(device): - """Get shared flashinfer workspace.""" - global workspace - if workspace is None: - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - return workspace - - -def create_prefill_with_paged_kv_state( - *, - device: torch.device, -): - """Create a prefill state that uses the KV cache.""" - workspace_buffer = get_workspace(device) - return flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD", use_cuda_graph=False - ) - - -@contextmanager -def use_prefill_with_paged_kv_state( - *, - state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, - block_tables: torch.Tensor, - cu_seqlens: torch.Tensor, - input_lengths: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - page_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - - # Get the lengths of the last page in a block. - if page_size == 1: - last_page_len = torch.ones( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - else: - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 - - token = prefill_with_paged_kv_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - paged_kv_indptr=indptr, - paged_kv_indices=block_tables, - paged_kv_last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - page_size=page_size, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_with_paged_kv_state.reset(token) - - -def create_prefill_state( - *, - device: torch.device, -): - """Create a prefill state.""" - workspace_buffer = get_workspace(device) - return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, kv_layout="NHD", use_cuda_graph=False - ) - - -@contextmanager -def use_prefill_state( - *, - state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, - cu_seqlens: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - token = prefill_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - kv_indptr=cu_seqlens, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_state.reset(token) - - -def create_decode_state( - *, - device: torch.device, - num_heads: int, - num_kv_heads: int, -): - """Create a decode state.""" - workspace_buffer = get_workspace(device) - num_groups = num_heads // num_kv_heads - return flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout="NHD", - use_cuda_graph=False, - # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 - use_tensor_cores=num_groups not in [1, 2, 4, 8], - ) - - -def create_decode_state_cuda_graphs( - *, - device: torch.device, - block_tables: torch.Tensor, - block_tables_ptr: torch.Tensor, - last_page_len: torch.Tensor, - num_heads: int, - num_kv_heads: int, -): - """ - Create a decode state for use with CUDA Graphs. `block_tables`, - `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are - therefore stored as part of the state. - """ - workspace_buffer = get_workspace(device) - num_groups = num_heads // num_kv_heads - return flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout="NHD", - use_cuda_graph=True, - paged_kv_indices_buffer=block_tables, - paged_kv_indptr_buffer=block_tables_ptr, - paged_kv_last_page_len_buffer=last_page_len, - # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 - use_tensor_cores=num_groups not in [1, 2, 4, 8], - ) - - -@contextmanager -def use_decode_state( - *, - state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, - input_lengths: torch.Tensor, - block_tables: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - page_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer decoding state to the given - `state` and parameters. This state will be used by all calls to the - `paged_attention` function while the context manager is active. - """ - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - - # Get the lengths of the last page in a block. - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 - - token = decode_state.set(state) - - try: - state.begin_forward( - indptr=indptr, - indices=block_tables, - last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - page_size=page_size, - data_type=dtype, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - decode_state.reset(token) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py new file mode 100644 index 000000000..f34e93abc --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -0,0 +1,95 @@ +import torch +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from typing import Optional +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales +from vllm_hpu_extension import ops +from vllm_hpu_extension.utils import Matmul +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA +import os + +SUPPORTS_WINDOWING = False + + +def fetch_from_cache(cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + return cache[: blocks.size(0)] + else: + return cache.index_select(0, blocks) + + +def attention( + *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: KVCache, + kv_scales: KVScales, + seqlen: Seqlen, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, +): + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + bs = seqlen.input_lengths.shape[0] + _, head_num, head_size = query.shape + _, kv_head_num, head_size = key.shape + query = query.view(bs, -1, head_num, head_size).transpose(1, 2) + key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=seqlen.input_lengths, + padding_side="left", + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + return attn_output + + +def paged_attention( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, +): + batch_size, head_num, head_size = query.shape + output = ops.flat_pa( + query=query.view(batch_size, 1, head_num * head_size), + key_cache=kv_cache.key, + value_cache=kv_cache.value, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_scales=hpu_attention_meta.block_scales, + block_groups=hpu_attention_meta.block_groups, + scale=softmax_scale, + matmul_qk_op=Matmul(), + matmul_av_op=Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=fetch_from_cache, + values_fetch_func=fetch_from_cache, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, head_size) + + +__all__ = [ + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", +] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py b/backends/gaudi/server/text_generation_server/layers/attention/ipex.py deleted file mode 100644 index 657c90af4..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py +++ /dev/null @@ -1,82 +0,0 @@ -import intel_extension_for_pytorch as ipex -import torch -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE -from text_generation_server.layers.attention import Seqlen -from typing import Optional - -SUPPORTS_WINDOWING = False -PREFILL_IN_KV_CACHE = False - - -def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap: Optional[float] = None, -): - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - q.contiguous() if q.device.type == "xpu" else q, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) - - return out - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - out = torch.empty_like(query) - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - seqlen.input_lengths, - BLOCK_SIZE, - max_s, - None, - ) - return out diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py new file mode 100644 index 000000000..d238cdb97 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -0,0 +1,139 @@ +from typing import Tuple +from dataclasses import dataclass, field + +import torch + +from text_generation_server.models.globals import BLOCK_SIZE +from text_generation_server.utils.weights import Weights +from vllm_hpu_extension import cache_ops + + +@dataclass +class KVScales: + """ + Key-value scales for FP8 KV cache. + + This data class stores key and value scales both as a GPU tensor and + as a GPU float. This inconvenience is necessary because some functions + (e.g. scaling kernels) take scales as a GPU tensor, whereas others + (e.g. flashinfer) take scales as a CPU scalar. + """ + + key_scale: torch.Tensor + value_scale: torch.Tensor + key_scale_cpu: float = field(init=False) + value_scale_cpu: float = field(init=False) + + def __post_init__(self): + if self.key_scale.numel() != 1 or self.value_scale.numel() != 1: + raise ValueError("Key and value scales must be scalar tensors.") + + self.key_scale_cpu = self.key_scale.item() + self.value_scale_cpu = self.value_scale.item() + + +class KVCache: + """ + Key-value cache for attention layers. + """ + + kv_cache: Tuple[torch.Tensor, torch.Tensor] + + def __init__( + self, + *, + num_blocks: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + + self.kv_cache = ( + torch.zeros( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.zeros( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache[0].dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache[0] + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache[1] + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + + key_cache = self.kv_cache[0] + value_cache = self.kv_cache[1] + + paged_reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, + ) + + +def paged_reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, +): + block_idx = slots // BLOCK_SIZE + block_offset = slots % BLOCK_SIZE + cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + + +def get_kv_scales(weights: Weights, prefix: str) -> KVScales: + """Load KV cache scales.""" + + key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device) + value_scale = key_scale + if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor( + f"{prefix}.v_scale" + ): + key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float() + value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float() + elif weights.has_tensor(f"{prefix}.kv_scale"): + # Fall back to older more coarse-grained scale when available. + key_scale = weights.get_tensor(f"{prefix}.kv_scale").float() + value_scale = key_scale + + return KVScales(key_scale=key_scale, value_scale=value_scale) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py b/backends/gaudi/server/text_generation_server/layers/attention/rocm.py deleted file mode 100644 index 646a763d3..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py +++ /dev/null @@ -1,308 +0,0 @@ -import os -from typing import Optional -import torch -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION -from text_generation_server.layers.attention import Seqlen -from text_generation_server.utils.log import log_master -from loguru import logger - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 - -_PARTITION_SIZE_V1V2 = 512 -_PARTITION_SIZE_CUSTOM = 256 - -use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} -ENGINE = "triton" if use_triton else "ck" - -PREFILL_IN_KV_CACHE = False - -use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" -try: - if use_rocm_custom_paged_attn: - from vllm._custom_C import paged_attention_custom -except ImportError as e: - log_master( - logger.info, - f"Custom Paged Attention not available. Complete error: {e}", - ) - use_rocm_custom_paged_attn = False - -try: - import vllm._custom_ops as ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION == "flashdecoding": - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - if softcap is not None: - raise RuntimeError("Paged attention doesn't support softcapping") - - # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - - num_kv_heads = key_cache.shape[1] - gqa_ratio = num_heads // num_kv_heads - use_custom = ( - use_rocm_custom_paged_attn - and (query.dtype == torch.half or query.dtype == torch.bfloat16) - and (head_size == 128 or head_size == 64) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_s <= 32768 - ) - - if not use_custom: - _PARTITION_SIZE = _PARTITION_SIZE_V1V2 - else: - _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM - - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = seqlen.input_lengths - - out = torch.empty_like(query) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - import vllm._custom_ops as ops - - use_v1 = ( - max_s <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512) - and not use_custom - ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - if not use_custom: - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - paged_attention_custom( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - ) - - return out - - -if ENGINE != "triton": - try: - import flash_attn_2_cuda - - log_master( - logger.info, - "ROCm: using Flash Attention 2 Composable Kernel implementation.", - ) - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - for idx in range(torch.cuda.device_count()): - name = torch.cuda.get_device_name(idx) - if "MI210" not in name and "MI250" not in name: - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - raise ImportError( - f"AMD GPU with ROCm capability {major} {minor} is not supported" - ) from e - - -SUPPORTS_WINDOWING = False -if ENGINE == "ck": - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: float = 0.0, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - None, - None, - None, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] - -elif ENGINE == "triton": - from .flash_attn_triton import triton_attention - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: Optional[float] = None, - ): - if softcap is not None: - raise NotImplementedError("softcap is only available with CK flash attn") - - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - output, _ = triton_attention( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - causal, - softmax_scale, - ) - return output - -else: - raise RuntimeError(f"Unknown attention engine {ENGINE}") diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py new file mode 100644 index 000000000..856d7c281 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py @@ -0,0 +1,3 @@ +from .hpu import WQLinear + +__all__ = ["WQLinear"] diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py new file mode 100644 index 000000000..3af0131b3 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py @@ -0,0 +1,134 @@ +from typing import Optional +import torch +import torch.nn as nn + +try: + import habana_frameworks.torch.hpu # noqa: F401 + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + +AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + # unpacking columnwise + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + iweights = iweights.view(iweights.shape[0], -1) + + # unpacking columnwise + if qzeros is not None: + izeros = torch.bitwise_right_shift( + qzeros[:, :, None], shifts[None, None, :] + ).to( + torch.int8 # smallest dtype available + ) + izeros = izeros.view(izeros.shape[0], -1) + else: + izeros = qzeros + + return iweights, izeros + + +def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange( + iweights.shape[-1], + dtype=torch.int32, + device=izeros.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + if izeros is not None: + izeros = izeros[:, reverse_order_tensor] + iweights = iweights[:, reverse_order_tensor] + + return iweights, izeros + + +def unpack_weight_and_zeros(qweight, qzeros, bits): + # Unpack the qweight and qzeros tensors + iweight, izeros = unpack_awq(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + return iweight, izeros + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros( + (normal.shape[0], normal.shape[1] // 32 * bits), + dtype=torch.int32, + device=input.device, + ) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + + +class WQLinear(nn.Module): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.bias = bias + self._preprocessing() + + def _preprocessing(self): + device = self.qweight.device + weight, zeros = unpack_weight_and_zeros( + self.qweight.cpu(), self.qzeros.cpu(), self.w_bit + ) + self.qweight = pack_tensor(weight).to(device) + self.qzeros = pack_tensor(zeros).to(device) + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + x = x.reshape(-1, x.shape[-1]) + weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + outputs = torch.matmul(x, weights) + + outputs = outputs + self.bias if self.bias is not None else outputs + outputs = outputs.reshape(out_shape) + return outputs diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py deleted file mode 100644 index 391371a55..000000000 --- a/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py - -from typing import Optional -import torch -import torch.nn as nn -import awq_inference_engine # with CUDA kernels - - -# class ScaledActivation(nn.Module): -# def __init__(self, module, scales): -# super().__init__() -# self.act = module -# self.scales = nn.Parameter(scales.data) -# -# def forward(self, x): -# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) - - -class WQLinear(nn.Module): - def __init__( - self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] - ): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.in_features = qweight.shape[0] - self.out_features = qweight.shape[1] * 32 // w_bit - - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else self.in_features - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert self.out_features % (32 // self.w_bit) == 0 - - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.bias = bias - - @torch.no_grad() - def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features,) - out = awq_inference_engine.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/eetq.py b/backends/gaudi/server/text_generation_server/layers/eetq.py deleted file mode 100644 index b1e5235a0..000000000 --- a/backends/gaudi/server/text_generation_server/layers/eetq.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass - -import torch -from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import UnquantizedWeight - - -@dataclass -class EETQWeight(UnquantizedWeight): - weight: torch.Tensor - - def get_linear(self, bias: torch.Tensor): - try: - from text_generation_server.layers.eetq import EETQLinear - - return EETQLinear(self.weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - - -class EETQLinear(torch.nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - device = weight.device - if weight.dtype != torch.float16: - weight = weight.to(dtype=torch.float16) - weight = torch.t(weight).contiguous().cpu() - weight, scale = quant_weights(weight, torch.int8, False) - - self.weight = weight.cuda(device) - self.scale = scale.cuda(device) - self.bias = bias.cuda(device) if bias is not None else None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output = w8_a16_gemm(input, self.weight, self.scale) - output = output + self.bias if self.bias is not None else output - return output diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 61dd51151..0dc5cdafd 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -1,100 +1,152 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Type, Union, List + import torch -from dataclasses import dataclass -from typing import Optional, Union, List -from loguru import logger - -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import ( Weight, WeightsLoader, UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_master, log_once -import importlib.util + +from vllm_hpu_extension.ops import scaled_fp8_quant +from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 +import habana_frameworks.torch.utils.experimental as htexp + +w8a8_block_fp8_matmul = None +per_token_group_quant_fp8 = None +quant_dtype: torch.dtype = torch.float8_e4m3fn -FBGEMM_MM_AVAILABLE = False -FBGEMM_DYN_AVAILABLE = False - - -def is_fbgemm_gpu_available(): - try: - return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None - except ModuleNotFoundError: - return False - - -if is_fbgemm_gpu_available(): - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - FBGEMM_MM_AVAILABLE = major == 9 - FBGEMM_DYN_AVAILABLE = major >= 8 -else: - log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") - - -def get_fp8_linear() -> torch.nn.Module: +def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. """ - - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - if major == 8: - from text_generation_server.layers.marlin import GPTQMarlinFP8Linear - - return GPTQMarlinFP8Linear - # On other systems let Torch decide if the hardware supports FP8. return Fp8Linear -def fp8_quantize( - weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False -): - if FBGEMM_DYN_AVAILABLE and not scalar: - qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( - weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype - ) - return qweight, scale +def normalize_e4m3fn_to_native_float8( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return weight, weight_scale, input_scale - # weight, scale = quant_weights(weight, torch.int8, False) - finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) - scale = scale.float().reciprocal() - return qweight, scale + +def per_tensor_dequantize( + tensor: torch.Tensor, + inv_scale: Union[float, torch.Tensor], + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + device = tensor.device + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + # dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to("cpu") + + fake_qweight = tensor.to(dtype).to(device) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def requantize_with_max_scale( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: int, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + if is_hpu_gaudi2(): + max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor() + + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx], dtype + ) + weight[start:end, :], max_w_scale_normalized = fp8_quantize( + weight_dq, max_w_scale + ) + start = end + + return weight, max_w_scale_normalized + + +def fp8_quantize( + weight: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[torch.Tensor] = None, + qdtype: torch.dtype = torch.float8_e4m3fn, + scalar: bool = False, +): + """ + This function returns a reciprocal of the scale, so that a tensor can be unscaled + by multiplying it with the returned scale. If a scale is given through the `scale` + argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can + be used without modification). + """ + shape = weight.shape + qweight, scale = scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + scale=scale, + scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, + ) + + return qweight.reshape(shape), scale class HybridFP8UnquantLoader(WeightsLoader): """Weight loader that loads FP8 and unquantized Torch tensors.""" - def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + def __init__( + self, + activation_scale_ub: Optional[float], + to_fp8: bool, + weight_block_size: Optional[List[int]] = None, + ): self.activation_scale_ub = activation_scale_ub self.to_fp8 = to_fp8 + self.weight_block_size = weight_block_size def get_weights(self, weights: "Weights", prefix: str): w = weights.get_tensor(f"{prefix}.weight") if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = weights.get_tensor(f"{prefix}.weight_scale_inv") + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) # FP8 branch - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -116,6 +168,7 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: scale = weights.get_packed_sharded( f"{prefix}.weight_scale", @@ -123,11 +176,29 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -148,15 +219,48 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -169,14 +273,35 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) + if self.weight_block_size is not None: + # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems. + scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) + + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype ) return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -191,83 +316,126 @@ class Fp8Weight(Weight): weight: torch.Tensor dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None + force_w8a16: bool = False + weight_block_size: Optional[List[int]] = None def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: - return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant( + self.weight, bias, self.dtype + ) # This is not checked by the fbgemm kernels, but they require contiguous # memory. Can be non-contiguous when we e.g. expand from scalars. self.weight_scale = self.weight_scale.contiguous() - return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8( + weight=self.weight, + scale=self.weight_scale, + dtype=self.dtype, + bias=bias, + input_scale=self.input_scale, + scale_upper_bound=self.activation_scale_ub, + weight_block_size=self.weight_block_size, ) class Fp8Linear(torch.nn.Module): + _device_identity_cache = {} + def __init__( self, - qweight, - scale, - scale_upper_bound, - bias, - dtype, + qweight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[float] = None, + weight_block_size: Optional[List[int]] = None, ) -> None: super().__init__() - if FBGEMM_MM_AVAILABLE: - log_once(logger.info, "Using FBGEMM fp8 optimized kernels") self.dtype = dtype self.qweight = qweight - self.scale = scale - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device - ) - if scale_upper_bound is not None - else None - ) + self.scale = scale.float() + self.input_scale = input_scale.float() if input_scale is not None else None + self.weight_block_size = weight_block_size + self.scale_upper_bound = scale_upper_bound self.bias = bias if bias is not None else None @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=True) return cls( - qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + qweight=qweight, + scale=scale, + dtype=dtype, + bias=bias, + input_scale=None, + scale_upper_bound=None, ) @classmethod - def from_fp8(cls, weight, scale, input_scale, bias, dtype): - if FBGEMM_DYN_AVAILABLE: - # fbgemm needs float32 scales. - scale = scale.float() + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> "Fp8Linear": + input_scale = kwargs.get("input_scale", None) + scale_upper_bound = kwargs.get("scale_upper_bound", None) + weight_block_size = kwargs.get("weight_block_size", None) + return cls( qweight=weight, scale=scale, - scale_upper_bound=input_scale, + input_scale=input_scale, + scale_upper_bound=scale_upper_bound, bias=bias, dtype=dtype, + weight_block_size=weight_block_size, ) - def forward(self, input: torch.Tensor) -> torch.Tensor: - if FBGEMM_MM_AVAILABLE: - qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound - ) + @classmethod + def get_shared_device_identity(cls, device): + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale + if device not in cls._device_identity_cache: + cls._device_identity_cache[device] = torch.ones(1, device=device) + return cls._device_identity_cache[device] - y = torch.ops.fbgemm.f8f8bf16_rowwise( + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.weight_block_size is not None: + # https://arxiv.org/pdf/2412.19437 + # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and + # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we + # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output + # channels). + qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) + output = w8a8_block_fp8_matmul( qinput, self.qweight, scale, self.scale, - use_fast_accum=True, - bias=self.bias, + self.weight_block_size, + output_dtype=input.dtype, ) - return y.to(self.dtype) - qinput, scale = fp8_quantize(input, scalar=True) - output, _ = torch._scaled_mm( + if self.bias is not None: + output = output + self.bias + return output.to(dtype=input.dtype) + + qinput, scale = fp8_quantize( + input, + self.input_scale, + scale_upper_bound=self.scale_upper_bound, + scalar=True, + ) + + output = torch._scaled_mm( qinput, self.qweight.t(), out_dtype=self.dtype, @@ -275,11 +443,16 @@ class Fp8Linear(torch.nn.Module): scale_b=self.scale, bias=self.bias, ) + + if isinstance(output, tuple) and len(output) == 2: + output = output[0] + return output def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1).expand(shape[0]) + return scale.reshape(-1) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 505caa59a..90b8f6923 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -1,14 +1,15 @@ -import os from dataclasses import dataclass from typing import List, Optional, Union import torch from loguru import logger -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from .hpu import QuantLinear + + @dataclass class GPTQWeight(Weight): qweight: torch.Tensor @@ -30,13 +31,8 @@ class GPTQWeight(Weight): def get_linear(self, bias: torch.Tensor): if self.use_awq_kernel: - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear + from text_generation_server.layers.awq.quantize import WQLinear return WQLinear( w_bit=self.bits, @@ -50,18 +46,7 @@ class GPTQWeight(Weight): raise NotImplementedError( "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" ) - elif self.use_exllama: - try: - from text_generation_server.layers.gptq import ExllamaQuantLinear - except ImportError: - raise NotImplementedError( - "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - return ExllamaQuantLinear(self, bias) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - return QuantLinear( self.qweight, self.qzeros, @@ -118,23 +103,6 @@ class GPTQWeightsLoader(WeightsLoader): else: g_idx = None - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") @@ -247,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader): [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 ) - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - self.bits == 4 - and HAS_EXLLAMA - and self.quantize == "gptq" - and not self.desc_act - ) + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act if self.quantize == "gptq" and self.quant_method == "gptq": w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] @@ -298,6 +259,7 @@ class GPTQWeightsLoader(WeightsLoader): self._get_gptq_params(weights) use_exllama = True + desc_act = self.desc_act if self.bits != 4: use_exllama = False @@ -321,7 +283,8 @@ class GPTQWeightsLoader(WeightsLoader): if g_idx is not None: if ( not torch.equal( - g_idx.cpu(), + # Remove g_idx[0] to adapt the check with TP>1. + (g_idx - g_idx[0]).cpu(), torch.tensor( [i // self.groupsize for i in range(g_idx.shape[0])], dtype=torch.int32, @@ -332,34 +295,22 @@ class GPTQWeightsLoader(WeightsLoader): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_exllama = False + desc_act = True from text_generation_server.layers.gptq import ( - CAN_EXLLAMA, - HAS_EXLLAMA, GPTQWeight, ) - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and self.groupsize != -1: + if not desc_act and self.groupsize != -1: qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0) + if g_idx is not None: + # qzeros, scales sharded, and g_idx must be adjusted accordingly + g_idx = g_idx - g_idx[0] else: qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - if self.quantize == "gptq" and self.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." @@ -392,7 +343,7 @@ class GPTQWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False @@ -400,41 +351,7 @@ class GPTQWeightsLoader(WeightsLoader): # before the `gptq_sym` setting tensor was added. self.sym = ( weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") + if weights.has_tensor("gptq_sym") else False ) self.quant_method = "gptq" - - -# Needs to be at the end because circular import. -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, # noqa: F401 - create_exllama_buffers, # noqa: F401 - set_device, # noqa: F401 - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 - create_exllama_buffers, # noqa: F401 - set_device, # noqa: F401 - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py b/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py deleted file mode 100644 index 0388ef20b..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py +++ /dev/null @@ -1,261 +0,0 @@ -# https://github.com/fpgaminer/GPTQ-triton -""" -Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. -""" - -import builtins -import math -import time -from typing import Dict - -import triton - - -class Autotuner(triton.KernelInterface): - def __init__( - self, - fn, - arg_names, - configs, - key, - reset_to_zero, - prune_configs_by: Dict = None, - nearest_power_of_two: bool = False, - ): - """ - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results - """ - if not configs: - self.configs = [triton.Config({}, num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.nearest_power_of_two = nearest_power_of_two - self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = ( - prune_configs_by["perf_model"], - prune_configs_by["top_k"], - ) - if "early_config_prune" in prune_configs_by: - early_config_prune = prune_configs_by["early_config_prune"] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - self.fn = fn - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.fn.run( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **current, - ) - - try: - # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses - # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - return triton.testing.do_bench( - kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 - ) - except triton.OutOfResources: - return [float("inf"), float("inf"), float("inf")] - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple(args[i] for i in self.key_idx) - - # This reduces the amount of autotuning by rounding the keys to the nearest power of two - # In my testing this gives decent results, and greatly reduces the amount of tuning required - if self.nearest_power_of_two: - key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) - - if key not in self.cache: - # prune configs - pruned_configs = self.prune_configs(kwargs) - bench_start = time.time() - timings = { - config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs - } - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.fn.run( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model( - **self.nargs, - **kwargs, - **config.kwargs, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ - :top_k - ] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - - -def autotune( - configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False -): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - .. highlight:: python - .. code-block:: python - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - - def decorator(fn): - return Autotuner( - fn, - fn.arg_names, - configs, - key, - reset_to_zero, - prune_configs_by, - nearest_power_of_two, - ) - - return decorator - - -def matmul248_kernel_config_pruner(configs, nargs): - """ - The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. - """ - m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) - n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) - k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) - - used = set() - for config in configs: - block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) - block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) - block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) - group_size_m = config.kwargs["GROUP_SIZE_M"] - - if ( - block_size_m, - block_size_n, - block_size_k, - group_size_m, - config.num_stages, - config.num_warps, - ) in used: - continue - - used.add( - ( - block_size_m, - block_size_n, - block_size_k, - group_size_m, - config.num_stages, - config.num_warps, - ) - ) - yield triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - }, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py deleted file mode 100644 index f27666b77..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py +++ /dev/null @@ -1,134 +0,0 @@ -from text_generation_server.layers.gptq import GPTQWeight -import torch -from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -def ext_make_q4(qweight, qzeros, scales, g_idx, device): - """Construct Q4Matrix, return handle""" - return make_q4( - qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device - ) - - -def ext_q4_matmul(x, q4, q4_width): - """Matrix multiplication, returns x @ q4""" - outshape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device) - - q4_matmul(x, q4, output) - - return output.view(outshape) - - -MAX_DQ = 1 -MAX_INNER = 1 -ACT_ORDER = False -DEVICE = None - -TEMP_STATE = None -TEMP_DQ = None - - -def set_device(device): - global DEVICE - DEVICE = device - - -def create_exllama_buffers(max_total_tokens: int): - global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ - - assert DEVICE is not None, "call set_device first" - - if not ACT_ORDER: - max_total_tokens = 1 - - # This temp_state buffer is required to reorder X in the act-order case. - temp_state = torch.zeros( - (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE - ) - temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) - - # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - prepare_buffers(DEVICE, temp_state, temp_dq) - - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - TEMP_STATE, TEMP_DQ = temp_state, temp_dq - - -class Ex4bitLinear(torch.nn.Module): - """Linear layer implementation with per-group 4-bit quantization of the weights""" - - def __init__(self, weight: GPTQWeight, bias): - super().__init__() - global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE - assert weight.bits == 4 - - self.device = weight.qweight.device - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None - self.bias = bias if bias is not None else None - - if self.g_idx is not None and ( - (self.g_idx == 0).all() - or torch.equal( - weight.g_idx.cpu(), - torch.tensor( - [i // weight.groupsize for i in range(weight.g_idx.shape[0])], - dtype=torch.int32, - ), - ) - ): - self.empty_g_idx = True - self.g_idx = None - - assert self.device.type == "cuda" - assert self.device.index is not None - - self.q4 = ext_make_q4( - self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index - ) - - self.height = weight.qweight.shape[0] * 8 - self.width = weight.qweight.shape[1] - - # Infer groupsize from height of qzeros - self.groupsize = None - if self.qzeros.shape[0] > 1: - self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) - - if self.groupsize is not None: - assert weight.groupsize == self.groupsize - - # Handle act-order matrix - if self.g_idx is not None: - if self.groupsize is None: - raise ValueError("Found group index but no groupsize. What do?") - self.act_order = True - else: - self.act_order = False - - DEVICE = self.qweight.device - - MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8) - - if self.act_order: - MAX_INNER = max(MAX_INNER, self.height, self.width) - - ACT_ORDER = True - - def forward(self, x): - out = ext_q4_matmul(x, self.q4, self.width) - - if self.bias is not None: - out.add_(self.bias) - return out diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py deleted file mode 100644 index 920a6adf4..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py +++ /dev/null @@ -1,267 +0,0 @@ -# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 - -from dataclasses import dataclass -from typing import Optional -import torch -import torch.nn as nn - -from loguru import logger - -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.utils.log import log_master - -try: - from exllamav2.ext import exllamav2_ext - - make_q_matrix = exllamav2_ext.make_q_matrix - gemm_half_q_half = exllamav2_ext.gemm_half_q_half -except ImportError: - log_master(logger.warning, "exllamav2_kernels not installed.") - raise - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -@dataclass -class _ExtraTensors: - """Additional generated quantizer tensors.""" - - q_group_map: Optional[torch.Tensor] = None - q_invperm: Optional[torch.Tensor] = None - q_perm: Optional[torch.Tensor] = None - - -def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): - """Matrix multiplication, returns x @ q4""" - output_shape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) - gemm_half_q_half(x, q_handle, output, force_cuda) - return output.view(output_shape) - - -def make_group_map(q_groups: torch.Tensor, num_qrows: int): - gr = q_groups.tolist() - group_map = [] - num_groups = len(gr) // 2 - - for i in range(num_groups): - bits = gr[i * 2] - if i < num_groups - 1: - qrows = gr[i * 2 + 3] - gr[i * 2 + 1] - else: - qrows = num_qrows - gr[i * 2 + 1] - rows = qrows * 32 // bits - for j in range(rows): - group_map += [i] - group_map += [rows - j] - - return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) - - -# Create Q matrix - - -def ext_make_q_matrix( - w: Exl2Weight | GPTQWeight, - extra: _ExtraTensors, - temp_dq, - key: Optional[str] = None, -): - """ - Create Q matrix - """ - # max_dq_size = 512*(1024**2) - # max_dq_rows = max_dq_size // out_features[0] - max_dq_rows = 0 - - # EXL2 - if isinstance(w, Exl2Weight): - extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) - extra.q_perm = torch.argsort(w.q_invperm).short() - - return make_q_matrix( - w.q_weight, - extra.q_perm, - w.q_invperm, - w.q_scale, - w.q_scale_max, - w.q_groups, - extra.q_group_map, - none_tensor, # zeros - none_tensor, # scales - none_tensor, # g_idx - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - # GPTQ - elif isinstance(w, GPTQWeight): - if w.scales.dtype == torch.float: - w.scales = w.scales.half() - - # GPTQ with g_idx (act_order) - if w.g_idx is not None and not (w.g_idx == 0).all().item(): - extra.q_perm = torch.empty( - (w.qweight.shape[0] * 8,), - dtype=torch.short, - device=w.qweight.device, - ) - extra.q_invperm = torch.empty_like(extra.q_perm) - # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return make_q_matrix( - w.qweight, - extra.q_perm, - extra.q_invperm, - none_tensor, # q_scale - none_tensor, # q_scale_max - none_tensor, # q_groups - none_tensor, # q_group_map - w.qzeros, - w.scales, - w.g_idx.cpu(), - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - # GPTQ without g_idx - else: - return make_q_matrix( - w.qweight, - none_tensor, # q_perm - none_tensor, # q_invperm - none_tensor, # q_scale - none_tensor, # q_scale_max - none_tensor, # q_groups - none_tensor, # q_group_map - w.qzeros, - w.scales, - none_tensor, # g_idx - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - else: - RuntimeError("Cannot create handle") - - -DEVICE = None -LAYERS = [] - - -def set_device(device): - global DEVICE - DEVICE = device - - -def create_exllama_buffers(max_total_tokens: int): - global LAYERS, DEVICE - - # No need to initialize scratch space if there are no layers - # that use ExLLamav2. - if len(LAYERS) == 0: - return - - # Find the size of the scratch space. - scratch_bytes = max( - layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) - for layer in LAYERS - ) - temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) - - for layer in LAYERS: - layer.post_init(temp_dq) - - -class QuantLinear(nn.Module): - QUANT_TYPE = "exllamav2" - - """Linear layer implementation with per-group 4-bit quantization of the weights""" - - def __init__( - self, - weight: Exl2Weight | GPTQWeight, - bias: torch.Tensor, - ): - super().__init__() - - self.q_handle = None - self.q_tensors = weight - self.extra_tensors = _ExtraTensors() - - if isinstance(weight, Exl2Weight): - self.infeatures = weight.q_invperm.shape[0] - self.outfeatures = weight.q_weight.shape[1] - elif isinstance(weight, GPTQWeight): - if weight.bits != 4: - raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." - ) - - self.infeatures = weight.qweight.shape[0] // weight.bits * 32 - self.outfeatures = weight.qweight.shape[1] - - self.padding = -self.outfeatures % 32 - self.outfeatures = self.outfeatures + self.padding - - self.device = weight.device - self.bias = bias if bias is not None else None - - global LAYERS - LAYERS.append(self) - - def post_init(self, temp_dq): - device = self.q_tensors.device - assert device.type == "cuda" - assert device.index is not None - temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - - # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, - # and `Memory access fault by GPU node-2` will EAT you. - self.temp_dq = temp_dq - self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) - - def forward(self, x, force_cuda=False): - output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) - - if self.bias is not None: - output.add_(self.bias) - return output - - def temp_dq_size(self): - return self.infeatures * self.outfeatures * 2 + 128 - - def temp_fwd_size(self, max_input_len, max_batch_size): - return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - - def scratch_space_fixed(self, max_input_len, max_batch_size): - return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) - - -class ExLlamaV2DeviceTensors: - - device_idx: int - scratch_bytes: int - scratch_idx: int - scratch: torch.tensor = None - - def __init__(self, device, scratch_bytes): - self.device = device - self.scratch_bytes = scratch_bytes - - def prepare(self): - self.scratch = torch.empty( - (self.scratch_bytes // 2,), dtype=torch.half, device=self.device - ) - - def get_scratch_slice(self, size_bytes): - - if self.scratch is None: - self.prepare() - - size_bytes = ((size_bytes + 127) // 128) * 128 - size_half = size_bytes // 2 - scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py new file mode 100644 index 000000000..72944fa0e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -0,0 +1,186 @@ +import math +import numpy as np +import torch +import torch.nn as nn + +try: + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self.wf = torch.tensor( + list(range(0, 32, self.bits)), dtype=torch.int32 + ).unsqueeze(0) + self._preprocessing() + + def unpack_zeros_from_cuda_old_format(self): + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0), + ).to(torch.int16 if self.bits == 8 else torch.int8) + + zeros = zeros + 1 + zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to( + self.scales.dtype + ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. + zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2]) + return zeros + + def unpack_weight_from_cuda_old_format(self): + weight = torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1), + ).to(torch.int16 if self.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**self.bits) - 1) + weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2])) + return weight + + def _preprocessing(self): + orig_device = self.qweight.device + self.qweight = self.qweight.cpu() + weight = self.unpack_weight_from_cuda_old_format() + new_qweight = pack_tensor(weight) + self.qweight = new_qweight.to(orig_device) + # TODO: Support group indexing and remove the check + columns = self.qweight.shape[0] + g_idx_trivial = [i // self.groupsize for i in range(columns)] + g_idx_trivial = torch.tensor( + g_idx_trivial, dtype=torch.int32, device=self.g_idx.device + ) + assert torch.equal( + self.g_idx, g_idx_trivial + ), "Non-trivial tensor g_idx is not supported" + self.qzeros = self.qzeros.cpu() + zeros = self.unpack_zeros_from_cuda_old_format() + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to(orig_device) + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + out = torch.matmul(x, weight) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py b/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py deleted file mode 100644 index 736c357b0..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py +++ /dev/null @@ -1,359 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_fwd - -import triton -import triton.language as tl -from . import custom_autotune - - -# code based https://github.com/fpgaminer/GPTQ-triton -@custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, -) -@triton.jit -def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # eventually avoid overflow - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - - def grid(META): - return ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) - return output - - -class QuantLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - return output - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py index b0086ea08..aa664ea60 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files -from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq import QuantLinear from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error @@ -956,15 +956,24 @@ def quantize( pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file - from transformers.modeling_utils import shard_checkpoint + from huggingface_hub import split_torch_state_dict_into_shards state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} max_shard_size = "10GB" - shards, index = shard_checkpoint( - state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + state_dict_split = split_torch_state_dict_into_shards( + state_dict, + filename_pattern="model.safetensors", + max_shard_size=max_shard_size, ) + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + shards = state_dict_split.filename_to_tensors os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): save_file( diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index ce5289f93..848787910 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -1,9 +1,6 @@ import torch from torch import nn from accelerate import init_empty_weights -from text_generation_server.utils.import_utils import ( - SYSTEM, -) # Monkey patching @@ -33,69 +30,14 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias -if SYSTEM == "cuda": - import dropout_layer_norm - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - -elif SYSTEM == "rocm": - from vllm._C import ops - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super().forward(hidden_states), residual - -elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - out = ipex.llm.functional.add_layer_norm( - residual, - hidden_states, - self.weight, - self.bias, - self.eps, - residual is not None, - ) - return out, residual if residual is not None else hidden_states + return super().forward(hidden_states), residual class FastRMSNorm(nn.Module): @@ -111,74 +53,15 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "ipex": - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - residual is not None, - ) - return out, residual if residual is not None else hidden_states - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states + from vllm_hpu_extension.kernels import rms_norm - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states, residual - elif SYSTEM == "cuda": - # faster post attention rms norm - ( - normed_hidden_states, - res, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual + orig_shape = hidden_states.shape + if residual is not None: + residual += hidden_states.view(residual.shape) else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + residual = hidden_states + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + if len(orig_shape) == 2: + residual = residual.unsqueeze(0) + x = rms_norm().apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual.view(orig_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/linear.py b/backends/gaudi/server/text_generation_server/layers/linear.py index 08306d579..cca80c44e 100644 --- a/backends/gaudi/server/text_generation_server/layers/linear.py +++ b/backends/gaudi/server/text_generation_server/layers/linear.py @@ -1,21 +1,5 @@ import torch -from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F -import os - -if SYSTEM == "rocm": - ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( - "true", - "1", - ) - - if ROCM_USE_SKINNY_GEMM: - try: - from vllm import _custom_C - except Exception as e: - raise ImportError( - f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" - ) class FastLinear(torch.nn.Module): @@ -44,83 +28,11 @@ class FastLinear(torch.nn.Module): return F.linear(input, self.weight, self.bias) -class FastLinearROCm(torch.nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.weight = torch.nn.Parameter(weight) - if bias is not None: - self.bias = torch.nn.Parameter(bias) - else: - self.bias = None - - self.cu_count = torch.cuda.get_device_properties( - device="cuda" - ).multi_processor_count - self.use_skinny_gemm = ( - ROCM_USE_SKINNY_GEMM - and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName - ) - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") - if bias: - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None - return cls(weight, bias) - - def forward(self, inp: torch.Tensor) -> torch.Tensor: - weight = self.weight - bias = self.bias - - if ( - self.use_skinny_gemm - and inp.dtype == torch.float16 - and inp.shape[-1] % 8 == 0 - ): - batched = False - inp_shape = inp.shape - - if inp.dim() == 3: - inp = inp.view(-1, inp_shape[-1]) - batched = True - - m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] - if m > 8 and n <= 4: - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device - ) - _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) - elif m % 4 == 0 and n == 1 and k <= 8192: - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device - ) - _custom_C.LLMM1(weight, inp, out, 4) - else: - out = F.linear(inp, weight) - - if batched: - out.view(*inp_shape[:-1], out.shape[-1]) - - if bias is not None: - out = out + bias - return out - return F.linear(inp, self.weight, self.bias) - - def get_linear(weight, bias): # Weights that are loaded through methods that are not # quantization-aware are still bare tensors. We may want # to change this in the future. if isinstance(weight, torch.Tensor): - if SYSTEM == "rocm": - return FastLinearROCm(weight, bias) - else: - return FastLinear(weight, bias) + return FastLinear(weight, bias) return weight.get_linear(bias) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py b/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py deleted file mode 100644 index 3ff3ed58f..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeightsLoader, - can_use_gptq_marlin, - repack_gptq_for_marlin, -) -from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader - -__all__ = [ - "GPTQMarlinFP8Linear", - "GPTQMarlinWeightsLoader", - "MarlinWeightsLoader", - "can_use_gptq_marlin", - "repack_gptq_for_marlin", -] diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py deleted file mode 100644 index fe55a58a3..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.fp8 import fp8_quantize -from text_generation_server.layers.marlin.gptq import _check_valid_shape -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - permute_scales, -) -from text_generation_server.utils.log import log_once - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - - -MARLIN_TILE_SIZE = 16 - - -class GPTQMarlinFP8Linear(nn.Module): - """ - FP8 GPTQ-Marlin linear layer. - """ - - def __init__( - self, - qweight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> None: - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - - scales = scales.unsqueeze(0) - if scales.shape[1] == 1: - out_features, in_features = qweight.shape - scales = scales.repeat(1, out_features) - qweight, scales = repack_fp8_for_marlin(qweight, scales) - - in_features = qweight.shape[0] * MARLIN_TILE_SIZE - out_features = scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.qweight = qweight - self.scales = scales - self.bias = bias if bias is not None else None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=qweight.device - ) - - @classmethod - def from_unquant(cls, weight, bias, dtype): - qweight, scales = fp8_quantize(weight) - return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) - - @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, dtype): - return cls(qweight=weight, scales=scale.to(dtype), bias=bias) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.fp8_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.workspace, - 8, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements). - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - - if fp8_tensor.shape[0] % 4 != 0: - raise ValueError( - f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" - ) - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = torch.zeros( - fp8_tensor.shape[0] // 4, - fp8_tensor.shape[1], - dtype=torch.int32, - device=fp8_tensor.device, - ) - - for i in range(4): - packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) - - return packed - - -def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): - """ - Repack FP8 tensor for GPTQ-Marlin. - """ - - out_features, in_features = weight.shape - - # Torch linear layers weights with shape [out_features, in_features], - # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], - # so transpose before packing. - qweight = pack_fp8_as_int32(weight.t()) - - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, 8 - ) - - scales = permute_scales(scales) - - return repacked, scales diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py deleted file mode 100644 index 0a785d944..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py +++ /dev/null @@ -1,464 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - marlin_zero_points, - permute_scales, - unpack_cols, -) -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -GPTQ_MARLIN_BITS = [4, 8] -GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] -MARLIN_TILE_SIZE = 16 - - -def can_use_gptq_marlin( - *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool -) -> bool: - return ( - SYSTEM == "cuda" - and marlin_kernels is not None - and has_sm_8_0 - and quantize in {"awq", "gptq"} - and quant_method in {"awq", "gptq"} - and bits in GPTQ_MARLIN_BITS - and groupsize in GPTQ_MARLIN_GROUP_SIZES - # We only suppord asymmetric quantization for AWQ. - and (sym or quant_method == "awq") - ) - - -class GPTQMarlinWeightsLoader(WeightsLoader): - """ - Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. - """ - - def __init__( - self, - *, - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - quantize: str, - sym: bool, - ): - self.bits = bits - self.desc_act = desc_act - self.groupsize = groupsize - self.quant_method = quant_method - self.quantize = quantize - self.sym = sym - - def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_tensor(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - scales = weights.get_tensor(f"{prefix}.scales") - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - try: - qweight = weights.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." - ) - scales = weights.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=weights.dtype) - - if not self.sym: - qzeros = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - try: - qweight = torch.cat( - [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - if not self.sym: - qzeros = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - if self.desc_act or self.groupsize == -1: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) - - def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): - self.bits = weights.get_tensor("gptq_bits").item() - self.groupsize = weights.get_tensor("gptq_groupsize").item() - self.desc_act = False - # `server quantize` used asymmetric quantization unconditionally - # before the `gptq_sym` setting tensor was added. - self.sym = ( - weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") - else False - ) - self.quant_method = "gptq" - - -@dataclass -class GPTQMarlinWeight(Weight): - """ - Repacked GPTQ Marlin weights. - """ - - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - bits: int - is_full_k: bool - - def __post_init__(self): - assert self.qweight.dtype == torch.int32 - assert self.scales.dtype == torch.float16 - assert self.g_idx.dtype == torch.int32 - assert self.perm.dtype == torch.int32 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlinLinear( - weight=self, - bias=bias, - ) - - -def repack_gptq_for_marlin( - *, - qweight: torch.Tensor, - qzeros: Optional[torch.Tensor], - scales: torch.Tensor, - g_idx: Optional[torch.Tensor], - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - sym: bool, - sharded_infeatures: bool, -) -> GPTQMarlinWeight: - """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" - _check_marlin_kernels() - assert marlin_kernels is not None - - if bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" - ) - - if groupsize not in GPTQ_MARLIN_GROUP_SIZES: - supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) - raise RuntimeError( - f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" - ) - if not (sym or quant_method == "awq"): - raise RuntimeError( - "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." - ) - - log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") - - weights_per_int = 32 // bits - in_features = qweight.shape[0] - out_features = qweight.shape[1] - - # AWQ uses column packing, GPTQ uses row packing - if quant_method == "awq": - out_features *= weights_per_int - else: - in_features *= weights_per_int - - if in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({groupsize})" - ) - - if g_idx is not None and desc_act and groupsize != -1: - perm = torch.argsort(g_idx).to(torch.int) - g_idx = g_idx[perm] - else: - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - - if quant_method == "awq": - repacked = marlin_kernels.awq_marlin_repack( - qweight, in_features, out_features, bits - ) - if qzeros is not None: - qzeros = awq_to_marlin_zero_points( - qzeros, - in_features // groupsize, - out_features, - bits, - ) - - else: - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) - - if qzeros is None: - qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) - - scales = permute_scales(scales) - - is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) - - return GPTQMarlinWeight( - qweight=repacked, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - bits=bits, - is_full_k=is_full_k, - ) - - -class GPTQMarlinLinear(nn.Module): - """ - Linear layer for GPTQ weights that were converted for the GPTQ-Marlin - kernels. - """ - - def __init__( - self, - *, - weight: GPTQMarlinWeight, - bias: Optional[torch.Tensor], - ): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE - out_features = weight.scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.bits = weight.bits - self.is_full_k = weight.is_full_k - - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx - self.perm = weight.perm - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.perm, - self.workspace, - self.bits, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, - self.qzeros.numel() > 0, - True, - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def _check_valid_shape(in_features: int, out_features: int): - if (in_features % 128 != 0 or out_features % 64 != 0) and ( - in_features % 64 != 0 or out_features % 128 != 0 - ): - raise ValueError( - f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." - " The shape elements must be divisible by (128, 64) or (64, 128)." - ) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py deleted file mode 100644 index 89ebaca62..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py +++ /dev/null @@ -1,346 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import torch -import torch.nn as nn -from text_generation_server.layers.marlin.util import _check_marlin_kernels -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - - -class MarlinWeightsLoader(WeightsLoader): - """Loader for Marlin-quantized weights.""" - - def __init__(self, *, bits: int, is_marlin_24: bool): - self.bits = bits - self.is_marlin_24 = is_marlin_24 - - def get_weights(self, weights: "Weights", prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = weights.get_tensor(f"{prefix}.B_24") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_tensor(f"{prefix}.B_meta") - s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = weights.get_tensor(f"{prefix}.B") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - s = weights.get_tensor(f"{prefix}.s") - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - if self.is_marlin_24: - B = weights.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = weights.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - B = weights.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - if self.is_marlin_24: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_row(self, weights: Weights, prefix: str): - if self.is_marlin_24: - try: - B = weights.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = weights.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - - return weight - - -@dataclass -class MarlinWeight(Weight): - """ - Marlin weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): bfloat16/float16 scales. - """ - - B: torch.Tensor - s: torch.Tensor - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.s.dtype in [torch.float16, torch.bfloat16] - - def get_linear(self, bias: torch.Tensor): - return MarlinLinear(weight=self, bias=bias) - - -class MarlinLinear(nn.Module): - def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE - out_features = weight.s.shape[1] - assert ( - in_features % 128 == 0 - ), f"Number of input features ({in_features}) not divisable by 128" - assert ( - out_features % 256 == 0 - ), f"Number of output features ({out_features}) not divisable by 256" - - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert groupsize in { - -1, - 128, - }, f"Group size must be -1 or 128, was {groupsize}" - - self.B = weight.B - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.B.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.marlin_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.s, - self.workspace, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_TILE_SIZE = 16 - - -@dataclass -class GPTQMarlin24Weight: - """ - GPTQ-Marlin 2:4 weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - B_meta (torch.Tensor): metadata for 2:4 sparsity. - s (torch.Tensor): float16 scales. - bits: quantized weight size. - """ - - B: torch.Tensor - B_meta: torch.Tensor - s: torch.Tensor - bits: int - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.B_meta.dtype == torch.int16 - assert self.s.dtype == torch.float16 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlin24Linear( - weight=self, - bias=bias, - ) - - -class GPTQMarlin24Linear(nn.Module): - def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: - supported_bits = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS - ) - raise RuntimeError( - f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" - ) - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.s.shape[1] - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - - if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: - supported_sizes = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ) - raise RuntimeError( - f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" - ) - - self.bits = weight.bits - weights_per_int32 = 32 // self.bits - - assert ( - out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" - assert ( - out_features % weights_per_int32 == 0 - ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" - - assert ( - in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" - if groupsize != -1 and in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisable by group size ({groupsize})" - ) - - self.B = weight.B - self.B_meta = weight.B_meta - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, - dtype=torch.int, - device=weight.B.device, - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.gptq_marlin_24_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.B_meta, - self.s, - self.workspace, - self.bits, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/util.py b/backends/gaudi/server/text_generation_server/layers/marlin/util.py deleted file mode 100644 index 250d17141..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/util.py +++ /dev/null @@ -1,141 +0,0 @@ -import functools -from typing import List, Tuple - -import numpy -import torch -from text_generation_server.utils.import_utils import SYSTEM - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def _check_marlin_kernels(): - if not (SYSTEM == "cuda" and has_sm_8_0): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) - - if marlin_kernels is None: - raise NotImplementedError( - "marlin is not installed, install it with: pip install server/marlin" - ) - - -# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 -@functools.cache -def get_perms() -> Tuple[List[int], List[int]]: - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def permute_scales(scales: torch.Tensor): - scale_perm, scale_perm_single = get_perms() - out_features = scales.shape[1] - if scales.shape[0] == 1: - scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - else: - scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] - return scales.reshape((-1, out_features)).contiguous() - - -# Functions below are from vLLM - - -def get_pack_factor(bits: int) -> int: - if 32 % bits != 0: - raise ValueError(f"Cannot {bits} bit values into uint32") - return 32 // bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - scale_perm, _ = get_perms() - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py index 2c46ca02a..8b9d6fcb0 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py @@ -10,13 +10,8 @@ from text_generation_server.layers import ( TensorParallelRowLinear, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader -from text_generation_server.layers.moe.gptq_marlin import ( - GPTQMarlinSparseMoELayer, - can_use_marlin_moe_gemm, -) from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, @@ -24,12 +19,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, ) -if SYSTEM == "rocm": - from .fused_moe_rocm import grouped_topk - from vllm.model_executor.layers.fused_moe import fused_topk -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_topk, grouped_topk - +from .fused_moe import fused_topk, grouped_topk # NOTE: we are using a protocol here, because multiple inherance is not nice. # We need `Module`, and `Module` -> some abstract class -> some concrete @@ -52,6 +42,8 @@ class MoELayer(Protocol): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): ... def forward( @@ -81,9 +73,14 @@ class DenseMoELayer(nn.Module): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): super().__init__() + assert scoring_func is None, "scoring func is not handled" + assert e_score_correction_bias is None, "scoring correction bias is not handled" + log_once( logger.info, "No fused layers are available for this model type, using (slower) dense MoE layer", @@ -199,22 +196,27 @@ class SparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", ): super().__init__() - if ( isinstance(weights.loader, DefaultWeightsLoader) and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): - cls = UnquantizedSparseMoELayer - elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: - cls = GPTQMarlinSparseMoELayer + if ( + isinstance(weights.loader, HybridFP8UnquantLoader) + and weights.loader.to_fp8 + ): + cls = FP8SparseMoELayer + else: + cls = UnquantizedSparseMoELayer else: raise ValueError( - f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights" ) log_once( @@ -230,6 +232,8 @@ class SparseMoELayer(nn.Module): topk=topk, topk_group=topk_group, weights=weights, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, down_proj_name=down_proj_name, @@ -241,17 +245,6 @@ class SparseMoELayer(nn.Module): @staticmethod def is_supported(weights: Weights) -> bool: return ( - ( - isinstance(weights.loader, DefaultWeightsLoader) - and isinstance(weights.loader.weight_class, UnquantizedWeight) - ) - or isinstance(weights.loader, HybridFP8UnquantLoader) - or ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) - and can_use_marlin_moe_gemm( - quant_method=weights.loader.quant_method, - quantize=weights.loader.quantize, - sym=weights.loader.sym, - ) - ) - ) + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) or isinstance(weights.loader, HybridFP8UnquantLoader) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py new file mode 100644 index 000000000..071b2abee --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -0,0 +1,173 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.fp8 import ( + Fp8Weight, + fp8_quantize, + quant_dtype, + normalize_e4m3fn_to_native_float8, +) + +try: + from .unquantized import fused_moe +except Exception: + fused_moe = None + + +class FP8SparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + + ( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.gate_up_proj_input_scale, + ) = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( + _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + use_fp8_w8a8=True, + w1_scale=self.gate_up_proj_weight_scale, + w2_scale=self.down_proj_weight_scale, + a1_scale=self.gate_up_proj_input_scale, + a2_scale=self.down_proj_input_scale, + ) + + +def _load_expert_weights( + get_weight_fn, + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + all_weight_scales = None + max_input_scale = None + + for i in range(n_experts): + weight = get_weight_fn(prefix, i, name, weights) + + assert isinstance(weight, Fp8Weight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=quant_dtype, + device=weight.weight.device, + ) + if all_weight_scales is None: + all_weight_scales = torch.empty( + (n_experts,) + weight.weight_scale.shape, + dtype=torch.float32, + device=weight.weight.device, + ) + + if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: + all_weight[i], all_weight_scales[i], current_input_scale = ( + normalize_e4m3fn_to_native_float8( + weight.weight, weight.weight_scale, weight.input_scale + ) + ) + if current_input_scale is not None: + if max_input_scale is None or current_input_scale > max_input_scale: + max_input_scale = current_input_scale + else: + all_weight[i], all_weight_scales[i] = fp8_quantize( + weight.weight, scalar=True + ) + + assert all_weight is not None + + return all_weight, all_weight_scales, max_input_scale + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + ) + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights_row(f"{prefix}.{i}.{name}") + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py similarity index 80% rename from backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py rename to backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py index 68accb990..e26ff8770 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py @@ -16,10 +16,8 @@ from typing import Tuple import torch -import torch.distributed -# TODO: Remove the functions once moe_kernel are built for ROCM def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -50,3 +48,18 @@ def grouped_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + topk_weights = torch.nn.functional.softmax( + gating_output, dim=1, dtype=torch.float32 + ) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py b/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py deleted file mode 100644 index 3217cdc22..000000000 --- a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py +++ /dev/null @@ -1,215 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import torch -import torch.nn as nn - -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeight, - GPTQMarlinWeightsLoader, -) - -if SYSTEM == "cuda": - from moe_kernels.fused_marlin_moe import fused_marlin_moe -else: - fused_marlin_moe = None - - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def can_use_marlin_moe_gemm( - *, - quant_method: str, - quantize: str, - sym: bool, -): - return ( - SYSTEM == "cuda" - and fused_marlin_moe is not None - and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" - and sym - ) - - -@dataclass -class GPTQMarlinMoEWeight: - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - is_full_k: bool - - -class GPTQMarlinSparseMoELayer(nn.Module): - """ - MoE layer that uses a fused GPTQ-Marlin kernel. - """ - - def __init__( - self, - *, - n_expert_group: Optional[int], - n_experts: int, - prefix: str, - renormalize: bool, - topk: int, - topk_group: Optional[int], - weights: Weights, - gate_proj_name: str = "gate_proj", - up_proj_name: str = "up_proj", - down_proj_name: str = "down_proj", - ): - super().__init__() - - if not ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym - ): - raise ValueError( - f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" - ) - - assert (n_expert_group is None) == ( - topk_group is None - ), "n_expert_group and topk_group must both be None or have some value" - - self.n_expert_group = n_expert_group - self.topk = topk - self.topk_group = topk_group - self.renormalize = renormalize - - self.gate_up_proj = _load_expert_multi_weights_col( - prefix=prefix, - n_experts=n_experts, - names=[gate_proj_name, up_proj_name], - weights=weights, - ) - - self.down_proj = _load_expert_weights_row( - prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights - ) - - self.bits = weights.loader.bits - - def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_marlin_moe( - x, - w1=self.gate_up_proj.qweight, - w2=self.down_proj.qweight, - g_idx1=self.gate_up_proj.g_idx, - g_idx2=self.down_proj.g_idx, - perm1=self.gate_up_proj.perm, - perm2=self.down_proj.perm, - w1_scale=self.gate_up_proj.scales, - w2_scale=self.down_proj.scales, - is_full_k1=self.gate_up_proj.is_full_k, - is_full_k2=self.down_proj.is_full_k, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - num_bits=self.bits, - ) - - -def _load_expert_multi_weights_col( - *, - prefix: str, - n_experts: int, - names: List[str], - weights: Weights, -) -> GPTQMarlinMoEWeight: - moe_weight = None - for i in range(n_experts): - weight = weights.get_multi_weights_col( - [f"{prefix}.{i}.{name}" for name in names], 0 - ) - assert isinstance(weight, GPTQMarlinWeight) - moe_weight = _pack_weight( - n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight - ) - assert moe_weight is not None - return moe_weight - - -def _load_expert_weights_row( - *, - prefix: str, - n_experts: int, - name: str, - weights: Weights, -) -> GPTQMarlinMoEWeight: - moe_weight = None - for i in range(n_experts): - weight = weights.get_weights_row( - f"{prefix}.{i}.{name}", - ) - assert isinstance(weight, GPTQMarlinWeight) - moe_weight = _pack_weight( - n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight - ) - assert moe_weight is not None - return moe_weight - - -def _pack_weight( - *, - n_experts: int, - expert: int, - moe_weight: Optional[GPTQMarlinMoEWeight], - weight: GPTQMarlinWeight, -) -> GPTQMarlinMoEWeight: - if moe_weight is None: - qweight = torch.empty( - (n_experts,) + weight.qweight.shape, - dtype=weight.qweight.dtype, - device=weight.qweight.device, - ) - qzeros = torch.empty( - (n_experts,) + weight.qzeros.shape, - dtype=weight.qzeros.dtype, - device=weight.qzeros.device, - ) - scales = torch.empty( - (n_experts,) + weight.scales.shape, - dtype=weight.scales.dtype, - device=weight.scales.device, - ) - g_idx = torch.empty( - (n_experts,) + weight.g_idx.shape, - dtype=weight.g_idx.dtype, - device=weight.g_idx.device, - ) - perm = torch.empty( - (n_experts,) + weight.perm.shape, - dtype=weight.perm.dtype, - device=weight.perm.device, - ) - - moe_weight = GPTQMarlinMoEWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - is_full_k=weight.is_full_k, - ) - - moe_weight.qweight[expert] = weight.qweight - moe_weight.qzeros[expert] = weight.qzeros - moe_weight.scales[expert] = weight.scales - moe_weight.g_idx[expert] = weight.g_idx - moe_weight.perm[expert] = weight.perm - - return moe_weight diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index d9d62c0ef..ec1583989 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -3,13 +3,8 @@ from typing import Optional import torch import torch.nn as nn -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights - -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe +from vllm_hpu_extension.ops import DynamicFusedMOE class UnquantizedSparseMoELayer(nn.Module): @@ -23,6 +18,8 @@ class UnquantizedSparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", @@ -37,6 +34,9 @@ class UnquantizedSparseMoELayer(nn.Module): self.topk = topk self.topk_group = topk_group self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, @@ -53,30 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) - def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - if SYSTEM == "rocm": - return fused_moe( - x, - self.gate_up_proj, - self.down_proj, - gating_output, - self.topk, - renormalize=self.renormalize, - inplace=True, - ) + self.hpu_fused_moe = DynamicFusedMOE(n_experts) + for i in range(n_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - ) + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return self.hpu_fused_moe(x, gating_output, self.topk) def _load_expert_multi_weights_col( diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index a2076bb20..6a83d6a57 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -2,14 +2,10 @@ import os import math import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM == "cuda": - import rotary_emb -elif SYSTEM == "rocm": - from vllm._C import ops -elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) def _create_inv_freq(dim, base, device): @@ -30,7 +26,7 @@ def _get_rope_config(config): class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq, scaling_factor): + def __init__(self, inv_freq, scaling_factor, max_position_embeddings): super().__init__() self.inv_freq = inv_freq self._seq_len_cached = 0 @@ -40,6 +36,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, inv_freq.device, max_position_embeddings + ) def forward( self, @@ -48,40 +47,41 @@ class PositionRotaryEmbedding(nn.Module): cos: torch.Tensor, sin: torch.Tensor, ): - # Such controlflows may add some overhead. - if SYSTEM == "cuda": - rotary_dim = cos.shape[-1] - q1 = query[..., :rotary_dim] - q2 = query[..., rotary_dim : 2 * rotary_dim] + num_tokens = query.shape[0] + head_size = query.shape[-1] + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., :rotary_dim] - k2 = key[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "ipex": - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), True - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) scaling_factor = None rope_scaling = _get_rope_config(config) + if not hasattr(config, "max_position_embeddings") and hasattr( + config, "max_seq_len" + ): + # handling for dbrx + config.max_position_embeddings = config.max_seq_len if rope_scaling is not None: # `rope_type` is now standard in transformers, but some existing models # have `type` instead. @@ -89,6 +89,17 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass + elif rope_type == "default": + pass + elif rope_type == "mrope": + mrope_section = rope_scaling["mrope_section"] + if mrope_section is not None: + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, + scaling_factor, + mrope_section, + config.max_position_embeddings, + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -109,7 +120,7 @@ class PositionRotaryEmbedding(nn.Module): ], ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] @@ -185,12 +196,13 @@ class PositionRotaryEmbedding(nn.Module): long_inv_freq=long_inv_freq, scaling_factor=scaling_factor, original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=config.max_position_embeddings, ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) @classmethod def load(cls, config, prefix, weights): @@ -236,7 +248,7 @@ class PositionRotaryEmbedding(nn.Module): raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -257,17 +269,7 @@ class PositionRotaryEmbedding(nn.Module): self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - if SYSTEM == "rocm": - # For RoCm, we always use float cos/sin to avoid a cast. - # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 - # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. - dtype = torch.float32 - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) + def get_cos_sin(self, position_ids: torch.Tensor): cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -283,6 +285,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): long_inv_freq, scaling_factor, original_max_position_embeddings, + max_position_embeddings, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq @@ -295,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -348,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -383,7 +392,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -461,7 +470,9 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) + super().__init__( + inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor + ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -546,3 +557,50 @@ def apply_llama3_scaling( new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): + def __init__( + self, + inv_freq: torch.Tensor, + scaling_factor: float, + sections: list, + max_position_embeddings, + ): + self.sections = sections + self._cos_cached = None + self._sin_cached = None + self.section_indices = ( + torch.arange(len(self.sections)) + .repeat_interleave(torch.tensor(self.sections)) + .view(1, 1, -1) + .to(inv_freq.device) + ) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) + + def _update_cos_sin_cache( + self, dtype: torch.dtype, device: torch.device, seqlen: int + ): + # always cache the cos/sin for the full sequence length to avoid + # recomputing if the sequence length is smaller than the cached one + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + self._sections = self.section_indices.expand(seqlen, -1, -1) + + def get_cos_sin( + self, + position_ids: torch.Tensor, + ): + slen = position_ids.shape[0] + + cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) + sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) + return cos, sin diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1e..8f19174f8 100644 --- a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py +++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py @@ -2,10 +2,8 @@ import torch from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear -from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex +import habana_frameworks.torch as htorch class LayerConcat(torch.nn.Module): @@ -90,14 +88,10 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - if SYSTEM == "ipex": - ipex.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) - else: - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + htorch.core.mark_step() + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -107,10 +101,9 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - if SYSTEM == "ipex": - ipex.distributed.all_gather(world_output, output, group=self.process_group) - else: - torch.distributed.all_gather(world_output, output, group=self.process_group) + + htorch.core.mark_step() + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -202,10 +195,11 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - if SYSTEM == "ipex": - ipex.distributed.all_reduce(out, group=self.process_group) - else: - torch.distributed.all_reduce(out, group=self.process_group) + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -242,8 +236,9 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - if SYSTEM == "ipex": - ipex.distributed.all_reduce(out, group=self.process_group) - else: - torch.distributed.all_reduce(out, group=self.process_group) + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 346016c21..778b14a1b 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -1,3 +1,5 @@ +# ruff: noqa: F821 +# the above line disables the `undefined-name` rule for the model type variables import torch import os @@ -8,6 +10,7 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path from typing import List, Dict +import enum # Needed to properly setup habana_frameworks @@ -16,15 +19,10 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder -from text_generation_server.models.vlm_causal_lm import VlmCausalLM -from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, +from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( + PhiMoEConfig, ) -# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, @@ -33,9 +31,285 @@ from text_generation_server.utils.adapter import ( ) from text_generation_server.adapters.lora import LoraWeights - +from text_generation_server.utils.log import log_master from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +__all__ = [ + "Model", + "CausalLM", + "Seq2SeqLM", + "get_model_with_lora_adapters", +] +from text_generation_server.models.globals import ATTENTION + +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." + +FLASH_ATTENTION = False +if ATTENTION == "paged": + FLASH_ATTENTION = True + +try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM + from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) + from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import ( + FlashDeepseekV3ForCausalLM, + DeepseekV3Config, + ) + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, + ) + from text_generation_server.models.pali_gemma import ( + PaliGemmaBatch, + ) + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, + ) + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch + from text_generation_server.models.custom_modeling.flash_mllama import ( + FlashMllamaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_llava_next import ( + FlashLlavaNextForConditionalGeneration, + ) + + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gptj_modeling import ( + FlashGPTJForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.idefics3 import ( + Idefics3ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.qwen2_5_vl import ( + Qwen2_5VLForConditionalGeneration, + Qwen2_5_VLConfig, + Qwen2_5_VLProcessor, + ) + from text_generation_server.layers.attention import SUPPORTS_WINDOWING +except ImportError as e: + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") + SUPPORTS_WINDOWING = False + FLASH_ATTENTION = False + +if FLASH_ATTENTION: + __all__.append(FlashCausalLM) + + +class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } + DEEPSEEK_V3 = { + "type": "deepseek_v3", + "name": "Deepseek V3", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V3", + } + IDEFICS2 = { + "type": "idefics2", + "name": "Idefics 2", + "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", + "multimodal": True, + } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + "multimodal": True, + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + GRANITE = { + "type": "granite", + "name": "Granite", + "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } + GEMMA2 = { + "type": "gemma2", + "name": "Gemma2", + "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", + } + COHERE = { + "type": "cohere", + "name": "Cohere", + "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", + } + DBRX = { + "type": "dbrx", + "name": "Dbrx", + "url": "https://huggingface.co/databricks/dbrx-instruct", + } + MAMBA = { + "type": "mamba", + "name": "Mamba", + "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407", + } + MIXTRAL = { + "type": "mixtral", + "name": "Mixtral", + "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", + } + GPT_BIGCODE = { + "type": "gpt_bigcode", + "name": "Gpt Bigcode", + "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + PHI_MOE = { + "type": "phimoe", + "name": "PhiMoe", + "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + FALCON = { + "type": "falcon", + "name": "Falcon", + "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", + } + STARCODER2 = { + "type": "starcoder2", + "name": "StarCoder 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", + } + QWEN2_VL = { + "type": "qwen2_vl", + "name": "Qwen 2 VL", + "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", + } + QWEN2_5_VL = { + "type": "qwen2_5_vl", + "name": "Qwen 2.5 VL", + "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", + } + GALACTICA = { + "type": "galactica", + "name": "Galactica", + "url": "https://huggingface.co/facebook/galactica-120b", + } + SANTACODER = { + "type": "santacoder", + "name": "SantaCoder", + "url": "https://huggingface.co/bigcode/santacoder", + } + GPT2 = { + "type": "gpt2", + "name": "Gpt2", + "url": "https://huggingface.co/openai-community/gpt2", + } + GPT_NEOX = { + "type": "gpt_neox", + "name": "Gpt Neox", + "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", + } + GPTJ = { + "type": "gptj", + "name": "Gptj", + "url": "https://huggingface.co/EleutherAI/gpt-j-6b", + } + MLLAMA = { + "type": "mllama", + "name": "Mllama", + "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct", + "multimodal": True, + } + + +__GLOBALS = locals() +for data in ModelType: + __GLOBALS[data.name] = data.value["type"] SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0)) # Disable gradients @@ -53,9 +327,7 @@ def get_model( trust_remote_code: bool, max_input_tokens: int, ) -> Model: - adapt_transformers_to_gaudi() - if SDP_ON_BF16 == 1: - torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + global FLASH_ATTENTION if speculate is not None: set_speculate(speculate) @@ -177,9 +449,393 @@ def get_model( model_type = config_dict["model_type"] + kv_cache_dtype = dtype + + if FLASH_ATTENTION: + if model_type == DEEPSEEK_V2: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif model_type == DEEPSEEK_V3: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV3Config, + head_size=head_size, + ) + + elif ( + model_type == GPT_BIGCODE + or model_type == GPT2 + and model_id.startswith("bigcode/") + ): + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, + ) + elif model_type == GPT2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GPTJ: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTJForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GPT_NEOX: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, + ) + elif model_type == PHI: + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == PHI_MOE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + config_class=PhiMoEConfig, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == BAICHUAN: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GEMMA: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GEMMA2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == COHERE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == DBRX: + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, + ) + elif ( + model_type in ["RefinedWeb", "RefinedWebModel", FALCON] + and not sharded + and not config_dict.get("alibi", False) + ): + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, + ) + elif model_type == MISTRAL: + return FlashCausalLM( + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == MIXTRAL: + return FlashCausalLM( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == STARCODER2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2_VL: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2_5_VL: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Qwen2_5VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=Qwen2_5_VLConfig, + processor_class=Qwen2_5_VLProcessor, + ) + elif model_type == MLLAMA: + return FlashMllamaCausalLM( + model_id=model_id, + model_class=FlashMllamaForConditionalGeneration, + batch_class=FlashMllamaCausalLMBatch, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == IDEFICS2: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + ) + elif model_type == IDEFICS3: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 1456}}, + ) + elif model_type == PALIGEMMA: + return FlashVlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, + ) + elif model_type == LLAVA_NEXT: + return FlashVlmCausalLM( + model_class=FlashLlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + ) + + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.mllama import ( + MllamaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + + adapt_transformers_to_gaudi() + if SDP_ON_BF16 == 1: + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if model_type == "gpt_bigcode": return StarCoder(model_id=model_id, revision=revision, dtype=dtype) - if model_type == "bloom": return BLOOM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e2719fad2..84835ab89 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -377,7 +377,7 @@ class BloomAttention(nn.Module): past_value.view(-1, *past_value.shape[-2:]), ) - if CUSTOM_KERNELS_ENABLED: + if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096: assert self.training is False, "Only foward pass was implemented" assert ( attention_mask.shape[-1] < 4096 @@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel): @staticmethod def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: """ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 30656038b..3bcc689d2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,10 +28,10 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -39,7 +39,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -47,11 +46,10 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight - -if SYSTEM == "cuda": - import dropout_layer_norm -else: - dropout_layer_norm = None +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) class CohereRotary(PositionRotaryEmbedding): @@ -63,38 +61,25 @@ class CohereRotary(PositionRotaryEmbedding): sin: torch.Tensor, ): # Such controlflows may add some overhead. - if SYSTEM == "cuda": - import rotary_emb + num_tokens = query.shape[0] + head_size = query.shape[-1] + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - q1 = query[..., ::2] - q2 = query[..., 1::2] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., ::2] - k2 = key[..., 1::2] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - from vllm._C import ops - - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, False) - elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), False - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) class CohereLayerNorm(nn.Module): @@ -107,49 +92,18 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": - hidden_states = hidden_states.reshape( - -1, self.weight.shape[0], self.weight.shape[1] - ) - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - hidden_states_minus_mean = hidden_states - mean - variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) - hidden_states = self.weight.to(torch.float32) * hidden_states - hidden_states = hidden_states.view(-1, self.weight.shape[1]) - return hidden_states.to(input_dtype) - - ( - hidden_states, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - None, - self.ones, - None, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - - # Required to apply one weight matrix per head - hidden_states = hidden_states.view( + hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) - hidden_states = self.weight * hidden_states + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + hidden_states_minus_mean = hidden_states - mean + variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) + hidden_states = self.weight.to(torch.float32) * hidden_states hidden_states = hidden_states.view(-1, self.weight.shape[1]) - - return hidden_states + return hidden_states.to(input_dtype) def load_attention(config, prefix, weights): @@ -229,6 +183,7 @@ class FlashCohereAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: @@ -264,10 +219,9 @@ class FlashCohereAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, key, value = qkv.split( @@ -291,30 +245,35 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -386,10 +345,9 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -400,10 +358,9 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) mlp_output = self.mlp(normed_hidden_states) @@ -452,18 +409,15 @@ class FlashCohereModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: torch.Tensor, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None @@ -475,10 +429,9 @@ class FlashCohereModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -516,11 +469,9 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -529,10 +480,9 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 1137a453f..15c243c97 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,17 +20,14 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales -if SYSTEM != "ipex": - from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, - PREFILL_IN_KV_CACHE, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( FastLinear, @@ -46,6 +43,7 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from vllm_hpu_extension.ops import DynamicFusedMOE class DbrxAttentionConfig(PretrainedConfig): @@ -290,6 +288,7 @@ class DbrxAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -309,10 +308,9 @@ class DbrxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) if self.clip_qkv is not None: @@ -330,30 +328,35 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -387,10 +390,9 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -401,10 +403,9 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # faster post attention rms norm @@ -482,18 +483,15 @@ class BlockSparseMoE(nn.Module): self.process_group = weights.process_group + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) + for i in range(self.num_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i]) + def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( - x, - self.wv1, - self.w2, - router_logits, - self.top_k, - renormalize=self.moe_normalize_expert_weights, - inplace=True, - ) + out = self.hpu_fused_moe(x, router_logits, self.top_k) # Reduce sum if self.process_group.size() > 1: @@ -620,10 +618,9 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): # Self Attention attn_output, attn_res = self.attn( @@ -633,10 +630,9 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) moe_output = self.moe(attn_output) @@ -677,18 +673,15 @@ class DbrxModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -699,10 +692,9 @@ class DbrxModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -732,11 +724,9 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -745,10 +735,9 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 88c2cf803..9d61c6941 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,21 +33,14 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, + HPUPagedAttentionMetadata, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - class DeepseekV2Config(PretrainedConfig): def __init__( @@ -232,6 +225,8 @@ class DeepseekV2Attention(torch.nn.Module): ), ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -260,11 +255,10 @@ class DeepseekV2Attention(torch.nn.Module): cos: torch.Tensor, sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - block_tables: torch.Tensor, + kv_cache: KVCache, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: query = self.q_proj(hidden_states) @@ -321,30 +315,35 @@ class DeepseekV2Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) # Remove padding. @@ -387,27 +386,11 @@ class DeepseekV2MLP(nn.Module): self.quantize = config.quantize def forward(self, hidden_states: torch.Tensor, reduce: bool = True): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out, reduce=reduce) - else: - gate_up_states = self.gate_up_proj(hidden_states) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce - ) + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) class DeepseekV2MoE(nn.Module): @@ -520,10 +503,9 @@ class DeepseekV2Layer(nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -534,10 +516,9 @@ class DeepseekV2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # faster post attention rms norm @@ -583,18 +564,15 @@ class DeepseekV2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -605,10 +583,9 @@ class DeepseekV2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -635,11 +612,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -648,10 +623,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py new file mode 100644 index 000000000..1a7ce5cf5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -0,0 +1,642 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, + HPUPagedAttentionMetadata, +) +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.weights import Weights + + +class DeepseekV3Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekV3Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: KVCache, + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV3MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class DeepseekV3MoE(nn.Module): + def __init__( + self, + prefix, + config: DeepseekV3Config, + moe_layer_cls: Type[MoELayer], + weights, + ): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = torch.zeros( + config.n_routed_experts, device=weights.device + ) + else: + self.gate.e_score_correction_bias = None + + self.moe_layer = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.n_routed_experts, + n_expert_group=config.n_group, + renormalize=config.norm_topk_prob, + topk=config.num_experts_per_tok, + topk_group=config.topk_group, + weights=weights, + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + assert isinstance(self.moe_layer, MoELayer) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV3MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + + out = self.moe_layer(x, gating_output=router_logits) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DeepseekV3Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV3Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + else: + self.mlp = DeepseekV3MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV3Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV3Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV3ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV3Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 7a3d60c97..79f21b0f3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,8 +28,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -40,7 +40,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -208,6 +208,7 @@ class FlashGemma2Attention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -234,11 +235,10 @@ class FlashGemma2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -253,19 +253,24 @@ class FlashGemma2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, - causal=self.causal, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, ) @@ -273,14 +278,13 @@ class FlashGemma2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, softcap=self.softcap, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -390,11 +394,10 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -405,11 +408,10 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -458,19 +460,16 @@ class FlashGemma2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - adapter_data: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -481,11 +480,10 @@ class FlashGemma2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -529,11 +527,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -543,11 +539,10 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4c1be6f60..609f03acc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,9 +28,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, - PREFILL_IN_KV_CACHE, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -39,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -187,6 +187,7 @@ class FlashGemmaAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -206,10 +207,9 @@ class FlashGemmaAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -224,31 +224,36 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, causal=self.causal, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -317,10 +322,9 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -331,10 +335,9 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # faster post attention rms norm @@ -379,18 +382,16 @@ class FlashGemmaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + adapter_data: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -401,10 +402,9 @@ class FlashGemmaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -446,11 +446,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -460,10 +458,10 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 44c015cf4..10024a6de 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,12 +24,11 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -38,6 +37,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads): prefix, weights, ) - elif config.quantize == "marlin": - raise RuntimeError( - "GPT-2 models with marlin quantization are not yet supported" - ) else: return _load_qkv(config, prefix, weights, head_size, num_heads) @@ -195,6 +191,7 @@ class FlashGPT2Attention(torch.nn.Module): head_size=self.head_size, num_heads=self.num_heads, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -212,10 +209,9 @@ class FlashGPT2Attention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -224,30 +220,35 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -313,10 +314,9 @@ class FlashGPT2Layer(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -326,10 +326,9 @@ class FlashGPT2Layer(nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states = attn_output + residual @@ -379,12 +378,9 @@ class FlashGPT2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -395,10 +391,9 @@ class FlashGPT2Model(torch.nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states = self.norm(hidden_states) @@ -432,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -448,12 +441,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s=max_s, - prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index aca970044..41eeab78c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,12 +24,12 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -38,13 +38,16 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) def load_attention(config, prefix: str, weights): @@ -78,39 +81,25 @@ class GPTJRotary(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - # Such controlflows may add some overhead. - if SYSTEM == "cuda": - import rotary_emb + num_tokens = query.shape[0] + head_size = query.shape[-1] + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - q1 = query[..., ::2] - q2 = query[..., 1::2] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., ::2] - k2 = key[..., 1::2] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - from vllm._C import ops - - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, False) - elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), False - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) class FlashGPTJAttention(torch.nn.Module): @@ -140,6 +129,7 @@ class FlashGPTJAttention(torch.nn.Module): prefix=prefix, weights=weights, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -166,10 +156,9 @@ class FlashGPTJAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -186,30 +175,35 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -267,10 +261,9 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) # Self Attention @@ -280,10 +273,9 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) feed_forward_hidden_states = self.mlp(hidden_states) @@ -327,19 +319,15 @@ class FlashGPTJModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -350,10 +338,9 @@ class FlashGPTJModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -381,11 +368,9 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -394,11 +379,9 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c9ec70cca..81af55603 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,14 +27,16 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention import ( + KVCache, + get_kv_scales, +) from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -57,15 +59,6 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -if SYSTEM != "ipex": - pass - -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. @@ -157,7 +150,10 @@ class FlashLlamaAttention(torch.nn.Module): device=weights.device, ) - self.softmax_scale = self.head_size**-0.5 + # `config.attention_multiplier` is used in Granite + self.softmax_scale = getattr( + config, "attention_multiplier", self.head_size**-0.5 + ) if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -177,11 +173,13 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index + self.kv_scales = get_kv_scales(weights, f"{prefix}") + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=False, + bias=getattr(config, "attention_bias", False), ) self.o_proj = TensorParallelAdapterRowLinear.load( @@ -202,12 +200,11 @@ class FlashLlamaAttention(torch.nn.Module): cos, sin, cu_seqlen_prefill, - kv_cache, - block_tables, + kv_cache: KVCache, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -222,30 +219,35 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_scales=self.kv_scales, + kv_cache=kv_cache, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -363,31 +365,11 @@ class LlamaMLP(nn.Module): self.hidden_size = config.hidden_size def forward(self, hidden_states, adapter_data): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - and self.hidden_size - != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu( - self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 - ) - return self.down_proj(out, adapter_data) - else: - gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - ) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): @@ -408,7 +390,7 @@ class FlashLlamaLayer(nn.Module): if SparseMoELayer.is_supported(weights) else DenseMoELayer ) - self.dense = Phi3MoE( + self.mlp = Phi3MoE( f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights ) # with moe the layernorms are are not rmsnorms and they have bias @@ -423,7 +405,7 @@ class FlashLlamaLayer(nn.Module): eps=config.rms_norm_eps, ) else: - self.dense = LlamaMLP( + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) self.input_layernorm = FastRMSNorm.load( @@ -437,6 +419,11 @@ class FlashLlamaLayer(nn.Module): eps=config.rms_norm_eps, ) + # Used in Granite + # This could eventually be baked into the weights like we do for the embeddings/lm_head + # but this would mean modifying the lora code + self.residual_multiplier = getattr(config, "residual_multiplier", None) + def forward( self, hidden_states, @@ -445,12 +432,11 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -461,19 +447,21 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta=hpu_attention_meta, ) + if self.residual_multiplier is not None: + attn_output *= self.residual_multiplier - # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( attn_output, res ) - mlp_output = self.dense(normed_attn_res_output, adapter_data) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) + if self.residual_multiplier is not None: + mlp_output *= self.residual_multiplier return mlp_output, attn_res @@ -493,9 +481,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=f"{prefix}.layers.0", config=config, weights=weights, ) @@ -504,18 +490,14 @@ class FlashLlamaModel(torch.nn.Module): # Skip first and last layers for layer_id in range(1, config.num_hidden_layers - 1): if layer_id in self.cross_attention_layers: - from text_generation_server.models.custom_modeling.mllama import ( + from text_generation_server.models.custom_modeling.flash_mllama import ( FlashLlamaCrossLayer, ) self.layers.append( FlashLlamaCrossLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -524,11 +506,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -539,18 +517,14 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=last_layer_id, - prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.model.layers.{last_layer_id}" - ), + prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, ) ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -567,22 +541,17 @@ class FlashLlamaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -593,12 +562,11 @@ class FlashLlamaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, + hpu_attention_meta=hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -607,42 +575,60 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}.model.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" + # Used in Granite + embedding_multiplier = getattr(config, "embedding_multiplier", None) + if embedding_multiplier is not None: + self.embed_tokens.weight.data *= embedding_multiplier + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, + prefix, + weights, ) + # Used in Granite + self.logits_scaling = getattr(config, "logits_scaling", None) + if self.logits_scaling is not None and self.lm_head.head is not None: + try: + # Scale the weights directly + self.lm_head.head.linear.weight.data /= self.logits_scaling + self.logits_scaled = True + except Exception: + self.logits_scaled = False + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, @@ -653,16 +639,20 @@ class FlashLlamaForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s=max_s, - prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) + + # Used in Granite + if self.logits_scaling is not None and not self.logits_scaled: + logits /= self.logits_scaling + if speculative_logits is not None: + speculative_logits /= self.logits_scaling + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py new file mode 100644 index 000000000..88548042d --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -0,0 +1,285 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava-NeXT model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.image_processing_utils import select_best_resolution + +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (height, width). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (height, width). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext +class LlavaNextMultiModalProjector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.linear_1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class FlashLlavaNextForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + vision_config = config.vision_config + # Instead of selecting in hidden_states[-2]. + # Instead compute only the n -2 + 1 layers and don't pool + if config.vision_feature_layer < 0: + vision_config.num_hidden_layers += config.vision_feature_layer + 1 + else: + vision_config.num_hidden_layers = config.vision_feature_layer + 1 + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.multi_modal_projector = LlavaNextMultiModalProjector( + prefix="multi_modal_projector", config=config, weights=weights + ) + + self.image_newline = weights.get_tensor("image_newline") + + self.vocab_size = config.text_config.vocab_size + self.config = config + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = torch.where(input_ids == self.config.image_token_index) + # Let's pray we have enabled enough slots ! + try: + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + except Exception as e: + raise RuntimeError( + f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + # Unused for this model + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None and len(pixel_values) > 0: + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_features + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 341a23524..d23d4f679 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,12 +26,12 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -41,20 +41,12 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - - class MistralConfig(PretrainedConfig): model_type = "mistral" @@ -160,6 +152,7 @@ class MistralAttention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -185,12 +178,10 @@ class MistralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -205,38 +196,36 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -300,29 +289,11 @@ class MistralMLP(nn.Module): self.quantize = config.quantize def forward(self, hidden_states, adapter_data): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu( - self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 - ) - return self.down_proj(out, adapter_data) - else: - gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - ) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): @@ -355,12 +326,10 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -371,12 +340,10 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -423,20 +390,15 @@ class MistralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -447,12 +409,10 @@ class MistralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -498,35 +458,21 @@ class FlashMistralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + hpu_attention_meta, adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 5836d30af..1ef6be481 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,9 +37,9 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, + HPUPagedAttentionMetadata, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding @@ -215,6 +215,7 @@ class MixtralAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -234,11 +235,9 @@ class MixtralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -253,38 +252,36 @@ class MixtralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -378,11 +375,9 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -393,11 +388,9 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -448,20 +441,15 @@ class MixtralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -472,11 +460,9 @@ class MixtralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -507,34 +493,21 @@ class FlashMixtralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py new file mode 100644 index 000000000..216642e08 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -0,0 +1,986 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, +) +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + attention_mask = ( + attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave( + num_vision_tokens, dim=3 + ) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision +class MllamaVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.embed_dim = config.hidden_size + self.head_dim = config.hidden_size // config.attention_heads + self.num_heads = config.attention_heads // weights.process_group.size() + + self.qkv_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_state) + query, key, value = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + ], + dim=2, + ) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = MllamaVisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 + ) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False + ) + self.gate_ffn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): + super().__init__() + self.config = config + self.layers = [ + MllamaVisionEncoderLayer( + prefix=f"{prefix}.layers.{i}", + config=config, + weights=weights, + is_gated=is_gated, + ) + for i in range(num_layers) + ] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + encoder_states = [hidden_states] + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs + encoder_states.append(hidden_states) + + return hidden_states, encoder_states + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + + self.embedding = TensorParallelEmbedding( + prefix=f"{prefix}.embedding", weights=weights + ) + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + # Always gated. + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + # position embedding + embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.embedding"), requires_grad=False + ) + self.gated_position_embedding = (1 - self.gate.tanh()) * embedding + self.tile_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.tile_embedding", weights=weights + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaVisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + self.dtype = weights.dtype + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False + ) + + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + prefix=f"{prefix}.gated_positional_embedding", + config=config, + weights=weights, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.pre_tile_positional_embedding", + config=config, + weights=weights, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.post_tile_positional_embedding", + config=config, + weights=weights, + ) + + ## layer norms + self.layernorm_pre = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_pre", + weights=weights, + # torch default + eps=1e-05, + ) + self.layernorm_post = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_post", + weights=weights, + # torch default + eps=1e-05, + ) + + ## encoders + self.transformer = MllamaVisionEncoder( + prefix=f"{prefix}.transformer", + config=config, + weights=weights, + is_gated=False, + num_layers=config.num_hidden_layers, + ) + self.global_transformer = MllamaVisionEncoder( + prefix=f"{prefix}.global_transformer", + config=config, + weights=weights, + is_gated=True, + num_layers=config.num_global_layers, + ) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # patch embedding + patch_embeds = self.patch_embedding(pixel_values) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + if attention_mask is not None: + attention_mask = attention_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + hidden_state, all_intermediate_hidden_states = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + hidden_state, _ = self.global_transformer( + hidden_state, attention_mask=attention_mask + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, *, prefix, config, weights, layer_idx): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_size = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.q_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps + ) + self.k_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps + ) + self.softmax_scale = self.head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + # past_key_value=None, + # attention_mask: Optional[torch.Tensor] = None, + # cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # hidden_states = hidden_states.unsqueeze(0) + # bsz, q_len, _ = hidden_states.size() + ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + indices, + ) = cross_attention_states + bs = cu_seqlen_q.size(0) - 1 + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bs, -1, self.num_heads, self.head_size) + query_states = self.q_norm(query_states) + + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size) + value_states = value_states.view( + bs, -1, self.num_key_value_heads, self.head_size + ) + key_states = self.k_norm(key_states) + + # key_states = key_states.repeat(1, self.num_key_value_groups, 1) + # value_states = value_states.repeat(1, self.num_key_value_groups, 1) + + causal = False + # logger.info( + # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" + # ) + # execute sdpa + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query_states, + key_states, + value_states, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + return attn_output + + +# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText +class MllamaTextMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + shape = x.shape + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) + result = self.down_proj( + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] + ) + return result + + +class FlashLlamaCrossLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, *, prefix, config, weights, index) -> None: + layer_idx = index + super().__init__() + self.cross_attn = MllamaTextCrossAttention( + prefix=f"{prefix}.cross_attn", + config=config, + weights=weights, + layer_idx=layer_idx, + ) + + self.input_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.cross_attn_attn_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False + ) + + self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.post_attention_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.cross_attn_mlp_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False + ) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + cross_attention_states, # [ IB, ...] + hpu_attention_meta, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cross_attention_states is None: + return hidden_states, residual + if residual is not None: + hidden_states += residual + + indices = cross_attention_states[-1] + out_hidden_states = hidden_states[:] + if len(indices) > 0: + assert max(indices) < hidden_states.shape[0] + hidden_states = hidden_states[indices] + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + # attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + out_hidden_states[indices] = hidden_states + hidden_states = out_hidden_states + + return hidden_states, None + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, weight, eps): + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + @classmethod + def load(cls, *, prefix, weights, eps): + weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + return cls(weight=weight, eps=eps) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class FlashMllamaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + config.text_config._attn_implementation = "sdpa" + self.hidden_size = config.text_config.hidden_size + self.vision_model = MllamaVisionModel( + prefix="vision_model", config=config.vision_config, weights=weights + ) + self.multi_modal_projector = FastLinear.load( + prefix="multi_modal_projector", config=config, weights=weights, bias=True + ) + self.text_model = FlashLlamaForCausalLM( + prefix="language_model", config=config.text_config, weights=weights + ) + self.config = config + self.dtype = weights.dtype + self.device = weights.device + + def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # logger.info(f"PIxel values {pixel_values.shape}") + batch_size = pixel_values.shape[0] + vision_states = self.vision_model( + pixel_values, aspect_ratio_ids, aspect_ratio_mask + ) + cross_attention_states = self.multi_modal_projector(vision_states).reshape( + -1, vision_states.shape[-2], self.hidden_size + ) + _, _, h = cross_attention_states.shape + cross_attention_states = cross_attention_states.view(batch_size, -1, h) + # logger.info(f"cross {cross_attention_states.shape}") + return cross_attention_states + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, + # XXX: Putting these as optional so that the cuda warmup calls can go through. + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + if cross_attention_states is not None: + seqlen_q = len(image_indices) + n_images = cross_attention_states.shape[0] + seqlen_k = cross_attention_states.shape[1] + device = cross_attention_states.device + if cu_seqlen_prefill is not None: + offset = 0 + cu_q = [] + indices = [] + for index in image_indices: + cu_q.append(offset) + length = seqlen.input_lengths[index].item() + assert index < seqlen.cu_seqlen_q.shape[0] + input_ids_offset = seqlen.cu_seqlen_q[index] + indices.extend(range(input_ids_offset, input_ids_offset + length)) + offset += length + cu_q.append(offset) + cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) + + assert max(indices) < input_ids.shape[0] + + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + else: + cu_seqlen_q = torch.arange( + seqlen_q + 1, device=device, dtype=torch.int32 + ) + seqlen_k = cross_attention_states.shape[1] + n_images = cross_attention_states.shape[0] + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + indices = image_indices[:] + + cross_attention_states = ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + indices, + ) + + outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + cross_attention_states=cross_attention_states, + ) + + return outputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ad4e382fe..33f63333a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,8 +29,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -39,7 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -132,6 +132,7 @@ class FlashNeoxAttention(torch.nn.Module): head_size=self.head_size, hidden_size=self.hidden_size, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) @@ -146,10 +147,9 @@ class FlashNeoxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -165,30 +165,35 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=qkv[:, 1], + value=qkv[:, 2], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - qkv[:, 0], - kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], - kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], - seqlen, - block_tables, - self.softmax_scale, + query=qkv[:, 0], + key=qkv[:, 1], + value=qkv[:, 2], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( qkv[:, 0], - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -255,10 +260,9 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -269,10 +273,9 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -293,10 +296,9 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, residual = self.post_attention_layernorm( @@ -347,18 +349,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -369,10 +368,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.final_layer_norm(hidden_states, residual) @@ -401,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -414,10 +410,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0024f2bb9..4d31d5ddf 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,7 +19,7 @@ from torch import nn from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -69,22 +69,20 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused here pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: - max_s += 1 position_ids += 1 if pixel_values is not None: @@ -106,10 +104,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 2a0dc6066..0c7779124 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,8 +9,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -19,7 +19,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: + if config.quantize not in ["gptq", "awq"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -139,6 +139,7 @@ class FlashPhiAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") # in llama the dense layer is called "o_proj" and has bias=False self.dense = TensorParallelRowLinear.load( @@ -159,10 +160,9 @@ class FlashPhiAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): # Compute query, key, value and split qkv = self.query_key_value(hidden_states) @@ -188,29 +188,34 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_scales=self.kv_scales, + kv_cache=kv_cache, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -274,10 +279,9 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention @@ -287,10 +291,9 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states = self.resid_dropout(attn_output).add( @@ -339,18 +342,15 @@ class FlashPhiModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -361,10 +361,9 @@ class FlashPhiModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -394,11 +393,9 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -407,10 +404,9 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 02c788d3e..af4b404d0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,8 +8,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -17,7 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -86,6 +86,8 @@ class Qwen2Attention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -104,11 +106,9 @@ class Qwen2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -123,38 +123,36 @@ class Qwen2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -223,13 +221,11 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + normed_hidden_states, residual = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( @@ -238,21 +234,17 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) + hidden_states = attn_output + residual # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) - - mlp_output = self.mlp(normed_attn_res_output) - - return mlp_output, attn_res + hidden_states, residual = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states) + hidden_states = mlp_output + residual + return hidden_states class Qwen2Model(torch.nn.Module): @@ -264,9 +256,6 @@ class Qwen2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ Qwen2Layer( @@ -290,42 +279,35 @@ class Qwen2Model(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = inputs_embeds - # Get rotary cos and sin for this forward - # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype + position_ids, ) residual = None for i, layer in enumerate(self.layers): - hidden_states, residual = layer( + hidden_states = layer( hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states) return hidden_states @@ -346,6 +328,12 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + self.max_past = config.sliding_window self.max_past_tensor = ( torch.tensor(config.sliding_window, device=weights.device) @@ -359,34 +347,23 @@ class Qwen2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) + + inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6671d85e2..141e13a63 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,14 +12,14 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) @@ -79,6 +79,7 @@ class RWConfig(PretrainedConfig): self.alibi = False self.rotary = True self.rope_theta = rope_theta + self.max_position_embeddings = 2048 self.vocab_size = vocab_size # Backward compatibility with n_embed kwarg @@ -160,6 +161,7 @@ class FlashRWAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -180,10 +182,9 @@ class FlashRWAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -200,30 +201,35 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -278,6 +284,7 @@ class FlashRWLargeAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -293,10 +300,9 @@ class FlashRWLargeAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -312,36 +318,35 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - reshape_and_cache( - kv[:, :, 0].contiguous(), - kv[:, :, 1].contiguous(), - kv_cache[0], - kv_cache[1], - slots, + kv_cache.store( + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, :, 0], + value=kv[:, :, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense( @@ -424,10 +429,9 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -438,10 +442,9 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) mlp_output = self.mlp(ln_hidden_states) @@ -460,10 +463,9 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if self.post_attention_layernorm is not None: @@ -547,10 +549,9 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): # Layer norm. ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) @@ -562,10 +563,9 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # MLP. @@ -623,18 +623,15 @@ class FlashRWModel(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.h): @@ -645,10 +642,9 @@ class FlashRWModel(FlashRWPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -675,11 +671,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -688,10 +682,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 43eb9687f..b68f47840 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,8 +8,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -18,7 +18,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, @@ -32,10 +32,6 @@ def load_multi_mqa( return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) - elif config.quantize == "marlin": - raise RuntimeError( - "santacoder models with marlin quantization are not yet supported" - ) else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -259,6 +255,7 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_head_mapping = torch.zeros( self.num_heads, dtype=torch.int32, device=weights.device ) @@ -268,10 +265,9 @@ class FlashMQAttention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.c_attn(hidden_states) @@ -284,32 +280,35 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - reshape_and_cache( - key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=key_value[:, 0], + value=key_value[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key_value[:, 0], + value=key_value[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -371,20 +370,18 @@ class Block(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -435,10 +432,9 @@ class FlashSantacoderModel(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -452,10 +448,9 @@ class FlashSantacoderModel(nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -484,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -497,10 +490,9 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 4975cf225..76f6f473a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,17 +29,19 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -110,17 +112,31 @@ class Starcoder2Config(PretrainedConfig): ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=config.use_bias, ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): @@ -158,6 +174,7 @@ def _load_gqa(config, prefix: str, weights): class Starcoder2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -189,14 +206,23 @@ class Starcoder2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) + self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=config.use_bias, + bias=getattr(config, "use_bias", False), ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -209,13 +235,12 @@ class Starcoder2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -228,45 +253,45 @@ class Starcoder2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Starcoder2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -280,27 +305,42 @@ class Starcoder2MLP(nn.Module): ) ) # Fuse gate and up proj - self.c_fc = TensorParallelColumnLinear.load( + c_fc = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.c_fc", weights=weights, bias=config.use_bias, ) - self.c_proj = TensorParallelRowLinear.load( + c_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.c_proj", weights=weights, bias=config.use_bias, ) - def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id=index, + layer_names=[f"{prefix}.c_fc"], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.c_proj = TensorParallelAdapterRowLinear.load( + c_proj, + index, + "c_proj", + process_group=weights.process_group, + ) + + def forward(self, hidden_states, adapter_data): + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - return self.c_proj(hidden_states) + return self.c_proj(hidden_states, adapter_data) class Starcoder2GatedMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -314,27 +354,47 @@ class Starcoder2GatedMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=config.use_bias, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=config.use_bias, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) STARCODER2_NORMALIZATION_CLASSES = { @@ -353,11 +413,11 @@ class Starcoder2Layer(nn.Module): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id ) self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( @@ -379,11 +439,10 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -394,11 +453,10 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -406,7 +464,7 @@ class Starcoder2Layer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -447,20 +505,16 @@ class Starcoder2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -471,11 +525,10 @@ class Starcoder2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -519,34 +572,22 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index a829c3741..02806ac94 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,7 +25,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module): ): """In place merges in vision_embeddings with inputs_embeds.""" # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds @@ -739,17 +740,16 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: @@ -793,6 +793,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ].contiguous() patch_size = self.config.vision_config.patch_size + """ patches_subgrid = pixel_attention_mask.unfold( dimension=1, size=patch_size, step=patch_size ) @@ -800,6 +801,21 @@ class Idefics2ForConditionalGeneration(nn.Module): dimension=2, size=patch_size, step=patch_size ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) # Get sequence from the vision encoder image_hidden_states = self.vision_model( @@ -825,12 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, + hpu_attention_meta=hpu_attention_meta, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py new file mode 100644 index 000000000..964526fcf --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -0,0 +1,596 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics3 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics3VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics3VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics3VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics3EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics3VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics3VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics3Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics3EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics3VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics3VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics3Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics3SimpleMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj + + def forward(self, x): + return self.proj(x) + + +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight` + # since Idefics3 uses the `embed_tokens` for the final prediction + # config.text_config.tie_word_embeddings = True + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics3VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + """ + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py index fc6becc4b..a130dbc12 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -46,15 +46,9 @@ from text_generation_server.layers import ( FastLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.utils.import_utils import SYSTEM from loguru import logger -if SYSTEM == "cuda": - import dropout_layer_norm -elif SYSTEM == "rocm": - from vllm._C import ops -else: - dropout_layer_norm = None +dropout_layer_norm = None @dataclass @@ -351,94 +345,18 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex + from vllm_hpu_extension.kernels import rms_norm - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - residual is not None, - ) - return out - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - elif SYSTEM == "cuda": - # faster post attention rms norm - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) - - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - if unwrap: - normed_hidden_states = normed_hidden_states.view(*shape) - - return normed_hidden_states - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - - if unwrap: - out = out.view(*shape) - - return out + orig_shape = hidden_states.shape + if residual is not None: + residual += hidden_states.view(residual.shape) else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + residual = hidden_states + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + if len(orig_shape) == 2: + residual = residual.unsqueeze(0) + x = rms_norm().apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual.view(orig_shape) # this was adapted from LlamaMLP diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 293051c2b..5a9c05887 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -196,7 +196,10 @@ class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + try: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights) + except RuntimeError: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) @@ -206,7 +209,10 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + try: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) + except RuntimeError: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) self.config = config def forward( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py deleted file mode 100644 index 988a74a39..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ /dev/null @@ -1,1215 +0,0 @@ -"""A simple, flexible implementation of a GPT model. - -Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py -""" - -import math -import warnings -from typing import List, Optional, Tuple, Union -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from einops import rearrange -from packaging import version -from text_generation_server.layers import ( - TensorParallelEmbedding, - TensorParallelColumnLinear, - TensorParallelRowLinear, - SpeculativeHead, - get_linear, -) - -EPS = 1e-5 - - -def load_col(config, prefix, weights, bias): - assert config.quantize != "gptq", NotImplementedError - slice_ = weights._get_slice(f"{prefix}.weight") - rank = weights.process_group.rank() - size = weights.process_group.size() - - h3, h = slice_.get_shape() - block_size = h // size - - q_part = slice_[rank * block_size : (rank + 1) * block_size] - k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] - v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] - - weight = torch.cat([q_part, k_part, v_part], dim=0) - if weight.dtype != torch.int32: - weight = weight.to(dtype=weights.dtype) - weight = weight.to(device=weights.device) - - if bias: - bias_slice_ = weights._get_slice(f"{prefix}.bias") - bias_rank = weights.process_group.rank() - bias_size = weights.process_group.size() - - bias_h = bias_slice_.get_shape() - bias_h = bias_h[0] - bias_block_size = bias_h // bias_size - - bias_q_part = bias_slice_[ - bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size - ] - bias_k_part = bias_slice_[ - bias_h - + bias_rank * bias_block_size : bias_h - + (bias_rank + 1) * bias_block_size - ] - bias_v_part = bias_slice_[ - 2 * bias_h - + bias_rank * bias_block_size : 2 * bias_h - + (bias_rank + 1) * bias_block_size - ] - - bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) - if bias.dtype != torch.int32: - bias = bias.to(dtype=weights.dtype) - bias = bias.to(device=weights.device) - else: - bias = None - linear = get_linear(weight, bias) - return TensorParallelColumnLinear(linear) - - -def _reset_is_causal( - num_query_tokens: int, num_key_tokens: int, original_is_causal: bool -): - if original_is_causal and num_query_tokens != num_key_tokens: - if num_query_tokens != 1: - raise NotImplementedError( - "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." - ) - else: - return False - return original_is_causal - - -def scaled_multihead_dot_product_attention( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) - kv_n_heads = 1 if multiquery else n_heads - k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) - v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) - if past_key_value is not None: - if len(past_key_value) != 0: - k = torch.cat([past_key_value[0], k], dim=3) - v = torch.cat([past_key_value[1], v], dim=2) - past_key_value = (k, v) - (b, _, s_q, d) = q.shape - s_k = k.size(-1) - attn_weight = q.matmul(k) * softmax_scale - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - s_q) - _s_k = max(0, attn_bias.size(3) - s_k) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if ( - attn_bias.size(-1) != 1 - and attn_bias.size(-1) != s_k - or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) - ): - raise RuntimeError( - f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." - ) - attn_weight = attn_weight + attn_bias - min_val = torch.finfo(q.dtype).min - if key_padding_mask is not None: - if attn_bias is not None: - warnings.warn( - "Propogating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unneccessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - attn_weight = attn_weight.masked_fill( - ~key_padding_mask.view((b, 1, 1, s_k)), min_val - ) - if is_causal and (not q.size(2) == 1): - s = max(s_q, s_k) - causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) - causal_mask = causal_mask.tril() - causal_mask = causal_mask.to(torch.bool) - causal_mask = ~causal_mask - causal_mask = causal_mask[-s_q:, -s_k:] - attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p: - attn_weight = torch.nn.functional.dropout( - attn_weight, p=dropout_p, training=training, inplace=True - ) - out = attn_weight.to(v.dtype).matmul(v) - out = rearrange(out, "b h s d -> b s (h d)") - if needs_weights: - return (out, attn_weight, past_key_value) - return (out, None, past_key_value) - - -def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): - for tensor in tensors: - if tensor.dtype not in valid_dtypes: - raise TypeError( - f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." - ) - if not tensor.is_cuda: - raise TypeError( - f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." - ) - - -def flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from flash_attn import bert_padding, flash_attn_interface - except Exception: - raise RuntimeError("Please install flash-attn==1.0.3.post0") - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: - raise NotImplementedError("attn_bias not implemented for flash attn.") - (batch_size, seqlen) = query.shape[:2] - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1) :] - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( - query, query_padding_mask - ) - query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( - key, key_padding_mask - ) - key_unpad = rearrange( - key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange( - value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - if multiquery: - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand( - value_unpad.size(0), n_heads, value_unpad.size(-1) - ) - dropout_p = dropout_p if training else 0.0 - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights, - ) - output = bert_padding.pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen - ) - return (output, None, past_key_value) - - -def triton_flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from .flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if version.parse(torch.__version__) < version.parse("2.0.0"): - _installed = True - try: - from flash_attn.flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if not _installed: - raise RuntimeError( - "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." - ) - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if dropout_p: - raise NotImplementedError("Dropout not implemented for attn_impl: triton.") - if needs_weights: - raise NotImplementedError("attn_impl: triton cannot return attn weights.") - if key_padding_mask is not None: - warnings.warn( - "Propagating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unnecessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - (b_size, s_k) = key_padding_mask.shape[:2] - if attn_bias is None: - attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill( - ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min - ) - query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) - key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - if multiquery: - key = key.expand(*key.shape[:2], n_heads, key.size(-1)) - value = value.expand(*value.shape[:2], n_heads, value.size(-1)) - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_func( - query, key, value, attn_bias, reset_is_causal, softmax_scale - ) - output = attn_output.view(*attn_output.shape[:2], -1) - return (output, None, past_key_value) - - -class MultiheadAttention(nn.Module): - """Multi-head self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__( - self, - config, - prefix, - weights, - ): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config.attn_pdrop - - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // weights.process_group.size() - self.Wqkv = load_col( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - if self.qk_ln: - bias = not config.no_bias - hidden_size = config.d_model - head_dim = hidden_size // self.n_heads - - self.q_ln = LPLayerNorm( - d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights - ) - self.k_ln = LPLayerNorm( - self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights - ) - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.chunk(3, dim=2) - - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - ) - out = self.out_proj(context) - return (out, attn_weights, past_key_value) - - -class MultiQueryAttention(nn.Module): - """Multi-Query self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__(self, config, prefix, weights, verbose=False): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config.attn_pdrop - # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - (d_model, d_model + self.head_dim) - if self.qk_ln: - raise NotImplementedError("qk_ln not supported") - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - if verbose: - warnings.warn( - "While `attn_impl: triton` can be faster than `attn_impl: flash` " - + "it uses more memory. When training larger models this can trigger " - + "alloc retries which hurts performance. If encountered, we recommend " - + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." - ) - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - if torch.cuda.is_available() and verbose: - warnings.warn( - "Using `attn_impl: torch`. If your model does not use `alibi` or " - + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " - + "we recommend using `attn_impl: triton`." - ) - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.out_proj._is_residual = True - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2 - ) - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - multiquery=True, - ) - return (self.out_proj(context), attn_weights, past_key_value) - - -def attn_bias_shape( - attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - if (prefix_lm or not causal) or use_sequence_id: - return (1, n_heads, seq_len, seq_len) - return (1, n_heads, 1, seq_len) - elif prefix_lm or use_sequence_id: - return (1, 1, seq_len, seq_len) - return None - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def build_attn_bias( - attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - (device, dtype) = (attn_bias.device, attn_bias.dtype) - attn_bias = attn_bias.add( - build_alibi_bias( - n_heads, - seq_len, - full=not causal, - alibi_bias_max=alibi_bias_max, - device=device, - dtype=dtype, - ) - ) - return attn_bias - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def gen_slopes(n_heads, alibi_bias_max=8, device=None): - _n_heads = 2 ** math.ceil(math.log2(n_heads)) - m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) - m = m.mul(alibi_bias_max / _n_heads) - slopes = 1.0 / torch.pow(2, m) - if _n_heads != n_heads: - slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - return slopes.view(1, n_heads, 1, 1) - - -def build_alibi_bias( - n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None -): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( - 1, 1, 1, seq_len - ) - if full: - alibi_bias = alibi_bias - torch.arange( - 1 - seq_len, 1, dtype=torch.int32, device=device - ).view(1, 1, seq_len, 1) - alibi_bias = alibi_bias.abs().mul(-1) - slopes = gen_slopes(n_heads, alibi_bias_max, device=device) - alibi_bias = alibi_bias * slopes - return alibi_bias.to(dtype=dtype) - - -ATTN_CLASS_REGISTRY = { - "multihead_attention": MultiheadAttention, - "multiquery_attention": MultiQueryAttention, -} - -"""GPT Blocks used for the GPT Model.""" - - -class MPTMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) - self.up_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias - ) - self.act = nn.GELU(approximate="none") - # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.down_proj._is_residual = True - - def forward(self, x): - return self.down_proj(self.act(self.up_proj(x))) - - -class MPTBlock(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.prefix = prefix - if config.attn_config.attn_type != "multihead_attention": - raise NotImplementedError( - f"""Not implemented attn {config.attn_config.attn_type}""" - ) - resid_pdrop = config.resid_pdrop - if config.no_bias: - self.norm_1 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - else: - self.norm_1 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) - self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) - self.resid_attn_dropout = nn.Dropout(resid_pdrop) - self.resid_ffn_dropout = nn.Dropout(resid_pdrop) - - def forward( - self, - x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attn_bias: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.ByteTensor] = None, - is_causal: bool = True, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: - a = self.norm_1(x) - (b, attn_weights, past_key_value) = self.attn( - a, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=is_causal, - ) - x = x + self.resid_attn_dropout(b) - m = self.norm_2(x) - n = self.ffn(m) - x = x + self.resid_ffn_dropout(n) - return (x, attn_weights, past_key_value) - - -def _cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor - - -class LPLayerNorm(torch.nn.LayerNorm): - def __init__( - self, - normalized_shape, - eps=1e-05, - elementwise_affine=True, - device=None, - dtype=None, - bias: Optional[bool] = True, - prefix=None, - weights=None, - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - device=device, - dtype=dtype, - bias=bias, - ) - if weights is not None: - self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) - if bias: - self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) - self.normalized_shape = self.weight.shape - - def forward(self, x): - module_device = x.device - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - downcast_bias = ( - _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - ) - with torch.autocast(enabled=False, device_type=module_device.type): - return torch.nn.functional.layer_norm( - downcast_x, - self.normalized_shape, - downcast_weight, - downcast_bias, - self.eps, - ) - - -def rms_norm(x, weight=None, eps=1e-05): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - if weight is not None: - return output * weight - return output - - -class RMSNorm(torch.nn.Module): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__() - self.eps = eps - if weight: - self.weight = torch.nn.Parameter( - torch.ones(normalized_shape, dtype=dtype, device=device) - ) - else: - self.register_parameter("weight", None) - - def forward(self, x): - return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) - - -class LPRMSNorm(RMSNorm): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - weight=weight, - dtype=dtype, - device=device, - ) - - def forward(self, x): - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - with torch.autocast(enabled=False, device_type=x.device.type): - return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) - - -NORM_CLASS_REGISTRY = { - "layernorm": torch.nn.LayerNorm, - "low_precision_layernorm": LPLayerNorm, - "rmsnorm": RMSNorm, - "low_precision_rmsnorm": LPRMSNorm, -} - -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - -class MPTPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" - _no_split_modules = ["MPTBlock"] - - -class MPTModel(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - # config._validate_config() - super().__init__(config) - self.world_size = weights.process_group.size() - self.rank = weights.process_group.rank() - self.n_heads = config.n_heads - self.attn_impl = config.attn_config.attn_impl - self.prefix_lm = config.attn_config.prefix_lm - self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id - self.alibi = config.attn_config.alibi - self.alibi_bias_max = config.attn_config.alibi_bias_max - if config.init_device == "mixed": - # TODO: reimplement mixed device initialization - # dist.get_local_rank() == 0: - if True: - config.init_device = "cpu" - else: - config.init_device = "meta" - if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." - ) - if config.norm_type.lower() != "low_precision_layernorm": - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo." - ) - - self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) - - if not self.alibi: - self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) - self.blocks = nn.ModuleList( - [ - MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) - for i in range(config.n_layers) - ] - ) - if config.no_bias: - self.norm_f = nn.LayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - else: - self.norm_f = nn.LayerNorm.load( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - self.is_causal = not self.prefix_lm - self._attn_bias_initialized = False - self.attn_bias = None - self.attn_bias_shape = attn_bias_shape( - self.attn_impl, - config.n_heads, - config.max_seq_len, - self.alibi, - prefix_lm=self.prefix_lm, - causal=self.is_causal, - use_sequence_id=self.attn_uses_sequence_id, - ) - if config.no_bias: - for module in self.modules(): - if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): - if config.verbose: - warnings.warn(f"Removing bias ({module.bias}) from {module}.") - module.register_parameter("bias", None) - if hasattr(self.config, "verbose"): - if config.verbose and config.verbose > 2: - print(self) - if "verbose" not in self.config.init_config: - self.config.init_config["verbose"] = self.config.verbose - if self.config.init_config["verbose"] > 1: - init_fn_name = self.config.init_config["name"] - warnings.warn(f"Using {init_fn_name} initialization.") - - @torch.no_grad() - def _attn_bias( - self, - device, - dtype, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - ): - if not self._attn_bias_initialized: - if self.attn_bias_shape: - self.attn_bias = torch.zeros( - self.attn_bias_shape, device=device, dtype=dtype - ) - self.attn_bias = build_attn_bias( - self.attn_impl, - self.attn_bias, - self.config.n_heads, - self.config.max_seq_len, - causal=self.is_causal, - alibi=self.alibi, - alibi_bias_max=self.alibi_bias_max, - ) - assert self.n_heads % self.world_size == 0 - block_size = self.n_heads // self.world_size - self.attn_bias = self.attn_bias[ - :, self.rank * block_size : (self.rank + 1) * block_size - ] - self._attn_bias_initialized = True - if self.attn_impl == "flash": - return (self.attn_bias, attention_mask) - if self.attn_bias is not None: - self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) - attn_bias = self.attn_bias - if self.prefix_lm: - assert isinstance(attn_bias, torch.Tensor) - assert isinstance(prefix_mask, torch.Tensor) - attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) - if self.attn_uses_sequence_id and sequence_id is not None: - assert isinstance(attn_bias, torch.Tensor) - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) - if attention_mask is not None: - s_k = attention_mask.shape[-1] - if attn_bias is None: - attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) - else: - _s_k = max(0, attn_bias.size(-1) - s_k) - attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: - raise ValueError( - f"attention_mask shape={attention_mask.shape} " - + f"and prefix_mask shape={prefix_mask.shape} are not equal." - ) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill( - ~attention_mask.view(-1, 1, 1, s_k), min_val - ) - return (attn_bias, None) - - def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): - (s_k, s_q) = attn_bias.shape[-2:] - if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: - raise ValueError( - "attn_bias does not match the expected shape. " - + f"The last two dimensions should both be {self.config.max_length} " - + f"but are {s_k} and {s_q}." - ) - seq_len = prefix_mask.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - causal = torch.tril( - torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) - ).view(1, 1, seq_len, seq_len) - prefix = prefix_mask.view(-1, 1, 1, seq_len) - cannot_attend = ~torch.logical_or(causal, prefix.bool()) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def _apply_sequence_id( - self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor - ): - seq_len = sequence_id.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - cannot_attend = torch.logical_not( - torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) - ).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if attention_mask is not None: - attention_mask = attention_mask.bool() - if prefix_mask is not None: - prefix_mask = prefix_mask.bool() - if not return_dict: - raise NotImplementedError( - "return_dict False is not implemented yet for MPT" - ) - if output_attentions: - if self.attn_impl != "torch": - raise NotImplementedError( - "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." - ) - if ( - attention_mask is not None - and attention_mask[:, 0].sum() != attention_mask.shape[0] - and self.training - ): - raise NotImplementedError( - "MPT does not support training with left padding." - ) - if self.prefix_lm and prefix_mask is None: - raise ValueError( - "prefix_mask is a required argument when MPT is configured with prefix_lm=True." - ) - if self.training: - if self.attn_uses_sequence_id and sequence_id is None: - raise ValueError( - "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " - + "and the model is in train mode." - ) - elif self.attn_uses_sequence_id is False and sequence_id is not None: - warnings.warn( - "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " - + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." - ) - S = input_ids.size(1) - assert ( - S <= self.config.max_seq_len - ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" - tok_emb = self.wte(input_ids) - if self.alibi: - x = tok_emb - else: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - "past_key_values must provide a past_key_value for each attention " - + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." - ) - past_position = past_key_values[0][0].size(1) - if self.attn_impl == "torch": - past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: - raise ValueError( - f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." - ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - pos = torch.clamp( - pos - - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ - :, past_position: - ], - min=0, - ) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - (attn_bias, attention_mask) = self._attn_bias( - device=x.device, - dtype=torch.float32, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - ) - if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.config.n_layers)] - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for b_idx, block in enumerate(self.blocks): - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - past_key_value = ( - past_key_values[b_idx] if past_key_values is not None else None - ) - (x, attn_weights, past_key_value) = block( - x, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=self.is_causal, - ) - if past_key_values is not None: - past_key_values[b_idx] = past_key_value - if output_attentions: - assert all_self_attns is not None - all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - return BaseModelOutputWithPast( - last_hidden_state=x, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - if not config.tie_word_embeddings: - raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(prefix, config, weights) - self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.wte", weights=weights - ) - self.logit_scale = None - if config.logit_scale is not None: - logit_scale = config.logit_scale - if isinstance(logit_scale, str): - if logit_scale == "inv_sqrt_d_model": - logit_scale = 1 / math.sqrt(config.d_model) - else: - raise ValueError( - f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." - ) - self.logit_scale = logit_scale - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - return_dict=return_dict, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=use_cache, - ) - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - if self.logit_scale is not None: - if self.logit_scale == 0: - warnings.warn( - f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." - ) - logits *= self.logit_scale - loss = None - if labels is not None: - labels = torch.roll(labels, shifts=-1) - labels[:, -1] = -100 - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) - ) - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - if inputs_embeds is not None: - raise NotImplementedError("inputs_embeds is not implemented for MPT yet") - attention_mask = kwargs["attention_mask"].bool() - if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError( - "MPT does not support generation with right padding." - ) - if self.transformer.attn_uses_sequence_id and self.training: - sequence_id = torch.zeros_like(input_ids[:1]) - else: - sequence_id = None - if past_key_values is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if self.transformer.prefix_lm: - prefix_mask = torch.ones_like(attention_mask) - if kwargs.get("use_cache") is False: - raise NotImplementedError( - "MPT with prefix_lm=True does not support use_cache=False." - ) - else: - prefix_mask = None - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "prefix_mask": prefix_mask, - "sequence_id": sequence_id, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache", True), - } - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - """Used by HuggingFace generate when using beam search with kv-caching. - - See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 - for an example in transformers. - """ - reordered_past = [] - for layer_past in past_key_values: - reordered_past += [ - tuple( - (past_state.index_select(0, beam_idx) for past_state in layer_past) - ) - ] - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py deleted file mode 100644 index 06731a6f9..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ /dev/null @@ -1,796 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch GPTNeoX model.""" - -from typing import Optional, Tuple, Union - -import os -import torch -import torch.distributed -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - - -CUSTOM_KERNELS_ENABLED = False -if ( - torch.cuda.is_available() - and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" -): - try: - from custom_kernels import fused_attention_cuda - - CUSTOM_KERNELS_ENABLED = True - except ImportError: - pass - - -def make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.ones( - (target_length, target_length + past_key_values_length), - dtype=torch.bool, - device=device, - ) - mask = mask.triu(1 + past_key_values_length) - - expanded_mask = mask.unsqueeze(0).expand( - batch_size, target_length, target_length + past_key_values_length - ) - return expanded_mask - - -def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: - """ - Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. - """ - batch_size, src_length = mask.shape - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = ~(mask[:, None, :].to(torch.bool)) - return expanded_mask.expand(batch_size, tgt_length, src_length) - - -def prepare_attn_mask( - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, -) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -class GPTNeoXPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - -class GPTNeoXAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size * config.rotary_pct) - # ??? TODO - # self.register_buffer( - # "bias", - # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - # 1, 1, max_positions, max_positions - # ), - # ) - # self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, - config.max_position_embeddings, - base=config.rotary_emb_base, - ) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) - self.inv_norm_factor = 1.0 / torch.sqrt( - torch.tensor(self.head_size, dtype=torch.float32) - ).to(torch.get_default_dtype()) - - if self.num_attention_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_attention_heads` must be divisible by `num_shards` " - f"(got `num_attention_heads`: {self.num_attention_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_attention_heads = ( - self.num_attention_heads // weights.process_group.size() - ) - self.query_key_value = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True - ) - self.dense = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense", weights=weights, bias=True - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask, - head_mask=None, - layer_past=None, - use_cache=False, - output_attentions=False, - ): - has_layer_past = layer_past is not None - - # Compute QKV - # Attention heads [batch, seq_len, hidden_size] - # --> [batch, seq_len, (np * 3 * head_size)] - qkv = self.query_key_value(hidden_states) - - # [batch, seq_len, (num_heads * 3 * head_size)] - # --> [batch, seq_len, num_heads, 3 * head_size] - new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) - qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) - # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] - query, key, value = qkv.split(self.head_size, -1) - - # Compute token offset for rotary embeddings (when decoding) - seq_len = key.shape[-2] - if has_layer_past: - seq_len += layer_past[0].shape[-2] - - # Compute rotary embeddings on rotary_ndims - query_rot = query[..., : self.rotary_ndims] - key_rot = key[..., : self.rotary_ndims] - - query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) - - query[..., : self.rotary_ndims] = query_rot - key[..., : self.rotary_ndims] = key_rot - - if CUSTOM_KERNELS_ENABLED: - attn_output, present, attn_weights = fused_attention_cuda.forward( - query, - key, - value, - layer_past, - attention_mask, - head_mask, - self.inv_norm_factor, - self.num_attention_heads, - use_cache, - ) - else: - # Cache QKV values - if has_layer_past: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = (key, value) if use_cache else None - - # Compute attention - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask - ) - - # Reshape outputs - attn_output = self._merge_heads( - attn_output, self.num_attention_heads, self.head_size - ) - - attn_output = self.dense(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs - - @classmethod - def _split_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - # tensor: [bs, seq_len, hidden_size] - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(new_shape) - # -> [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3) - return tensor - - @classmethod - def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - # tensor [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3).contiguous() - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view( - tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size - ) - # -> [bs, seq_len, hidden_size] - return tensor - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] - # compute causal mask from causal mask buffer - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - - query = query.reshape( - batch_size * num_attention_heads, query_length, attn_head_size - ) - key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - 1, - dtype=query.dtype, - device=key.device, - ).expand(batch_size * num_attention_heads, query_length, key_length) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.inv_norm_factor, - ) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attn_scores.dtype - if input_dtype in [torch.float16, torch.bfloat16]: - attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where( - attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores - ) - attn_scores = attn_scores.view( - batch_size, num_attention_heads, query_length, key_length - ) - - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings, base=10000, device=None): - super().__init__() - self.true_inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2).float().to(device) / dim) - ) - self.register_buffer("inv_freq", self.true_inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - self.cos_cached = None - self.sin_cached = None - - @staticmethod - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - @staticmethod - def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange( - max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype - ) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) - - def forward(self, q, k, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if ( - seq_len > self.max_seq_len_cached - or self.cos_cached is None - or self.sin_cached is None - ): - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.cos_cached, self.sin_cached = self._create_cos_sin( - self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device - ) - return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) - - -@torch.jit.script -def rotary_forward(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) - - chunk_size = q.shape[-1] // 2 - q1, q2 = q.split(chunk_size, -1) - q_rotated = torch.cat((-q2, q1), dim=-1) - k1, k2 = k.split(chunk_size, -1) - k_rotated = torch.cat((-k2, k1), dim=-1) - - q_embed = (q * cos) + (q_rotated * sin) - k_embed = (k * cos) + (k_rotated * sin) - return q_embed, k_embed - - -class GPTNeoXMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.act = ( - ACT2FN[config.hidden_act] - if "gelu_fast" not in config.hidden_act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - self.dense_h_to_4h = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True - ) - self.dense_4h_to_h = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True - ) - - def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states - - -class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, prefix: str, config, weights): - super().__init__() - self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.input_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.attention = GPTNeoXAttention( - config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights - ) - self.mlp = GPTNeoXMLP( - config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask=None, - head_mask=None, - use_cache=False, - layer_past=None, - output_attentions=False, - ): - attention_layer_outputs = self.attention( - self.input_layernorm(hidden_states), - attention_mask=attention_mask, - position_ids=position_ids, - layer_past=layer_past, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attention_layer_outputs[ - 0 - ] # output_attn: attn_output, present, (attn_weights) - outputs = attention_layer_outputs[1:] - - if self.use_parallel_residual: - # pseudocode: - # x = x + attn(ln1(x)) + mlp(ln2(x)) - mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = mlp_output + attn_output + hidden_states - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - attn_output = attn_output + hidden_states - mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - hidden_states = mlp_output + attn_output - - if use_cache: - outputs = ( - hidden_states, - ) + outputs # hidden_states, present, (attn_weights) - else: - outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) - - return outputs - - -class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - self.config = config - - self.num_attention_heads = config.num_attention_heads - - self.embed_in = TensorParallelEmbedding( - prefix=f"{prefix}.embed_in", weights=weights - ) - self.layers = nn.ModuleList( - [ - GPTNeoXLayer(layer_id, prefix, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.tp_world_size = weights.process_group.size() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids=None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * self.config.num_hidden_layers) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_length, seq_length + past_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - - hidden_states = inputs_embeds - - # Attention mask. - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-1] - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), device=hidden_states.device - ) - else: - attention_mask = attention_mask.to(hidden_states.device) - - causal_mask = prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - assert self.num_attention_heads % self.tp_world_size == 0 - block_size = self.num_attention_heads // self.tp_world_size - causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - presents = () if use_cache else None - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = layer( - hidden_states, - position_ids=position_ids, - attention_mask=causal_mask, - head_mask=head_mask[i], - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_attentions = all_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.final_layer_norm(hidden_states) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_attentions] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "gpt_neox" - else: - prefix = f"{prefix}.gpt_neox" - - self.gpt_neox = GPTNeoXModel(prefix, config, weights) - self.embed_out = SpeculativeHead.load( - config, prefix="embed_out", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are - only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see - `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config.is_decoder = True - >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.gpt_neox( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - lm_logits, speculative_logits = self.embed_out(hidden_states) - - lm_loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # we are doing next-token prediction; shift prediction scores and input ids by one - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return ( - CausalLMOutputWithPast( - loss=lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - input_shape = input_ids.shape - - # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - ) - - return model_inputs - - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past[:2] - ) - + layer_past[2:], - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py deleted file mode 100644 index bd4403214..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ /dev/null @@ -1,857 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch OPT model.""" -import random -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers import OPTConfig -from text_generation_server.layers import ( - FastLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -EPS = 1e-5 - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full( - (tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device, - ) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -class OPTLearnedPositionalEmbedding(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, prefix: str, weights): - super().__init__() - self.offset = 2 - self.weight = nn.Parameter( - weights.get_tensor( - f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" - ) - ) - - def forward( - self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 - ): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = ( - torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask - ).long() - 1 - - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] - - return torch.nn.functional.embedding(positions + self.offset, self.weight) - - -class OPTAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config, - prefix, - weights, - is_decoder: bool = False, - bias: bool = True, - process_group=None, - ): - super().__init__() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - - self.hidden_size = hidden_size - self.num_heads = num_heads - self.dropout = config.dropout - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - process_group = weights.process_group - if self.num_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads = self.num_heads // process_group.size() - self.hidden_size = self.hidden_size // process_group.size() - - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias - ) - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - if attn_weights.dtype == torch.float16: - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(torch.float16) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): - super().__init__() - self.process_group = weights.process_group - self.hidden_size = config.hidden_size - prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" - self.self_attn = OPTAttention( - config, - prefix=f"{prefix}.self_attn", - weights=weights, - is_decoder=True, - bias=config.enable_bias, - ) - self.do_layer_norm_before = config.do_layer_norm_before - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - - self.self_attn_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS - ) - self.fc1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias - ) - self.fc2 = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Fully Connected - hidden_states_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - hidden_states = (residual + hidden_states).view(hidden_states_shape) - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class OPTPreTrainedModel(PreTrainedModel): - config_class = OPTConfig - - -class OPTDecoder(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.vocab_size = config.vocab_size - - prefix = prefix + "." if prefix else "" - - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}decoder.embed_tokens", weights=weights - ) - self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) - - if config.word_embed_proj_dim != config.hidden_size: - self.project_out = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_out", - weights=weights, - bias=False, - ) - else: - self.project_out = None - - if config.word_embed_proj_dim != config.hidden_size: - self.project_in = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_in", - weights=weights, - bias=False, - ) - else: - self.project_in = None - - # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility - # with checkpoints that have been fine-tuned before transformers v4.20.1 - # see https://github.com/facebookresearch/metaseq/pull/164 - if config.do_layer_norm_before and not config._remove_final_layer_norm: - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS - ) - else: - self.final_layer_norm = None - - self.layers = nn.ModuleList( - [ - OPTDecoderLayer(layer_id, prefix, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - past_key_values_length = ( - past_key_values[0][0].shape[2] if past_key_values is not None else 0 - ) - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - causal_attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) - - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - - hidden_states = inputs_embeds + pos_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): - continue - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.final_layer_norm is not None: - hidden_states = self.final_layer_norm(hidden_states) - - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class OPTModel(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.decoder = OPTDecoder(prefix, config, weights) - # Initialize weights and apply final processing - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs - - return BaseModelOutputWithPast( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - ) - - -class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, prefix, config, weights): - super().__init__(config) - - self.model = OPTModel(prefix, config, weights) - - self.lm_head = SpeculativeHead.load( - config, - prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", - weights=weights, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - - loss = None - - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py deleted file mode 100644 index 3f2ed010f..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ /dev/null @@ -1,336 +0,0 @@ -# imlementation of the PhiModel and PhiForCausalLM classes - -import torch -import torch.distributed - -import math -from torch import nn -from typing import Optional, List, Tuple -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast - -from text_generation_server.layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - SpeculativeHead, - FastLinear, -) - - -# PhiConfig is the configuration class for the PhiModel. -class PhiConfig(PretrainedConfig): - def __init__( - self, - vocab_size=51200, - n_positions=2048, - n_embd=2560, - n_layer=32, - n_inner=None, - n_head=32, - rotary_dim=32, - layer_norm_epsilon=1e-5, - tie_word_embeddings=False, - pad_vocab_size_multiple=64, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - no_bias=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.rotary_dim = rotary_dim - - self.layer_norm_epsilon = layer_norm_epsilon - self.tie_word_embeddings = tie_word_embeddings - self.pad_vocab_size_multiple = pad_vocab_size_multiple - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.no_bias = no_bias - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -# RotaryEmbedding is a class that implements the rotary embedding. -class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)] - inv_freq_len = len(inv_freq) - inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) - t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) - freqs = t.matmul(inv_freq) - self.sin = freqs.sin() - self.cos = freqs.cos() - - def apply_rotary_emb_qkv(self, qkv, seqlen_offset): - b_size, seqlen, three, _, _headdim = qkv.shape - if three != 3: - raise Exception("unexpected shape for qkv") - _, rotary_dim = self.cos.shape - rotary_dim = rotary_dim * 2 - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - q12 = torch.chunk(q_rot, 2, dim=-1) - k12 = torch.chunk(k_rot, 2, dim=-1) - q1, q2 = q12[0], q12[1] - k1, k2 = k12[0], k12[1] - c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - q_rot = torch.cat( - [ - q1 * c - q2 * s, - q1 * s + q2 * c, - ], - dim=-1, - ) - k_rot = torch.cat( - [ - k1 * c - k2 * s, - k1 * s + k2 * c, - ], - dim=-1, - ) - q = torch.cat([q_rot, q_pass], dim=-1) - k = torch.cat([k_rot, k_pass], dim=-1) - v = qkv[:, :, 2] - return q, k, v - - -# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm. -class PhiCausalLMHead(nn.Module): - def __init__(self, config, weights): - super().__init__() - self.ln = nn.LayerNorm.load( - prefix="lm_head.ln", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.linear = SpeculativeHead.load( - config=config, prefix="lm_head.linear", weights=weights - ) - - def forward(self, hidden_states): - hidden_states = self.ln(hidden_states) - hidden_states = self.linear(hidden_states) - return hidden_states - - -# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. -class PhiMHA(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - self.op_size = config.n_embd - self.head_dim = int(config.n_embd / config.n_head) - self.num_heads = config.n_head - self.rotary_emb = RotaryEmbedding( - config.rotary_dim, - config.n_positions, - ) - self.softmax_scale = 1.0 / math.sqrt(self.head_dim) - - def forward( - self, - hidden_states, - past_kv_cache, - attention_mask=None, - ): - b_size, seq_len, _n_embd = hidden_states.shape - qkv = self.Wqkv(hidden_states) - qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim) - seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1] - q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) - - # if there is a kv_cache, then we need to concatenate - if past_kv_cache is not None: - prev_k, prev_v = past_kv_cache - k = torch.cat([prev_k, k], dim=1) - v = torch.cat([prev_v, v], dim=1) - - past_kv_cache = [k, v] - attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale) - - if attention_mask is not None: - seqlen_k = k.shape[1] - seqlen_q = q.shape[1] - causal_mask = torch.triu( - torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), - 1, - ) - attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) - attn_output = ( - attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)) - .transpose(1, 2) - .flatten(-2) - ) - return self.out_proj(attn_output), past_kv_cache - - -# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. -class PhiMLP(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.n_inner = config.n_inner - self.fc1 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc1", - weights=weights, - bias=False, - ) - self.fc2 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc2", - weights=weights, - bias=False, - ) - self.activation = torch.nn.functional.gelu - - def forward(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. -class PhiBlock(nn.Module): - def __init__(self, layer_id, config, weights): - super().__init__() - self.layer_id = layer_id - self.layer_norm = nn.LayerNorm.load( - prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon - ) - self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) - self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) - - def forward( - self, - hidden_states, - kv_cache, - attention_mask, - ): - residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - attn_outputs, past_kv_cache = self.mixer( - hidden_states, kv_cache, attention_mask - ) - feed_forward_hidden_states = self.mlp(hidden_states) - out = attn_outputs + feed_forward_hidden_states + residual - return out, past_kv_cache - - -# PhiModel implements the embedding layer and the transformer blocks. -class PhiModel(nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - self.tp_rank = weights.process_group.rank() - self.tp_world_size = weights.process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embd.wte", weights=weights - ) - self.blocks = nn.ModuleList( - [ - PhiBlock(f"{prefix}.h.{layer_id}", config, weights) - for layer_id in range(config.n_layer) - ] - ) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - hidden_states = self.embed_tokens(input_ids) - seq_len = hidden_states.shape[1] - mask = None if seq_len <= 1 else attention_mask - - past_key_values = ( - [None] * len(self.blocks) if past_key_values is None else past_key_values - ) - - for index, block in enumerate(self.blocks): - hidden_states, new_key_values = block( - hidden_states, past_key_values[index], mask - ) - past_key_values[index] = new_key_values - - return hidden_states, past_key_values - - -# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. -class PhiForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - self.model = PhiModel(prefix, config, weights) - self.lm_head = PhiCausalLMHead(config, weights) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - model_output = self.model( - input_ids, past_key_values, attention_mask, return_dict, use_cache - ) - logits = self.lm_head(model_output[0]) - - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()( - logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1) - ) - - if not return_dict: - return ( - ((loss,) + (logits,) + model_output[1:]) - if loss is not None - else (logits,) + model_output[1:] - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=model_output[1], - hidden_states=None, - attentions=None, - ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py new file mode 100644 index 000000000..441b0016e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -0,0 +1,946 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +import numpy as np + +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +from typing import Union +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ( + ProcessingKwargs, + ProcessorMixin, + Unpack, + VideosKwargs, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, chat_template=None, **kwargs + ): + self.image_token = ( + "<|image_pad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.video_token = ( + "<|video_pad|>" + if not hasattr(tokenizer, "video_token") + else tokenizer.video_token + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[ + TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] + ] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor( + images=images, videos=None, **output_kwargs["images_kwargs"] + ) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor( + images=None, videos=videos, **output_kwargs["images_kwargs"] + ) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [ + self.image_processor.temporal_patch_size / fps + ] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [ + self.image_processor.temporal_patch_size / tmp for tmp in fps + ] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" + * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" + * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list( + dict.fromkeys(tokenizer_input_names + image_processor_input_names) + ) + return names_from_processor + ["second_per_grid_ts"] + + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + spatial_patch_size=14, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_patch_size = spatial_patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if vision_config is not None: + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute sdpa + query = query.unsqueeze(0).transpose(1, 2) + key = key.unsqueeze(0).transpose(1, 2) + value = value.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + self.up = TensorParallelColumnLinear.load( + prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True + ) + self.gate = TensorParallelColumnLinear.load( + prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True + ) + self.down = TensorParallelRowLinear.load( + prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_states = self.gate(hidden_states) + up_states = self.up(hidden_states) + activated_states = self.activation_fn(gate_states) * up_states + down_states = self.down(activated_states) + return down_states + + +class Qwen2_5VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2_5VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastRMSNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastRMSNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2_5VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, _ = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = hidden_states + attn_out + norm2_out, _ = self.norm2(hidden_states) + mlp_out = self.mlp(norm2_out) + hidden_states = hidden_states + mlp_out + return hidden_states + + +class Qwen2_5VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastRMSNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2_5VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_channels, + out_channels=config.hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.hidden_size // config.num_heads + + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2_5VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2_5VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + self.window_size = config.window_size + self.patch_size = config.patch_size + self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size + self.fullatt_block_indexes = config.fullatt_block_indexes + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + seq_len = hidden_states.shape[0] + patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + og_shape = (seq_len, -1) + + hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view( + og_shape + ) + rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view( + og_shape + ) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device="cpu", + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to( + hidden_states.device + ) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + + # iterately apply the blocks to the hidden states + for layer_num, block in enumerate(self.blocks): + # NOTE: qwen2_5_vl.py has a concept of full attention blocks + # that are applied at specific layers. + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = block( + hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen + ) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + +class Qwen2_5VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment + if ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and config.rope_scaling.get("type", None) == "default" + ): + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2_5VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.device = weights.device + + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + # import ipdb + + # ipdb.set_trace() + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + # Unused in this model + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py new file mode 100644 index 000000000..47ae2ac94 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -0,0 +1,519 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +import numpy as np + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.embed_dim // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute sdpa + query = query.unsqueeze(0).transpose(1, 2) + key = key.unsqueeze(0).transpose(1, 2) + value = value.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastLayerNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastLayerNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, residual = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = attn_out + residual + norm2_out, residual = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(norm2_out) + return hidden_states + + +class Qwen2VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.embed_dim * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastLayerNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.embed_dim // config.num_heads + # TODO: replace with static positional embeddings once implemented + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.embed_dim + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + # iterately apply the blocks to the hidden states + for block in self.blocks: + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + return hidden_states + + +class Qwen2VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2-VL model at the moment + if ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and config.rope_scaling.get("type", None) == "default" + ): + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.device = weights.device + + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py deleted file mode 100644 index e6666acd3..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ /dev/null @@ -1,1227 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch T5 model.""" - -import copy -import math -import warnings -from typing import Optional, Tuple, Union - -from loguru import logger - -import torch -import torch.distributed -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( - is_torch_fx_proxy, -) -from transformers import T5Config -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" - - -class PartialTPEmbedding(nn.Module): - def __init__(self, prefix: str, weights): - super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=1) - self.weight = nn.Parameter(weight) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.embedding(input, self.weight) - - -@torch.jit.script -def layer_norm(hidden_states, weight, epsilon): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + epsilon) - - # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(weight.dtype) - - return weight * hidden_states - - -class T5LayerNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = torch.tensor(eps) - - def forward(self, hidden_states): - return layer_norm(hidden_states, self.weight, self.variance_epsilon) - - -try: - from apex.normalization import FusedRMSNorm - - T5LayerNorm = FusedRMSNorm # noqa - - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" - ) -except ImportError: - # using the normal T5LayerNorm - pass -except Exception: - logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") - pass - -ALL_LAYERNORM_LAYERS.append(T5LayerNorm) - - -class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi", weights=weights, bias=False - ) - - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi_0 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_0", weights=weights, bias=False - ) - self.wi_1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_1", weights=weights, bias=False - ) - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - if config.is_gated_act: - self.DenseReluDense = T5DenseGatedActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - else: - self.DenseReluDense = T5DenseActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class T5Attention(nn.Module): - def __init__( - self, config: T5Config, prefix, weights, has_relative_attention_bias=False - ): - super().__init__() - self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - process_group = weights.process_group - # Mesh TensorFlow initialization to avoid scaling before softmax - assert self.n_heads % process_group.size() == 0 - self.q = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q", weights=weights, bias=False - ) - self.k = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k", weights=weights, bias=False - ) - self.v = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v", weights=weights, bias=False - ) - self.o = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.o", weights=weights, bias=False - ) - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // process_group.size() - self.inner_dim = self.inner_dim // process_group.size() - - if self.has_relative_attention_bias: - self.relative_attention_bias = PartialTPEmbedding( - prefix=f"{prefix}.relative_attention_bias", weights=weights - ) - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets - - def compute_bias(self, query_length, key_length, device=None): - """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length - ) - - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) - - def shape(states): - """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) - - def unshape(states): - """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device - ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = ( - position_bias + mask - ) # (batch_size, n_heads, seq_length, key_length) - - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class T5LayerSelfAttention(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias=False): - super().__init__() - self.SelfAttention = T5Attention( - config, - prefix=f"{prefix}.SelfAttention", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5LayerCrossAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.EncDecAttention = T5Attention( - config, - prefix=f"{prefix}.EncDecAttention", - weights=weights, - has_relative_attention_bias=False, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - query_length=query_length, - output_attentions=output_attentions, - ) - layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5Block(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): - super().__init__() - self.is_decoder = config.is_decoder - self.layer = nn.ModuleList() - self.layer.append( - T5LayerSelfAttention( - config, - prefix=f"{prefix}.layer.0", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - ) - if self.is_decoder: - i = 2 - self.layer.append( - T5LayerCrossAttention( - config, prefix=f"{prefix}.layer.1", weights=weights - ) - ) - else: - i = 1 - - self.layer.append( - T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) - ) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning( - "`past_key_values` is passed to the encoder. Please make sure this is intended." - ) - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[ - 2: - ] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - outputs = (hidden_states,) - - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - - -class T5PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = T5Config - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." - " See T5 docs for more information" - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id - - assert ( - pad_token_id is not None - ), "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - - -class T5Stack(T5PreTrainedModel): - def __init__(self, config, prefix, weights, embed_tokens): - super().__init__(config) - - self.is_decoder = config.is_decoder - - self.embed_tokens = embed_tokens - self.block = nn.ModuleList( - [ - T5Block( - config, - prefix=f"{prefix}.block.{layer_id}", - weights=weights, - has_relative_attention_bias=(layer_id == 0), - ) - for layer_id in range(config.num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - # Model parallel - use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" - ) - - if inputs_embeds is None: - assert ( - self.embed_tokens is not None - ), "You have to initialize the model with valid token embeddings" - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values[0][0].shape[2] + seq_length - if past_key_values is not None - else seq_length - ) - - if use_cache is True: - assert ( - self.is_decoder - ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" - - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - if ( - self.is_decoder - and encoder_attention_mask is None - and encoder_hidden_states is not None - ): - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, - encoder_seq_length, - device=inputs_embeds.device, - dtype=torch.long, - ) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.is_decoder and encoder_hidden_states is not None: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) - present_key_value_states = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds) - - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): - layer_head_mask = head_mask[i] - cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class T5ForConditionalGeneration(T5PreTrainedModel): - def __init__(self, config: T5Config, weights): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack( - config=encoder_config, - prefix="encoder", - weights=weights, - embed_tokens=self.shared, - ) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack( - config=decoder_config, - prefix="decoder", - weights=weights, - embed_tokens=self.shared, - ) - - try: - self.lm_head = SpeculativeHead.load( - config, prefix="lm_head", weights=weights - ) - except RuntimeError: - # Some models like t5-small were saved with shared weights unlike flan - # Since they are declared as the same arch we have no choice but hope - # that this is OK instead of using a proper flag. - self.lm_head = SpeculativeHead.load( - config, prefix="shared", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - logits, speculative_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # move labels to correct device to enable PP - labels = labels.to(logits.device) - loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - return ( - Seq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning( - "You might want to consider setting `use_cache=True` to speed up decoding" - ) - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) - return reordered_decoder_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py index e5c44045a..ae704af31 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, @@ -16,7 +16,13 @@ def load_text_model(prefix, config, weights, name=None): FlashGemmaForCausalLM, ) - return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + return FlashGemmaForCausalLM(prefix, config, weights) + elif config.model_type == "gemma2": + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + + return FlashGemma2ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index bc9d44a0b..a4d58596b 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext import math import os import time @@ -16,12 +15,19 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict - +from typing import ( + Any, + Iterable, + Optional, + Tuple, + List, + Type, + Dict, + Union, +) +import torch.nn.functional as F from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens @@ -39,27 +45,34 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( - MEM_POOL, - ATTENTION, BLOCK_SIZE, - CUDA_GRAPHS, + REQUEST_LOGPROBS, TGI_WIGGLE_ROOM, get_adapter_to_index, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import ( + KVCache, + Seqlen, + HPUPagedAttentionMetadata, + trim_attn_metadata, + trim_seqlen_metadata, +) from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments - from text_generation_server.utils.import_utils import ( empty_cache, synchronize, get_free_memory, ) -tracer = trace.get_tracer(__name__) +import vllm_hpu_extension.environment as environment +import habana_frameworks.torch as htorch +import itertools +from vllm_hpu_extension.ops import batch2block, block2batch +tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -75,38 +88,75 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW -def init_cpu_threads_env(rank_id: int, world_size: int): - import importlib.util +def prepare_for_decode( + dtype, use_contiguous_pa, device, slot, block_tables, batch_size +): + # Prepare values if we need to continue decoding + # need for HPUPagedAttentionMetadata preparation + def flatten(in_list): + return list(itertools.chain(*in_list)) - if importlib.util.find_spec("numa") is not None: - import numa - import psutil + def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] - nodes = numa.info.get_max_node() + 1 - rank_per_node = math.ceil(world_size / nodes) - num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int(rank_id / rank_per_node) - rank_offset_per_node = rank_id % rank_per_node - if os.getenv("OMP_NUM_THREADS") is None: - num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) - else: - num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) - if len(numa.memory.get_membind_nodes()) == nodes: - numa.memory.set_membind_nodes((node_id)) - torch.set_num_threads(num_cpus_per_rank) - if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True): - cpu_start = num_cpus_per_rank * rank_offset_per_node - numa.schedule.run_on_cpus( - 0, - *( - numa.info.node_to_cpus(node_id)[ - cpu_start : cpu_start + num_cpus_per_rank - ] - ), - ) - logger.info( - f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}" + def pad_list(input, k, v): + input_len = len(input) + target_len = (input_len + k - 1) // k * k + padding = target_len - input_len + return input + [v] * padding + + last_block_usage = slot % BLOCK_SIZE + 1 + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [ + [BLOCK_SIZE] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt + ] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + if use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + # block_bucket_size) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + block_list = gather_list(block_list, indices, 0) + block_groups = gather_list(block_groups, indices, -1) + block_usage = gather_list(block_usage, indices, 1) + else: + block_bucket_size = len(block_list) + block_list = pad_list(block_list, block_bucket_size, 0) + block_groups = pad_list(block_groups, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + + block_list = torch.tensor(block_list, dtype=torch.int, device=device) + block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) + block_usage = torch.tensor(block_usage, dtype=dtype, device=device) + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + ones = torch.ones( + (block_mapping.size(0),), device=device, dtype=block_mapping.dtype + ) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + return trim_attn_metadata( + HPUPagedAttentionMetadata( + block_list=block_list, + block_groups=block_groups, + block_usage=block_usage, + block_mapping=block_mapping.to(dtype), + attn_bias=attn_bias, + block_scales=block_scales, ) + ) @dataclass @@ -117,25 +167,17 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] - - # Paged Attention values - # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] @@ -143,19 +185,32 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor + # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch + # used for filtering + cu_slots: torch.Tensor - max_seqlen: int + max_input_length: int + max_current_length: int + + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] - - # Prefixes - prefix_ids: List[List[int]] + # Will be set by `generate_token` and reset after each prefill forward + prefill_logprob_tokens: List[Optional[Tokens]] # All tokens all_input_ids: List[List[int]] @@ -163,7 +218,14 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - input_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + cache_lengths: List[int] + prompt_lengths: List[int] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -174,19 +236,27 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Adapter metadata for each request - adapter_meta: AdapterBatchMetadata + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + adapter_meta: Optional[AdapterBatchMetadata] # Number of blocks in this batch num_blocks: int # Maximum number of blocks max_blocks: int + hpu_attn_meta: Optional[HPUPagedAttentionMetadata] + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) @classmethod @@ -218,86 +288,67 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - sliding_window = get_sliding_windows() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - + cache_lengths = [] input_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] - prefix_ids = [] + all_postfix_ids = [] requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] + slots = [] + cu_slots = [0] next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 + cu_blocks = [0] block_tables = [] - slots = [] - prefix_lens = [] + block_tables_ragged = [] # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): + ### XXX: This consumes so much memory on long requests + ### Deactivating it by default seems like the best course. + if not REQUEST_LOGPROBS: + r.prefill_logprobs = False # request id -> idx in list mapping requests_idx_mapping[r.id] = i - orig_input_length = len(tokenized_input) + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) + + cache_length = r.cache_len - prefix_len = r.prefix_len assert ( - prefix_len <= orig_input_length - ), f"Prefix {prefix_len} vs input {orig_input_length}" - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: + assert False, "unreachable" - # Commented as it's costly. - # log_master(logger.debug, "Tokenized input ids {tokenized_input}") - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - input_length = len(tokenized_input) input_lengths.append(input_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( @@ -307,22 +358,13 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((input_length,), adapter_index)) - adapter_set.add(adapter_index) - # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length - - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + block_tokens = prompt_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: @@ -337,70 +379,30 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length - ] + request_slots = r.slots block_tables.append(request_blocks) + block_tables_ragged.extend(request_blocks) + cu_blocks.append(len(block_tables_ragged)) slots.extend(request_slots) - prefix_lens.append(prefix_len) + cu_slots.append(len(slots)) + + cache_lengths.append(cache_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length - cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) max_length = max( - max_length, input_length + max_new_tokens + speculative_length + max_length, + prompt_length + max_new_tokens + speculative_length, ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -414,103 +416,71 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor, dtype=torch.int64, device=device ) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - - block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + block_tables_ragged = torch.tensor( + block_tables_ragged, device=device, dtype=torch.int32 ) + cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_tensor = torch.empty( + (len(block_tables), max_blocks), + device=device, + dtype=torch.int32, + ) + for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) + + slots = torch.tensor(slots, dtype=torch.int64, device=device) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, - slot_indices=slot_indices, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + cache_lengths=cache_lengths, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + slot_indices=None, + slots=slots, + cu_slots=cu_slots, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, + adapter_meta=None, + hpu_attn_meta=None, ) @classmethod @@ -533,7 +503,7 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -548,18 +518,23 @@ class FlashCausalLMBatch(Batch): # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 requests = [] - start_slots = [] block_tables = [] all_input_ids = [] - prefix_ids = [] + input_ids = [] + prompt_lengths = [] input_lengths = [] - prefix_lens = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] + cu_slots = [0] + + prefilling_mask = [] + prefill_logprob_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -567,8 +542,8 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 + max_slots = 0 + cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -577,16 +552,23 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling_mask.append(request_prefilling) + # Get length request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] - max_seqlen = max(max_seqlen, request_input_length) + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) + max_current_length = max( + max_current_length, request_cache_length + request_input_length + ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) + prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -594,60 +576,78 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + start_slot = self.cu_slots[idx] + end_slot = self.cu_slots[idx + 1] + slot_length = end_slot - start_slot # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True + slot_filtering_indices[start_slot:end_slot] = True - cumulative_max_length += request_input_length + remaining_tokens - 1 + cu_slots.append(cumulative_slot_tokens + slot_length) + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_slot_tokens + request_cache_length + + cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + slots = self.slots[slot_filtering_indices] - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] + + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -657,24 +657,29 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + cu_slots=cu_slots, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=self.prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -682,12 +687,8 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + hpu_attn_meta=None, ) @classmethod @@ -697,74 +698,105 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) + max_blocks = max(max_blocks, b.max_blocks) total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + max_input_length = max(max_input_length, b.max_input_length) + max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - input_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - b.input_lengths, b.stopping_criterias + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias ) ), ) + prefilling = prefilling or b.prefilling - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64) + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + if ( + batches[0].position_ids is not None + and batches[0].position_ids.dim() == 2 + ): + # Qwen2_vl case: + position_ids = batches[0].position_ids.new_empty( + (total_batch_size, batches[0].position_ids.shape[-1]) + ) + else: + position_ids = batches[0].position_ids.new_empty(total_batch_size) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() + + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - start_slots = [] block_tables = [] - prefix_lens = [] + cache_lengths = [] all_input_ids = [] - prefix_ids = [] + prompt_lengths = [] input_lengths = [] prefix_offsets = [] read_offsets = [] + prefill_logprob_tokens = [] + next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -783,32 +815,9 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices - ) - all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -816,20 +825,56 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + slots[slots_start_index:slots_end_index] = batch.slots + cu_slots[start_index + 1 : end_index + 1] = ( + batch.cu_slots[1:] + cumulative_slots + ) - start_slots.append(batch.start_slots + cumulative_slots) + if not prefilling: + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) + + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) @@ -837,12 +882,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens.extend(batch.top_n_tokens) # Update - cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) - - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + cumulative_batch_size += len(batch) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -852,13 +893,21 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states=fsm_grammar_states, ) - speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) - if batches[0].speculative_ids is not None - else None - ) + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: + speculative_ids = None - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -868,24 +917,29 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + cu_slots=cu_slots, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -893,12 +947,286 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + hpu_attn_meta=None, + ) + + def prepare_for_decode(self, dtype, use_contiguous_pa): + block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables = [] + for i, bt in enumerate(self.block_tables): + block_tables.append(bt[0 : block_num[i]]) + + self.hpu_attn_meta = prepare_for_decode( + dtype, + use_contiguous_pa, + self.block_tables_tensor.device, + self.slots[self.slot_indices], + block_tables, + self.input_ids.size(0), + ) + + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + device = self.block_tables_tensor.device + + # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position + # padding to left to work with sliding window + # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate + # the right logit position + input_ids_padded_length = [] + # need extra pad to match warmup seq + extra_pad = 0 + if isinstance(self.input_ids, list) and len(self) > 1: + input_ids_padded_length = [] + input_ids = [] + for input_id in self.input_ids: + padded = self.max_input_length - len(input_id) + extra_pad + if padded > 0: + input_id = [0] * padded + input_id + input_ids.append(input_id) + input_ids_padded_length.append(padded) + input_ids = np.concatenate(input_ids, dtype=np.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + elif isinstance(self.input_ids, list): + input_ids = self.input_ids[0] + input_ids_padded_length.append(extra_pad) + input_ids = [0] * extra_pad + input_ids + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + else: + self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) + input_ids_padded_length.append(extra_pad) + + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1) + torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0) + self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + + sliding_window = get_sliding_windows() + position_ids = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + blocks, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables, + ) + ): + next_chunk_length = input_length + + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + request_position_ids = F.pad( + request_position_ids, (input_ids_padded_length[i], 0), value=1 + ) + position_ids.append(request_position_ids) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slot_indices = torch.arange( + cache_length + cumulative_slot_tokens, + cache_length + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + slot_indices.append(request_slot_indices) + + # Update + cumulative_slot_tokens += len(request_slots) + + # Create tensor to slice into the kv tensor in prefill + # hpu need request_prefill_cache_indices to skip padding in kv cache + sliding_window = get_sliding_windows() + if sliding_window is None: + sliding_window = input_length + cumulative_length += input_ids_padded_length[i] + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + if ADAPTER_TO_INDEX: + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((next_chunk_length,), adapter_index) + ) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + + if not all_prefill_logprobs and not no_prefill_logprobs: + prefill_head_indices = [] + prefill_next_token_indices = [] + + # Cumulative length + cumulative_length = 0 + prefill_out_cumulative_length = 0 + + for i, ( + r, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.input_lengths, + self.prefilling_mask, + ) + ): + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length + + if len(self) > 1: + if position_ids: + position_ids = torch.cat(position_ids) + if slot_indices: + slot_indices = torch.cat(slot_indices) + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + if position_ids: + position_ids = position_ids[0] + if slot_indices: + slot_indices = slot_indices[0] + prefill_cache_indices = prefill_cache_indices[0] + + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + + self.prefill_cu_outlens = prefill_cu_outlens + self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices[prefill_cache_indices.to(device)] = True + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.cat(prefill_head_indices).to(device) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + input_ids_padded_length_tensor = torch.cumsum( + torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), + dim=-1, + ) + if self.prefill_head_indices is not None: + self.prefill_head_indices = ( + self.prefill_head_indices + input_ids_padded_length_tensor + ) + + if self.prefill_next_token_indices is not None: + self.prefill_next_token_indices = ( + self.prefill_next_token_indices + input_ids_padded_length_tensor + ) + + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] + + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, ) def __len__(self): @@ -937,23 +1265,14 @@ class FlashCausalLM(Model): # Deepseek V2 uses different QK and V dims. head_size: Optional[int] = None, skip_special_tokens: bool = True, + kv_cache_dtype: Optional[torch.dtype] = None, + support_chunking: bool = True, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - init_cpu_threads_env(rank_id=rank, world_size=world_size) - else: - raise NotImplementedError(f"{model_class} is only available on GPU") + + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = tokenizer_class.from_pretrained( model_id, @@ -991,7 +1310,7 @@ class FlashCausalLM(Model): weights_loader=weights_loader, ) - prefix = "" + prefix = None model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) @@ -1007,6 +1326,7 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() + self.config = config # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1034,25 +1354,14 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import ( - create_prefill_state, - create_decode_state, - create_prefill_with_paged_kv_state, - ) - - self.prefill_state = create_prefill_state(device=device) - self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( - device=device - ) - - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - + if htorch.utils.internal.is_lazy(): + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) + environment.set_model_config(self.config) + self.use_contiguous_pa = ( + os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" + ) super().__init__( model_id=model_id, model=model, @@ -1063,6 +1372,7 @@ class FlashCausalLM(Model): rank=rank, world_size=world_size, sliding_window=config.sliding_window, + support_chunking=support_chunking, ) @property @@ -1083,317 +1393,126 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "ipex" and device.type == "xpu": - x = 1 - else: - x = BLOCK_SIZE // element_size - - if ATTENTION in {"flashdecoding", "flashinfer"}: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - elif SYSTEM == "ipex" and device == torch.device("cpu"): - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - else: - self.kv_cache = [ - ( - torch.zeros( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.zeros( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = [max_s] * bs - prefix_lengths = [0] * bs - input_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - ) - prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).repeat(bs) - block_tables = block_tables.reshape((bs, max_bt)) - - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lengths, - ) - from text_generation_server.layers.attention.flashinfer import ( - create_decode_state_cuda_graphs, + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, ) + for _ in range(num_layers) + ] - block_tables_ptr = torch.zeros( - bs + 1, dtype=torch.int32, device=self.device - ) - last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) - state = create_decode_state_cuda_graphs( - device=input_ids.device, - block_tables=block_tables, - block_tables_ptr=block_tables_ptr, - last_page_len=last_page_len, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - else: - state = None - - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": self.kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths_tensor, - "prefix_lengths": prefix_lengths_tensor, - "state": state, - "graph": graph, - } - - torch.cuda.synchronize() - # Run once outside to warmup - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, - state=state, - prefix_lens_tensor=prefix_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - del seqlen - - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits - torch.cuda.synchronize() - - def warmup(self, batch: FlashCausalLMBatch): + def warmup( + self, + batch: FlashCausalLMBatch, + max_input_tokens: Optional[int], + max_total_tokens: Optional[int], + ): # The warmup batch is the biggest batch we could ever receive + self.kv_cache = [] empty_cache() + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + try: self.init_kv_cache( batch.num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) - max_bt = batch.max_blocks - max_s = max_bt * BLOCK_SIZE - if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - torch.cuda.tunable.tuning_enable(False) - _, batch, _ = self.generate_token(batch) - except torch.cuda.OutOfMemoryError as e: + batch_num_blocks = batch.num_blocks + + num_tokens = batch.to_pb().current_tokens + synchronize(self.device) + free_memory = get_free_memory( + self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM + ) + real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) + log_master( + logger.debug, + f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", + ) + + _, _batch, _ = self.generate_token([batch]) + except Exception: raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"Not enough memory to handle {num_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" - ) from e + ) synchronize(self.device) - - # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) - # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - batch_num_blocks = batch.num_blocks if batch is not None else 0 - + free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) + kv_memory = free_memory num_blocks = ( # Leave 5% for some wiggle room - int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + int(kv_memory // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + batch_num_blocks ) - del batch + log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") + if max_total_tokens is None: + max_total_tokens = sum(batch.cache_lengths) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + + del _batch, batch + self.kv_cache = [] + empty_cache() self.init_kv_cache( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - if SYSTEM == "rocm": - if ( - os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None - or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" - ): - torch.cuda.tunable.enable() + def warmup_prefill(self, prompt_len: int, bs: int): + input_ids = torch.zeros( + prompt_len, dtype=torch.int64, device=self.device + ).repeat(bs) + position_ids = torch.arange( + prompt_len, dtype=torch.int32, device=self.device + ).repeat(bs) + max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slot_acc = [] + for i in range(bs): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) - if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": - torch.cuda.tunable.tuning_enable(True) - - if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None: - tuning_sequences = [ - int(val) - for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") - ] - elif CUDA_GRAPHS is not None: - tuning_sequences = CUDA_GRAPHS - else: - tuning_sequences = [1, 2, 3, 4, 5, 6, 7] - - tunableop_filepath = os.path.join( - HUGGINGFACE_HUB_CACHE, - f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", - ) - - log_master( - logger.info, - f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", - ) - - torch.cuda.tunable.set_filename( - tunableop_filepath, insert_device_ordinal=False - ) - - if os.path.isfile(tunableop_filepath): - log_master( - logger.info, - f"The file {tunableop_filepath} already exists and will be reused.", - ) - torch.cuda.tunable.read_file(tunableop_filepath) - - os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) - - for seqlen in tuning_sequences: - log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) - torch.cuda.tunable.write_file(tunableop_filepath) - if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": - torch.cuda.tunable.tuning_enable(False) - else: - log_master( - logger.info, - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", - ) - - if CUDA_GRAPHS: - try: - log_master( - logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" - ) - # Warmup cuda graphs - for bs in CUDA_GRAPHS: - if self.speculate is None or self.speculate + 1 <= bs: - self.cuda_graph_warmup(bs, max_s, max_bt) - except torch.cuda.OutOfMemoryError: - logger.exception("Decode cuda graph warmup failed") - else: - log_master( - logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." - ) - - return int(num_blocks * BLOCK_SIZE) - - def tunableop_warmup(self, seqlen: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 + input_lengths = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len ) - max_s = seqlen + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=1, - max_k=seqlen, ) + lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1401,12 +1520,64 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, - seqlen=seqlen, slots=slots, - max_s=max_s, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=lm_head_indices, + adapter_data=None, + hpu_attention_meta=None, + ) + + def warmup_decode(self, bs: int, block_num: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + start=1, end=block_num + 1, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slots = [] + past_len = ( + len(block_tables[0]) * BLOCK_SIZE - 1 + ) # for decode, we only need to pass the past token + # fetch the last blocked to warmup block num + for i in range(bs): + slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + ) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + block_num = cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables_valid = [] + for i, bt in enumerate(block_tables.tolist()): + block_tables_valid.append(bt[0 : block_num[i]]) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables_valid, + bs, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), lm_head_indices=None, - prefill_cache_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, ) def forward( @@ -1421,7 +1592,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -1436,12 +1607,20 @@ class FlashCausalLM(Model): new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated + slot_indices = ( + batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + slots = batch.slots[slot_indices] + input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1463,8 +1642,8 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -1473,105 +1652,48 @@ class FlashCausalLM(Model): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - - if cu_seqlen_prefill is not None or cuda_graph is None: - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, - ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - # assert block_tables.shape[0] >= slots.shape[0] - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - - # XXX: This is working only because block 0 is reserved for the healthcheck - # so it doesn't matter if we override it with bogus values. - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor - - with self._forward_context( - block_tables=cuda_graph["block_tables"], - cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - prefix_lens_tensor=cuda_graph["prefix_lengths"], - state=cuda_graph["state"], - ): - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=lm_head_indices, + # TODO not support adapter now, need the add in the future + adapter_data=None, + hpu_attention_meta=batch.hpu_attn_meta, + **kwargs, ) - logits = cuda_graph["logits"][:bs] return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batches: List[FlashCausalLMBatch] ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + if len(batches) > 1: + batch = self.batch_type.concatenate(batches) + else: + batch = batches[0] start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + else: + batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) prefill_logprobs = batch.prefill_next_token_indices is not None - # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta if batch.speculative_ids is not None: @@ -1611,13 +1733,23 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices + + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) + + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask speculate = get_speculate() ( @@ -1627,7 +1759,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculate, batch.speculative_ids, @@ -1638,85 +1770,110 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: + indices = batch.cu_seqlen_prefill[1:] - 1 + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ + indices + ] + else: + batch.position_ids = batch.position_ids[indices] - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - else: - prefill_logprobs = None - next_position_ids = batch.position_ids - - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - stopped = True + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ + indices + ] # Zipped iterator - iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.all_input_ids, + accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, + ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a GPU <-> CPU sync + # one, we need to first do a HPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - index = 0 - for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if prefill: + # Cumulative length + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + cumulative_length = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + all_input_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, + ) in enumerate(iterator): + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 ] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids - # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: - if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = ( - batch.input_ids[start_index + 1 : start_index + out_length] - ) - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] - - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] - index += 1 - + # If the device does not support triton, we copy one by one + if not request_is_prefilling: + # Only save tokens if we are done prefilling for this request + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] cumulative_length += input_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices + # These values can be updated without a HPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.speculative_ids = speculative_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) + batch.slot_indices += accepted_ids - if prefill: + if prefill and prefill_logprobs: + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out + prefill_logprobs = torch.gather( + prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) + ) + # HPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # Does a HPU <-> CPU sync internally + if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1725,192 +1882,282 @@ class FlashCausalLM(Model): device=batch.adapter_meta.adapter_segments.device, ) - if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) - prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) - ) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - - # GPU <-> CPU sync + # HPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + next_cache_length : next_cache_length + next_chunk_length + ] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + start_decode = time.time_ns() + # Results + generations: List[Generation] = [] + stopped = True + # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.cache_lengths, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + current_prefilling_mask, + batch.prefilling_mask, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch index = 0 for i, ( request, + prompt_length, + cache_length, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, top_n_tokens, + request_was_prefilling, + request_is_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: # Prefill - if prefill and request.prefill_logprobs: + if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[ + out_start_index:out_end_index + ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[ + cache_length + 1 : cache_length + input_length + 1 + ] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * ( + cache_length + 1 + ) + request_prefill_logprobs + prefill_token_ids = ( + all_input_ids[: cache_length + 1] + prefill_token_ids + ) - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ( - [float("nan")] * (len(prefix_ids) + 1) - ) + prefill_logprobs[out_start_index : out_end_index - 1] - prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, + prefill_logprob_tokens = Tokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens + + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - top_tokens = None + batch.prefill_logprob_tokens[i] = None - generation = Generation( - request.id, - prefill_tokens, - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length + else: + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 + # Append next token to all tokens + next_token_texts = [] + left = 0 - generations.append(generation) + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - # accept each new token for this specific request since we may - # have more than one new token per request with speculative decoding - for next_token_id in _next_token_ids: - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) - ) + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single( + i, next_token_id + ) + ) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + index += n_accepted_ids + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1921,83 +2168,14 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - def _forward_context( - self, - *, - block_tables: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, - state: Optional[Any] = None, - ) -> ContextManager: - if ATTENTION != "flashinfer": - return nullcontext() - - from text_generation_server.layers.attention.flashinfer import ( - use_decode_state, - use_prefill_with_paged_kv_state, - ) - - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) - - if cu_seqlen_prefill is not None: - return use_prefill_with_paged_kv_state( - state=( - state if state is not None else self.prefill_with_paged_kv_state - ), - # block_tables=block_tables_to_ragged( - # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, - # ), - block_tables=block_tables, - cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) - else: - assert input_lengths_tensor is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, - block_tables=block_tables, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) - - -def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] -) -> torch.Tensor: - """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) - - total_len = sum(input_lengths) + sum(prefix_lens) - block_tables_ragged = torch.empty( - total_len, dtype=torch.int32, device=block_tables.device - ) - - offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length - block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] - offset += seq_len - - return block_tables_ragged diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py new file mode 100644 index 000000000..208ab3582 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -0,0 +1,489 @@ +import torch +from PIL import Image +from io import BytesIO + +from opentelemetry import trace +from typing import Iterable, Optional, Tuple, List, Type, Dict + +from transformers import PreTrainedTokenizerBase +from transformers.image_processing_utils import select_best_resolution +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, +) +from text_generation_server.models.globals import PREFIX_CACHING +from loguru import logger +from text_generation_server.utils.log import log_master +from transformers import AutoProcessor +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch + +tracer = trace.get_tracer(__name__) + +IDEFICS2_FAKE_TOKEN = "" +IDEFICS2_IMAGE_TOKEN = "" + +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 +def _prompt_split_image( + *, + image_seq_len: int, + image_rows: int, + image_cols: int, + fake_token_around_image: str, + image_token: str, + global_img_token: str, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (height, width). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_text_replacement(processor, image_input, config, image_id: int) -> str: + if config.model_type == "idefics2": + image_seq_len = 64 + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" + if processor.image_processor.do_image_splitting: + image_str *= 5 + return image_str + if config.model_type == "idefics3": + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) + image_str = _prompt_split_image( + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) + return image_str + elif config.model_type == "llava_next": + height, width = image_input["image_sizes"][image_id] + num_features = get_number_of_features(height, width, config) + + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", + ) + return "" * num_features + + elif config.model_type == "paligemma": + return "" * config.text_config.num_image_tokens + elif config.model_type == "qwen2_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "qwen2_5_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "gemma3": + # TODO: get correct number of features via reviewing the Gemma3 architecture + # and calculating the number of image tokens + num_pads = 256 + padding = "" * num_pads + return f"\n\n{padding}\n\n" + else: + raise RuntimeError(f"Unknown config {config.model_type} for multimodal") + + +def image_text_replacement_fixup(config, text: str) -> str: + if config.model_type == "idefics2": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) + return text + + +def get_unpadded_features( + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = original_width / original_height + current_aspect_ratio: float = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +def get_number_of_features(height: int, width: int, config) -> int: + # From config + # Hardcoded for CLIP for now + # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + image_grid_pinpoints = config.image_grid_pinpoints + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + + assert image_size % patch_size == 0 + + npatches = image_size // patch_size + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + [height, width], + image_grid_pinpoints, + image_size, + ) + unpadded_features, newline_features = get_unpadded_features( + height, width, npatches, num_patch_height, num_patch_width + ) + # The base patch covers the entire image + base_features = npatches**2 + return unpadded_features + newline_features + base_features + + +class FlashVlmCausalLMBatch(FlashCausalLMBatch): + pixel_values: Optional[List[torch.Tensor]] + pixel_attention_mask: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] + image_grid_thw: Optional[torch.Tensor] + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + batch = super().filter(request_ids) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): + # Process images first. We need all of them so that the processor + # can make the image splits the same size. And we need the final + # sizes to insert correct number of image tokens. + images = [] + for r in requests: + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + pass + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the + # default warmup image is 20x20 + if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if image.width <= 20: + w = image.width * 2 + h = image.height * 2 + image = image.resize((w, h)) + + if config.model_type == "llava_next": + images.append(image) + elif config.model_type == "gemma3": + images.append(image) + else: + images.append([image]) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + if images: + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + image_inputs = processor.image_processor( + images, return_tensors="pt", **kwargs + ) + else: + image_inputs = None + + batch_tokenized_inputs = [] + max_length = 0 + image_id = 0 + for r in requests: + full_text = "" + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + full_text += image_text_replacement( + processor, image_inputs, config, image_id + ) + image_id += 1 + + full_text = image_text_replacement_fixup(config, full_text) + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) + + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashVlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to(device=device) + if "pixel_attention_mask" in image_inputs: + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + else: + batch.pixel_attention_mask = None + if "image_sizes" in image_inputs: + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.image_sizes = None + if "image_grid_thw" in image_inputs: + batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) + else: + batch.image_grid_thw = None + else: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + +class FlashVlmCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=FlashVlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, + **kwargs, + ) + + @property + def batch_type(self) -> Type[FlashVlmCausalLMBatch]: + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) + + def forward( + self, + batch: FlashVlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if position_ids.dim() == 1 and batch.prefilling: + position_ids = self.model.get_position_ids( + input_ids, batch.image_grid_thw + ) + batch.position_ids = position_ids + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, + **kwargs, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index 30a5d3da4..cd221e148 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -1,53 +1,31 @@ -import torch import os from typing import Dict, Optional from loguru import logger from text_generation_server.utils.log import log_master +REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.getenv("ATTENTION", "default") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } -PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer", "default"} +_expected = {"paged", "default"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: - raise RuntimeError("Prefix caching is only supported with flashinfer") - -MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 # This is overridden by the cli BLOCK_SIZE: int -if ATTENTION == "flashdecoding": - BLOCK_SIZE = 256 -elif ATTENTION == "flashinfer": - BLOCK_SIZE = 1 -else: - BLOCK_SIZE = 16 -# This is overridden by the cli -cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None: - try: - cuda_graphs = [int(item) for item in cuda_graphs.split(",")] - except Exception as e: - raise RuntimeError( - f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" - ) -else: - cuda_graphs = None +BLOCK_SIZE = 128 -CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. global MODEL_ID diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py index 9a7a6fe15..98d7352a8 100644 --- a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py @@ -34,9 +34,6 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.quantization import get_loader -from text_generation_server.utils.import_utils import SYSTEM - - tracer = trace.get_tracer(__name__) @@ -596,22 +593,8 @@ class IdeficsCausalLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - # 9b seems to work correctly enough in float16, but 80b seems - # to be really saturating for f16. - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype self.device, self.dtype = device, dtype config = AutoConfig.from_pretrained( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 9e19e1715..e034ed492 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -1,28 +1,30 @@ -from io import BytesIO -from PIL import Image import torch + +import numpy as np + from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request - +from io import BytesIO +from PIL import Image from dataclasses import dataclass from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import ( - block_tables_to_ragged, -) -from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION -from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + FlashVlmCausalLM, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch tracer = trace.get_tracer(__name__) @dataclass -class MllamaCausalLMBatch(VlmCausalLMBatch): +class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): image_indices: List[int] = 42 aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None @@ -158,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): config, dtype: torch.dtype, device: torch.device, - ) -> "VlmCausalLMBatch": + ) -> "FlashVlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( pb.requests, tokenizer, processor, config ) @@ -167,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( max=config.text_config.vocab_size - 1 ) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: @@ -187,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): return batch -class MllamaCausalLM(VlmCausalLM): +class FlashMllamaCausalLM(FlashVlmCausalLM): def forward( self, - batch: VlmCausalLMBatch, + batch: FlashMllamaCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward @@ -202,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -221,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -244,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -254,104 +263,46 @@ class MllamaCausalLM(VlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] - # Try to find an associated cuda graph - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - if ( - cu_seqlen_prefill is not None - or cuda_graph is None - # Only run cuda graphs when there's no images. - or batch.cross_attention_states is not None - ): - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, - ) + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) - if batch.pixel_values is not None: - cross_attention_states = self.model.vision_forward( - pixel_values=batch.pixel_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - ) - batch.cross_attention_states = cross_attention_states - - cross_attention_states = batch.cross_attention_states - - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, - adapter_data=adapter_data, - image_indices=batch.image_indices[:], - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, ) - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + batch.cross_attention_states = cross_attention_states - # Replay the graph - cuda_graph["graph"].replay() + cross_attention_states = batch.cross_attention_states - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, + lm_head_indices=lm_head_indices, + cross_attention_states=cross_attention_states, + # TODO list + adapter_data=None, + image_indices=batch.image_indices[:], + **kwargs, ) - logits = cuda_graph["logits"][:bs] + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py index 4fda22713..66c69bc1f 100644 --- a/backends/gaudi/server/text_generation_server/models/model.py +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -33,6 +33,7 @@ class Model(ABC): sliding_window: Optional[int] = None, speculate: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + support_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py index fe75570ea..e91aaed99 100644 --- a/backends/gaudi/server/text_generation_server/models/pali_gemma.py +++ b/backends/gaudi/server/text_generation_server/models/pali_gemma.py @@ -4,8 +4,8 @@ import torch import torch.distributed from opentelemetry import trace from typing import Iterable -from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, +from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, image_text_replacement, ) @@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request tracer = trace.get_tracer(__name__) -class PaliGemmaBatch(VlmCausalLMBatch): +class PaliGemmaBatch(FlashVlmCausalLMBatch): @classmethod def batch_tokenized_inputs( cls, requests: Iterable[Request], tokenizer, processor, config diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py index 04d4c28ba..0ee6ed167 100644 --- a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py +++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py @@ -10,7 +10,6 @@ from transformers import ( AutoConfig, ) from typing import Optional, Tuple, List, Type, Dict -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -555,20 +554,9 @@ class Seq2SeqLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype config = config_class.from_pretrained( model_id, @@ -600,7 +588,7 @@ class Seq2SeqLM(Model): aliases=aliases, weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + if config.quantize in ["awq", "gptq"]: weights._set_gptq_params(model_id, revision) model = model_class(config, weights) diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index 543b07e8e..709437d93 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") -if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) -else: - raise ValueError("MAX_BATCH_SIZE is not set") + PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = [] @@ -1467,6 +1463,12 @@ class VlmCausalLM(Model): batch = self.batch_from_pb(request.batch, is_warmup=True) max_input_tokens = request.max_input_tokens max_prefill_batch_size = batch.input_ids.shape[0] + max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") + if max_batch_size_str is not None: + MAX_BATCH_SIZE = int(max_batch_size_str) + else: + raise ValueError("MAX_BATCH_SIZE is not set") + try: # max prefill batch size warmup _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 674a8aed1..5a7d21175 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -18,22 +18,27 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, ATTENTION from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + ) VLM_BATCH_TYPES = { PaliGemmaBatch, VlmCausalLMBatch, - IdeficsCausalLMBatch, + FlashVlmCausalLMBatch, + FlashMllamaCausalLMBatch, } except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. @@ -103,14 +108,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens, max_input_tokens, max_total_tokens = ( - self.model.warmup(request) - ) + if ATTENTION == "paged": + set_max_prefill_tokens(request.max_prefill_tokens) + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.dtype, + self.model.device, + ) - # W/A for the skip tokenizer path - # We need to call make_tokenizer_optional after the warmup, - # because router is not aware of that feature - make_tokenizer_optional(self.model.tokenizer) + # Override default values with None for clearer semantics. + max_input_tokens = ( + request.max_input_tokens + if request.HasField("max_input_tokens") + else None + ) + max_total_tokens = ( + request.max_total_tokens + if request.HasField("max_total_tokens") + else None + ) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(batch, max_input_tokens, max_total_tokens) + ) + else: + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(request) + ) + + # W/A for the skip tokenizer path + # We need to call make_tokenizer_optional after the warmup, + # because router is not aware of that feature + make_tokenizer_optional(self.model.tokenizer) return generate_pb2.WarmupResponse( max_supported_total_tokens=max_supported_total_tokens, diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 0e9b97fb2..1c45713e8 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -1,15 +1,13 @@ import os import torch - +from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - -# CUDA memory fraction -MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) class FakeBarrier: @@ -17,10 +15,11 @@ class FakeBarrier: pass -class FakeGroup: +class FakeGroup(ProcessGroup): def __init__(self, rank, size): self._rank = rank self._size = size + super().__init__(rank, size) def allreduce(self, *args, **kwargs): return FakeBarrier() @@ -42,42 +41,11 @@ class FakeGroup: def rank(self): return self._rank + def _get_backend_name(self): + return "fake" + def initialize_torch_distributed(): - - world_size = int(os.getenv("WORLD_SIZE", "1")) - - options = None - if torch.cuda.is_available(): - from torch.distributed import ProcessGroupNCCL - - # Set the device id. - assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" - device = RANK % torch.cuda.device_count() - torch.cuda.set_device(device) - torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) - backend = "nccl" - options = ProcessGroupNCCL.Options() - options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) - elif torch.hpu.is_available(): - backend = "hccl" - n_hpus = torch.hpu.device_count() - if world_size > n_hpus: - raise ValueError( - f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})." - ) - else: - try: - import oneccl_bindings_for_pytorch # noqa: F401 - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" - options = None - if WORLD_SIZE == 1: return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE else: @@ -87,11 +55,10 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. torch.distributed.init_process_group( - backend=backend, + backend="hccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, + timeout=timedelta(seconds=120), ) else: logger.warning("torch.distributed is already initialized.") diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 782b4f15b..22560dd7a 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,75 +1,28 @@ import torch from loguru import logger -import os -import importlib.util +def get_hpu_free_memory(device, memory_fraction): + from habana_frameworks.torch.hpu import memory_stats - -def is_ipex_available(): - return importlib.util.find_spec("intel_extension_for_pytorch") is not None - - -def get_cuda_free_memory(device, memory_fraction): - total_free_memory, _ = torch.cuda.mem_get_info(device) - total_gpu_memory = torch.cuda.get_device_properties(device).total_memory - free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) - return free_memory - - -def get_xpu_free_memory(device, memory_fraction): - total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + mem_stats = memory_stats(device_id) + logger.info(f"mem_stats: {mem_stats}") + total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] free_memory = max( - 0, - int( - total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) - ), + 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) ) return free_memory -def get_cpu_free_memory(device, memory_fraction): - import psutil - from text_generation_server.utils.dist import WORLD_SIZE - - mem = psutil.virtual_memory() - free_memory = int(mem.available * 0.95 / WORLD_SIZE) - return free_memory +def synchronize_hpu(device): + torch.hpu.synchronize() def noop(*args, **kwargs): pass -SYSTEM = None -if torch.version.hip is not None: - SYSTEM = "rocm" - empty_cache = torch.cuda.empty_cache - synchronize = torch.cuda.synchronize - get_free_memory = get_cuda_free_memory -elif torch.version.cuda is not None and torch.cuda.is_available(): - SYSTEM = "cuda" - empty_cache = torch.cuda.empty_cache - synchronize = torch.cuda.synchronize - get_free_memory = get_cuda_free_memory -elif is_ipex_available(): - SYSTEM = "ipex" - import intel_extension_for_pytorch # noqa: F401 - - if hasattr(torch, "xpu") and torch.xpu.is_available(): - empty_cache = torch.xpu.empty_cache - synchronize = torch.xpu.synchronize - get_free_memory = get_xpu_free_memory - else: - empty_cache = noop - synchronize = noop - get_free_memory = get_cpu_free_memory -else: - SYSTEM = "cpu" - - empty_cache = noop - synchronize = noop - get_free_memory = get_cpu_free_memory -logger.info(f"Detected system {SYSTEM}") +empty_cache = noop +synchronize = synchronize_hpu +get_free_memory = get_hpu_free_memory diff --git a/backends/gaudi/server/text_generation_server/utils/kernels.py b/backends/gaudi/server/text_generation_server/utils/kernels.py new file mode 100644 index 000000000..42745c716 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/kernels.py @@ -0,0 +1,22 @@ +import importlib + +from loguru import logger +from hf_kernels import load_kernel as hf_load_kernel + +from text_generation_server.utils.log import log_once + + +def load_kernel(*, module: str, repo_id: str): + """ + Load a kernel. First try to load it as the given module (e.g. for + local development), falling back to a locked Hub kernel. + """ + try: + m = importlib.import_module(module) + log_once(logger.info, f"Using local module for `{module}`") + return m + except ModuleNotFoundError: + return hf_load_kernel(repo_id=repo_id) + + +__all__ = ["load_kernel"] diff --git a/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 000000000..c227d30f5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index ee561acc4..a8faf4a59 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -4,9 +4,7 @@ from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download -from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.utils.weights import ( - DefaultWeightsLoader, WeightsLoader, ) @@ -129,64 +127,13 @@ def get_loader( f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." ) - if can_use_gptq_marlin( + return GPTQWeightsLoader( bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ): - from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader - - return GPTQMarlinWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - else: - return GPTQWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - elif quantize == "bitsandbytes": - from text_generation_server.layers.bnb import BNBWeight - - return DefaultWeightsLoader(BNBWeight) - elif quantize == "bitsandbytes-fp4": - from text_generation_server.layers.bnb import BNBFP4Weight - - return DefaultWeightsLoader(BNBFP4Weight) - elif quantize == "bitsandbytes-nf4": - from text_generation_server.layers.bnb import BNBNF4Weight - - return DefaultWeightsLoader(BNBNF4Weight) - elif quantize == "eetq": - from text_generation_server.layers.eetq import EETQWeight - - return DefaultWeightsLoader(EETQWeight) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2WeightsLoader - - return Exl2WeightsLoader() - elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeightsLoader - - # TODO: improve check once we have one config type per quantize value - if not isinstance(quantizer_config, _QuantizerConfig): - raise ValueError( - f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." - ) - - return MarlinWeightsLoader( - bits=quantizer_config.bits, - is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index 75e01f7ce..acd598d7a 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Union, Type from safetensors import safe_open from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM - class WeightsLoader(ABC): """ @@ -88,12 +86,9 @@ class UnquantizedWeight(Weight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): - from text_generation_server.layers.linear import FastLinear, FastLinearROCm + from text_generation_server.layers.linear import FastLinear - if SYSTEM == "rocm": - return FastLinearROCm(self.weight, bias) - else: - return FastLinear(self.weight, bias) + return FastLinear(self.weight, bias) class DefaultWeightsLoader(WeightsLoader): @@ -197,7 +192,7 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ - def _has_tensor(self, tensor_name: str): + def has_tensor(self, tensor_name: str): try: self.get_filename(tensor_name) except Exception: @@ -207,7 +202,9 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): + def get_tensor( + self, tensor_name: str, to_device: bool = True, to_dtype: bool = True + ) -> torch.Tensor: filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -218,6 +215,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, @@ -253,7 +251,8 @@ class Weights: # u4 which are disguised as int32. exl2 uses int16. # FP8 uses torch.float8_e4m3fn. if ( - tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + tensor.dtype + not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32) and to_dtype ): tensor = tensor.to(dtype=self.dtype) @@ -329,6 +328,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index e7f3d85a9..6da2b51da 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; - +use text_generation_router::usage_stats::Env; #[derive(Debug, Clone)] pub struct BlockAllocation { pub allocation_id: u64, @@ -141,6 +141,7 @@ pub struct SimpleAllocator { free_blocks: Vec, block_size: u32, window_size: Option, + is_hpu_device: bool, } impl SimpleAllocator { @@ -150,6 +151,7 @@ impl SimpleAllocator { // Block 0 is reserved for health checks free_blocks: (1..blocks).collect(), window_size, + is_hpu_device: Env::new().is_hpu_device(), } } } @@ -179,9 +181,15 @@ impl Allocator for SimpleAllocator { if required_blocks > self.free_blocks.len() as u32 { None } else { - let blocks = self + if self.is_hpu_device { + self.free_blocks.sort_by(|a, b| b.cmp(a)); + } + let mut blocks = self .free_blocks .split_off(self.free_blocks.len() - required_blocks as usize); + if self.is_hpu_device { + blocks.sort(); + } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d7ae11d54..d9056e413 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -28,8 +28,8 @@ impl Env { } } - pub fn is_hpu_device(&self) -> bool { - self.hpu_env != "N/A" + pub fn should_start_a_single_hpu_shard(&self) -> bool { + self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") } } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index acff85730..c169a78ce 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1559,7 +1559,7 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().is_hpu_device() { + if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); break; } @@ -1639,7 +1639,7 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().is_hpu_device() { + if env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("HPU detected, shard is ready"); break; } diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 353e9e378..a17aade9c 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -157,6 +157,7 @@ pub struct Env { docker_label: &'static str, nvidia_info: Option>, xpu_info: Option>, + hpu_info: Option>, system_env: SystemInfo, } @@ -289,6 +290,60 @@ impl XpuSmiInfo { } } +#[derive(Debug, Serialize, Clone)] +struct HpuSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + temperature: String, + utilization: String, + memory_total: String, + memory_free: String, + memory_used: String, + power_draw_instant: String, +} + +impl HpuSmiInfo { + fn new() -> Option> { + let output = Command::new("hl-smi") + .args([ + "--query-aip=name,bus_id,driver_version,temperature.aip,utilization.aip,memory.total,memory.free,memory.used,power.draw", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(HpuSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + temperature: record[3].to_string(), + utilization: record[4].to_string(), + memory_total: record[5].to_string(), + memory_free: record[6].to_string(), + memory_used: record[7].to_string(), + power_draw_instant: record[8].to_string(), + }); + } + + Some(infos) + } +} + #[derive(Serialize, Debug, Clone)] pub struct SystemInfo { cpu_count: usize, @@ -335,10 +390,14 @@ impl Env { system_env: SystemInfo::new(), nvidia_info: NvidiaSmiInfo::new(), xpu_info: XpuSmiInfo::new(), + hpu_info: HpuSmiInfo::new(), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } + pub fn is_hpu_device(&self) -> bool { + self.hpu_info.is_some() + } } pub fn is_container() -> io::Result {