From d62c941c563c5aa31d178c5b6f60caf898ec085e Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 14 Apr 2025 21:58:13 +0800 Subject: [PATCH 1/3] Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113) * clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A * fix TP in pageattn Signed-off-by: Wang, Yi A * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A * enable all the model. not testet yet Signed-off-by: Wang, Yi A * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A * fix phimoe issue Signed-off-by: Wang, Yi A * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A * enable dbrx remove some unused code Signed-off-by: Wang, Yi A * multi-modality initial PR Signed-off-by: Wang, Yi A * adjust warmup and enable vlm Signed-off-by: Wang, Yi A * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A * fix gptq issue Signed-off-by: Wang, Yi A * enable fp8 Signed-off-by: Wang, Yi A * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A * add warmup_decode Signed-off-by: Wang, Yi A * warmup decode Signed-off-by: Wang, Yi A * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A * fix comment Signed-off-by: Wang, Yi A * missing gptj change... Signed-off-by: Wang, Yi A * fix some issue Signed-off-by: Wang, Yi A * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A * match the latest vllm_extension ops Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 2 +- .../server/text_generation_server/cli.py | 11 +- .../layers/attention/__init__.py | 51 +- .../layers/attention/common.py | 195 +- .../layers/attention/cuda.py | 357 --- .../layers/attention/flash_attn_triton.py | 813 ------ .../layers/attention/flashinfer.py | 251 -- .../layers/attention/hpu.py | 95 + .../layers/attention/ipex.py | 82 - .../layers/attention/kv_cache.py | 139 + .../layers/attention/rocm.py | 308 --- .../layers/awq/quantize/__init__.py | 3 + .../layers/awq/quantize/hpu.py | 134 + .../layers/awq/quantize/qmodule.py | 49 - .../text_generation_server/layers/eetq.py | 43 - .../text_generation_server/layers/fp8.py | 381 ++- .../layers/gptq/__init__.py | 113 +- .../layers/gptq/custom_autotune.py | 261 -- .../layers/gptq/exllama.py | 134 - .../layers/gptq/exllamav2.py | 267 -- .../text_generation_server/layers/gptq/hpu.py | 186 ++ .../layers/gptq/quant_linear.py | 359 --- .../layers/gptq/quantize.py | 17 +- .../layers/layernorm.py | 149 +- .../text_generation_server/layers/linear.py | 90 +- .../layers/marlin/__init__.py | 15 - .../layers/marlin/fp8.py | 140 - .../layers/marlin/gptq.py | 464 ---- .../layers/marlin/marlin.py | 346 --- .../layers/marlin/util.py | 141 - .../layers/moe/__init__.py | 55 +- .../text_generation_server/layers/moe/fp8.py | 173 ++ .../moe/{fused_moe_rocm.py => fused_moe.py} | 17 +- .../layers/moe/gptq_marlin.py | 215 -- .../layers/moe/unquantized.py | 41 +- .../text_generation_server/layers/rotary.py | 162 +- .../layers/tensor_parallel.py | 41 +- .../text_generation_server/models/__init__.py | 680 ++++- .../models/custom_modeling/bloom_modeling.py | 4 +- .../custom_modeling/flash_cohere_modeling.py | 168 +- .../custom_modeling/flash_dbrx_modeling.py | 85 +- .../flash_deepseek_v2_modeling.py | 94 +- .../flash_deepseek_v3_modeling.py | 642 +++++ .../custom_modeling/flash_gemma2_modeling.py | 63 +- .../custom_modeling/flash_gemma_modeling.py | 62 +- .../custom_modeling/flash_gpt2_modeling.py | 64 +- .../custom_modeling/flash_gptj_modeling.py | 117 +- .../custom_modeling/flash_llama_modeling.py | 200 +- .../custom_modeling/flash_llava_next.py | 285 +++ .../custom_modeling/flash_mistral_modeling.py | 118 +- .../custom_modeling/flash_mixtral_modeling.py | 81 +- .../models/custom_modeling/flash_mllama.py | 986 +++++++ .../custom_modeling/flash_neox_modeling.py | 63 +- .../flash_pali_gemma_modeling.py | 12 +- .../custom_modeling/flash_phi_modeling.py | 60 +- .../custom_modeling/flash_qwen2_modeling.py | 117 +- .../custom_modeling/flash_rw_modeling.py | 105 +- .../flash_santacoder_modeling.py | 60 +- .../flash_starcoder2_modeling.py | 199 +- .../models/custom_modeling/idefics2.py | 31 +- .../models/custom_modeling/idefics3.py | 596 +++++ .../custom_modeling/idefics_modeling.py | 104 +- .../models/custom_modeling/mamba_modeling.py | 10 +- .../models/custom_modeling/mpt_modeling.py | 1215 --------- .../models/custom_modeling/neox_modeling.py | 796 ------ .../models/custom_modeling/opt_modeling.py | 857 ------- .../models/custom_modeling/phi_modeling.py | 336 --- .../models/custom_modeling/qwen2_5_vl.py | 946 +++++++ .../models/custom_modeling/qwen2_vl.py | 519 ++++ .../models/custom_modeling/t5_modeling.py | 1227 --------- .../models/custom_modeling/vlm.py | 10 +- .../models/flash_causal_lm.py | 2266 +++++++++-------- .../models/flash_vlm_causal_lm.py | 489 ++++ .../text_generation_server/models/globals.py | 28 +- .../models/idefics_causal_lm.py | 21 +- .../models/mllama_causal_lm.py | 179 +- .../text_generation_server/models/model.py | 1 + .../models/pali_gemma.py | 6 +- .../models/seq2seq_lm.py | 20 +- .../models/vlm_causal_lm.py | 12 +- .../server/text_generation_server/server.py | 61 +- .../text_generation_server/utils/dist.py | 51 +- .../utils/import_utils.py | 69 +- .../text_generation_server/utils/kernels.py | 22 + .../utils/prefill_chunking.py | 24 + .../utils/quantization.py | 57 +- .../text_generation_server/utils/weights.py | 20 +- backends/v3/src/block_allocator.rs | 12 +- launcher/src/env_runtime.rs | 4 +- launcher/src/main.rs | 4 +- router/src/usage_stats.rs | 59 + 91 files changed, 8804 insertions(+), 11813 deletions(-) delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/cuda.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py create mode 100644 backends/gaudi/server/text_generation_server/layers/attention/hpu.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/ipex.py create mode 100644 backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/attention/rocm.py create mode 100644 backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py create mode 100644 backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/eetq.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/exllama.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py create mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/hpu.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/__init__.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/fp8.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/gptq.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/marlin.py delete mode 100644 backends/gaudi/server/text_generation_server/layers/marlin/util.py create mode 100644 backends/gaudi/server/text_generation_server/layers/moe/fp8.py rename backends/gaudi/server/text_generation_server/layers/moe/{fused_moe_rocm.py => fused_moe.py} (80%) delete mode 100644 backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py delete mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py create mode 100644 backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py create mode 100644 backends/gaudi/server/text_generation_server/utils/kernels.py create mode 100644 backends/gaudi/server/text_generation_server/utils/prefill_chunking.py diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index eff87ab65..06073fe40 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -95,7 +95,7 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir - +RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 700f763e9..53837ef71 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -16,15 +16,9 @@ app = typer.Typer() class Quantization(str, Enum): - bitsandbytes = "bitsandbytes" - bitsandbytes_nf4 = "bitsandbytes-nf4" - bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" - eetq = "eetq" - exl2 = "exl2" fp8 = "fp8" - marlin = "marlin" class Dtype(str, Enum): @@ -105,6 +99,9 @@ def serve( "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4", + "gptq", + "awq", + "fp8", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." @@ -112,7 +109,7 @@ def serve( logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - if sharded: + if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: tgi_file = Path(__file__).resolve().parent / "tgi_service.py" num_shard = int(os.getenv("WORLD_SIZE", "1")) logger.info("CLI SHARDED = {}".format(num_shard)) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 4d83a11fc..9ba9f6e08 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -1,43 +1,28 @@ -from text_generation_server.utils.import_utils import SYSTEM -import os +from .common import ( + Seqlen, + HPUPagedAttentionMetadata, + trim_attn_metadata, + trim_seqlen_metadata, +) -from .common import Seqlen +from .hpu import ( + SUPPORTS_WINDOWING, + attention, + paged_attention, +) -if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") -if SYSTEM == "cuda": - from .cuda import ( - attention, - paged_attention, - reshape_and_cache, - SUPPORTS_WINDOWING, - PREFILL_IN_KV_CACHE, - ) -elif SYSTEM == "rocm": - from .rocm import ( - attention, - paged_attention, - reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, - ) -elif SYSTEM == "ipex": - from .ipex import ( - attention, - paged_attention, - reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, - ) -else: - raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") +# KVCache needs `reshape_and_cache`, so ensure that it is defined already. +from .kv_cache import KVCache, get_kv_scales __all__ = [ "attention", + "get_kv_scales", "paged_attention", - "reshape_and_cache", - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", + "KVCache", "Seqlen", + "HPUPagedAttentionMetadata", + "trim_seqlen_metadata", + "trim_attn_metadata", ] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index d6e512c01..8ec9fb461 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -1,72 +1,147 @@ from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION import torch -from typing import Optional +from typing import Optional, List, Dict +import collections + +_TYPE_CACHE = {} -if ATTENTION in {"flashinfer", "flashdecoding"}: +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: Optional[torch.Tensor] - cu_seqlen_k: Optional[torch.Tensor] - max_q: int - max_k: int + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] + attn_bias: Optional[torch.Tensor] - def __init__( - self, - input_lengths, - prefix_lengths, - cu_seqlen_q=None, - max_q=None, - max_k=None, - ): - self.input_lengths = input_lengths - self.prefix_lengths = prefix_lengths - device = self.input_lengths.device - shape = self.input_lengths.shape - if cu_seqlen_q is None: - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) - max_q = 1 - else: - assert max_q is not None - assert max_k is not None - cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) - # cuda graphs don't like this and this is necessary to clamp within mistral - # Although FA2 might not want the clamping - # cu_seqlen_k[0] = 0 - total = self.input_lengths + self.prefix_lengths - torch.cumsum(total, -1, out=cu_seqlen_k[1:]) +def subtuple( + obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None, +): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + if isinstance(obj, dict): + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields)) + return _TYPE_CACHE[typename](**values) - self.cu_seqlen_q = cu_seqlen_q - self.cu_seqlen_k = cu_seqlen_k - self.max_q = max_q - self.max_k = max_k - def clamp(self, max): - # Flash decoding doesn't need to clamp - return self +def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. -else: + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) - @dataclass - class Seqlen: - input_lengths: torch.Tensor - prefix_lengths: torch.Tensor - cu_seqlen_q: torch.Tensor - max_q: int - max_k: int + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple( + metadata, + "TrimmedAttentionMetadata", + [ + "block_list", + "block_mapping", + "block_usage", + "block_scales", + "block_groups", + "attn_bias", + ], + ) + return attention_metadata - def clamp(self, max): - if SYSTEM == "rocm": - return self - raise NotImplementedError("Not implemented seqlen for paged") - return Seqlen(torch.clamp(self.input_lengths, max=max)) + +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cache_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__( + self, + input_lengths, + cache_lengths, + cu_seqlen_q=None, + ): + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.cache_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + + +def trim_seqlen_metadata(metadata: Seqlen) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple( + metadata, + "TrimmedSeqlen", + [ + "input_lengths", + "cache_lengths", + "cu_seqlen_q", + "cu_seqlen_k", + ], + ) + return attention_metadata diff --git a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py b/backends/gaudi/server/text_generation_server/layers/attention/cuda.py deleted file mode 100644 index 51af928d5..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/cuda.py +++ /dev/null @@ -1,357 +0,0 @@ -import torch -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ( - ATTENTION, - BLOCK_SIZE, -) -from text_generation_server.layers.attention import Seqlen -from typing import Optional - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE = 512 - -try: - from vllm._C import cache_ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION in {"flashdecoding", "flashinfer"}: - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - # value_cache => [num_blocks, num_heads, head_size, block_size] - # block_size = value_cache.shape[3] - block_size = BLOCK_SIZE - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import decode_state - - return decode_state.get().forward( - query.contiguous(), - paged_kv_cache=(key_cache, value_cache), - logits_soft_cap=softcap, - sm_scale=softmax_scale, - ) - elif ATTENTION == "flashdecoding": - max_q = 1 - max_k = max_s - import flash_attn_2_cuda - - # TODO fixme when flash contains the fix. - # Number of splits is not correctly handled - # by the current path - # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 - # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. - if softcap is None: - softcap = 0.0 - out = flash_attn_2_cuda.varlen_fwd( - query, - key_cache, - value_cache, - None, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, # pad_k - None, - block_tables, - None, - max_q, - max_k, - 0.0, # dropout - softmax_scale, - False, # zero_tensors - True, # causal - -1, # Window_left - -1, # Window right - softcap, - False, # return softmax - None, # generator - ) - return out[0] - else: - if softcap is not None: - raise RuntimeError("Paged attention doesn't support softcapping") - input_lengths = seqlen.input_lengths - from vllm._C import ops - - out = torch.empty_like(query) - - use_v1 = max_s <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 - ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - return out - - -try: - is_ampere_or_newer = major >= 8 and minor >= 0 - if not is_ampere_or_newer: - raise ImportError("FlashAttention only supports Ampere GPUs or newer.") - - import flash_attn_2_cuda - - V2 = True -except ImportError: - try: - import flash_attn_cuda - - V2 = False - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - - -SUPPORTS_WINDOWING = V2 - -if ATTENTION == "flashinfer": - - def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - from text_generation_server.layers.attention.flashinfer import ( - prefill_with_paged_kv_state, - ) - - return prefill_with_paged_kv_state.get().forward( - q.contiguous(), - causal=causal, - paged_kv_cache=(key_cache, value_cache), - logits_soft_cap=softcap, - sm_scale=softmax_scale, - window_left=window_size_left, - ) - -elif V2: - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] - -else: - - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - if softcap is not None: - raise NotImplementedError("softcap is only available with flash attn v2") - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - False, - 0, - None, - ) - return out - - -# Prefill in the cache with every kind of attention, unless we -# have a configuration that requires flash-attention v1, which -# does not support block tables. -PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py b/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py deleted file mode 100644 index 3a6f9a730..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/flash_attn_triton.py +++ /dev/null @@ -1,813 +0,0 @@ -#!/usr/bin/env python -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team - -Features supported: - -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: - -1) Non power of two head dims - -""" - -import torch -import triton -import triton.language as tl - -torch_dtype: tl.constexpr = torch.float16 - - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets( - philox_seed, philox_offset, dropout_p, m, n, stride - ).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) - rng_keep = rng_output > dropout_p - return rng_keep - - -@triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): - # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn( - bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" - ) - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = ( - batch_philox_offset - + start_m * BLOCK_M * actual_seqlen_k - + start_n - - BLOCK_N - ) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, BLOCK_N) - ) - return acc, l_i, m_i - - -@triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config( - { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - ], - key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], -) -@triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = MAX_SEQLENS_K - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N - ) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = ( - off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - ) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE - else: - off_h_k = off_h_q - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k - ) - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_HEAD, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, n_full_blocks) - ) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_HEAD, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full( - (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 - ) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - - -class _attention(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): - if o is None: - o = torch.empty_like(q, dtype=v.dtype) - - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) - - def grid(META): - return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch - - encoded_softmax = None - - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) - else: - bias_strides = (0, 0, 0, 0) - - attn_fwd[grid]( - q, - k, - v, - bias, - sm_scale, - None, - o, - *q_strides, - *k_strides, - *v_strides, - *o_strides, - *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, - philox_seed=philox_seed, - philox_offset_base=philox_offset, - encoded_softmax=encoded_softmax, - HQ=nheads_q, - HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, - BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) - - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False - return o, encoded_softmax - - -triton_attention = _attention.apply diff --git a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py b/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py deleted file mode 100644 index d603c6f5f..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/flashinfer.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Optional -from contextvars import ContextVar -from contextlib import contextmanager - -import flashinfer -import torch - -prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( - "prefill_state" -) - -prefill_with_paged_kv_state: ContextVar[ - flashinfer.BatchPrefillWithPagedKVCacheWrapper -] = ContextVar("prefill_with_paged_kv_state") - -decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( - "decode_state" -) - -workspace: Optional[torch.Tensor] = None - - -def get_workspace(device): - """Get shared flashinfer workspace.""" - global workspace - if workspace is None: - workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - return workspace - - -def create_prefill_with_paged_kv_state( - *, - device: torch.device, -): - """Create a prefill state that uses the KV cache.""" - workspace_buffer = get_workspace(device) - return flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD", use_cuda_graph=False - ) - - -@contextmanager -def use_prefill_with_paged_kv_state( - *, - state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, - block_tables: torch.Tensor, - cu_seqlens: torch.Tensor, - input_lengths: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - page_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - - # Get the lengths of the last page in a block. - if page_size == 1: - last_page_len = torch.ones( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - else: - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 - - token = prefill_with_paged_kv_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - paged_kv_indptr=indptr, - paged_kv_indices=block_tables, - paged_kv_last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - page_size=page_size, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_with_paged_kv_state.reset(token) - - -def create_prefill_state( - *, - device: torch.device, -): - """Create a prefill state.""" - workspace_buffer = get_workspace(device) - return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, kv_layout="NHD", use_cuda_graph=False - ) - - -@contextmanager -def use_prefill_state( - *, - state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, - cu_seqlens: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer prefill state to the given - `state` and parameters. This state will be used by all calls to the - `attention` function while the context manager is active. - """ - - token = prefill_state.set(state) - try: - state.begin_forward( - qo_indptr=cu_seqlens, - kv_indptr=cu_seqlens, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - prefill_state.reset(token) - - -def create_decode_state( - *, - device: torch.device, - num_heads: int, - num_kv_heads: int, -): - """Create a decode state.""" - workspace_buffer = get_workspace(device) - num_groups = num_heads // num_kv_heads - return flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout="NHD", - use_cuda_graph=False, - # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 - use_tensor_cores=num_groups not in [1, 2, 4, 8], - ) - - -def create_decode_state_cuda_graphs( - *, - device: torch.device, - block_tables: torch.Tensor, - block_tables_ptr: torch.Tensor, - last_page_len: torch.Tensor, - num_heads: int, - num_kv_heads: int, -): - """ - Create a decode state for use with CUDA Graphs. `block_tables`, - `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are - therefore stored as part of the state. - """ - workspace_buffer = get_workspace(device) - num_groups = num_heads // num_kv_heads - return flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout="NHD", - use_cuda_graph=True, - paged_kv_indices_buffer=block_tables, - paged_kv_indptr_buffer=block_tables_ptr, - paged_kv_last_page_len_buffer=last_page_len, - # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60 - use_tensor_cores=num_groups not in [1, 2, 4, 8], - ) - - -@contextmanager -def use_decode_state( - *, - state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, - input_lengths: torch.Tensor, - block_tables: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - page_size: int, - dtype: torch.dtype, - window_left: int, -): - """ - Context manager to set the active flashinfer decoding state to the given - `state` and parameters. This state will be used by all calls to the - `paged_attention` function while the context manager is active. - """ - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) - # Round up to page size and then calculate the cumulative sum to get - # the indices into the block table. - torch.add(input_lengths, page_size - 1, out=indptr[1:]) - indptr[1:].div_(page_size, rounding_mode="floor") - indptr[1:].cumsum_(-1) - - # Get the lengths of the last page in a block. - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) - torch.sub(input_lengths, 1, out=last_page_len) - last_page_len.remainder_(page_size) - last_page_len += 1 - - token = decode_state.set(state) - - try: - state.begin_forward( - indptr=indptr, - indices=block_tables, - last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - page_size=page_size, - data_type=dtype, - q_data_type=dtype, - window_left=window_left, - ) - yield - finally: - state.end_forward() - if token is not None: - decode_state.reset(token) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py new file mode 100644 index 000000000..f34e93abc --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -0,0 +1,95 @@ +import torch +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from typing import Optional +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales +from vllm_hpu_extension import ops +from vllm_hpu_extension.utils import Matmul +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA +import os + +SUPPORTS_WINDOWING = False + + +def fetch_from_cache(cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + return cache[: blocks.size(0)] + else: + return cache.index_select(0, blocks) + + +def attention( + *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: KVCache, + kv_scales: KVScales, + seqlen: Seqlen, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, +): + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + bs = seqlen.input_lengths.shape[0] + _, head_num, head_size = query.shape + _, kv_head_num, head_size = key.shape + query = query.view(bs, -1, head_num, head_size).transpose(1, 2) + key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=seqlen.input_lengths, + padding_side="left", + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + return attn_output + + +def paged_attention( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, +): + batch_size, head_num, head_size = query.shape + output = ops.flat_pa( + query=query.view(batch_size, 1, head_num * head_size), + key_cache=kv_cache.key, + value_cache=kv_cache.value, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_scales=hpu_attention_meta.block_scales, + block_groups=hpu_attention_meta.block_groups, + scale=softmax_scale, + matmul_qk_op=Matmul(), + matmul_av_op=Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=fetch_from_cache, + values_fetch_func=fetch_from_cache, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, head_size) + + +__all__ = [ + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", +] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py b/backends/gaudi/server/text_generation_server/layers/attention/ipex.py deleted file mode 100644 index 657c90af4..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/ipex.py +++ /dev/null @@ -1,82 +0,0 @@ -import intel_extension_for_pytorch as ipex -import torch -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE -from text_generation_server.layers.attention import Seqlen -from typing import Optional - -SUPPORTS_WINDOWING = False -PREFILL_IN_KV_CACHE = False - - -def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap: Optional[float] = None, -): - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - q.contiguous() if q.device.type == "xpu" else q, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) - - return out - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - out = torch.empty_like(query) - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - seqlen.input_lengths, - BLOCK_SIZE, - max_s, - None, - ) - return out diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py new file mode 100644 index 000000000..d238cdb97 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -0,0 +1,139 @@ +from typing import Tuple +from dataclasses import dataclass, field + +import torch + +from text_generation_server.models.globals import BLOCK_SIZE +from text_generation_server.utils.weights import Weights +from vllm_hpu_extension import cache_ops + + +@dataclass +class KVScales: + """ + Key-value scales for FP8 KV cache. + + This data class stores key and value scales both as a GPU tensor and + as a GPU float. This inconvenience is necessary because some functions + (e.g. scaling kernels) take scales as a GPU tensor, whereas others + (e.g. flashinfer) take scales as a CPU scalar. + """ + + key_scale: torch.Tensor + value_scale: torch.Tensor + key_scale_cpu: float = field(init=False) + value_scale_cpu: float = field(init=False) + + def __post_init__(self): + if self.key_scale.numel() != 1 or self.value_scale.numel() != 1: + raise ValueError("Key and value scales must be scalar tensors.") + + self.key_scale_cpu = self.key_scale.item() + self.value_scale_cpu = self.value_scale.item() + + +class KVCache: + """ + Key-value cache for attention layers. + """ + + kv_cache: Tuple[torch.Tensor, torch.Tensor] + + def __init__( + self, + *, + num_blocks: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + + self.kv_cache = ( + torch.zeros( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.zeros( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache[0].dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache[0] + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache[1] + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + + key_cache = self.kv_cache[0] + value_cache = self.kv_cache[1] + + paged_reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, + ) + + +def paged_reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, +): + block_idx = slots // BLOCK_SIZE + block_offset = slots % BLOCK_SIZE + cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) + cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + + +def get_kv_scales(weights: Weights, prefix: str) -> KVScales: + """Load KV cache scales.""" + + key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device) + value_scale = key_scale + if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor( + f"{prefix}.v_scale" + ): + key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float() + value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float() + elif weights.has_tensor(f"{prefix}.kv_scale"): + # Fall back to older more coarse-grained scale when available. + key_scale = weights.get_tensor(f"{prefix}.kv_scale").float() + value_scale = key_scale + + return KVScales(key_scale=key_scale, value_scale=value_scale) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py b/backends/gaudi/server/text_generation_server/layers/attention/rocm.py deleted file mode 100644 index 646a763d3..000000000 --- a/backends/gaudi/server/text_generation_server/layers/attention/rocm.py +++ /dev/null @@ -1,308 +0,0 @@ -import os -from typing import Optional -import torch -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION -from text_generation_server.layers.attention import Seqlen -from text_generation_server.utils.log import log_master -from loguru import logger - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 - -_PARTITION_SIZE_V1V2 = 512 -_PARTITION_SIZE_CUSTOM = 256 - -use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} -ENGINE = "triton" if use_triton else "ck" - -PREFILL_IN_KV_CACHE = False - -use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" -try: - if use_rocm_custom_paged_attn: - from vllm._custom_C import paged_attention_custom -except ImportError as e: - log_master( - logger.info, - f"Custom Paged Attention not available. Complete error: {e}", - ) - use_rocm_custom_paged_attn = False - -try: - import vllm._custom_ops as ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION == "flashdecoding": - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - seqlen: Seqlen, - max_s: int, - softcap: Optional[float] = None, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - if softcap is not None: - raise RuntimeError("Paged attention doesn't support softcapping") - - # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - - num_kv_heads = key_cache.shape[1] - gqa_ratio = num_heads // num_kv_heads - use_custom = ( - use_rocm_custom_paged_attn - and (query.dtype == torch.half or query.dtype == torch.bfloat16) - and (head_size == 128 or head_size == 64) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_s <= 32768 - ) - - if not use_custom: - _PARTITION_SIZE = _PARTITION_SIZE_V1V2 - else: - _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM - - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = seqlen.input_lengths - - out = torch.empty_like(query) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - import vllm._custom_ops as ops - - use_v1 = ( - max_s <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512) - and not use_custom - ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - if not use_custom: - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - paged_attention_custom( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - ) - - return out - - -if ENGINE != "triton": - try: - import flash_attn_2_cuda - - log_master( - logger.info, - "ROCm: using Flash Attention 2 Composable Kernel implementation.", - ) - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - for idx in range(torch.cuda.device_count()): - name = torch.cuda.get_device_name(idx) - if "MI210" not in name and "MI250" not in name: - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - raise ImportError( - f"AMD GPU with ROCm capability {major} {minor} is not supported" - ) from e - - -SUPPORTS_WINDOWING = False -if ENGINE == "ck": - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: float = 0.0, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - None, - None, - None, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] - -elif ENGINE == "triton": - from .flash_attn_triton import triton_attention - - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: Optional[float] = None, - ): - if softcap is not None: - raise NotImplementedError("softcap is only available with CK flash attn") - - out = torch.empty_like(q) - - # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - output, _ = triton_attention( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - causal, - softmax_scale, - ) - return output - -else: - raise RuntimeError(f"Unknown attention engine {ENGINE}") diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py new file mode 100644 index 000000000..856d7c281 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py @@ -0,0 +1,3 @@ +from .hpu import WQLinear + +__all__ = ["WQLinear"] diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py new file mode 100644 index 000000000..3af0131b3 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py @@ -0,0 +1,134 @@ +from typing import Optional +import torch +import torch.nn as nn + +try: + import habana_frameworks.torch.hpu # noqa: F401 + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + +AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + # unpacking columnwise + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + iweights = iweights.view(iweights.shape[0], -1) + + # unpacking columnwise + if qzeros is not None: + izeros = torch.bitwise_right_shift( + qzeros[:, :, None], shifts[None, None, :] + ).to( + torch.int8 # smallest dtype available + ) + izeros = izeros.view(izeros.shape[0], -1) + else: + izeros = qzeros + + return iweights, izeros + + +def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange( + iweights.shape[-1], + dtype=torch.int32, + device=izeros.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + if izeros is not None: + izeros = izeros[:, reverse_order_tensor] + iweights = iweights[:, reverse_order_tensor] + + return iweights, izeros + + +def unpack_weight_and_zeros(qweight, qzeros, bits): + # Unpack the qweight and qzeros tensors + iweight, izeros = unpack_awq(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + return iweight, izeros + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros( + (normal.shape[0], normal.shape[1] // 32 * bits), + dtype=torch.int32, + device=input.device, + ) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + + +class WQLinear(nn.Module): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.bias = bias + self._preprocessing() + + def _preprocessing(self): + device = self.qweight.device + weight, zeros = unpack_weight_and_zeros( + self.qweight.cpu(), self.qzeros.cpu(), self.w_bit + ) + self.qweight = pack_tensor(weight).to(device) + self.qzeros = pack_tensor(zeros).to(device) + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + x = x.reshape(-1, x.shape[-1]) + weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + outputs = torch.matmul(x, weights) + + outputs = outputs + self.bias if self.bias is not None else outputs + outputs = outputs.reshape(out_shape) + return outputs diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py deleted file mode 100644 index 391371a55..000000000 --- a/backends/gaudi/server/text_generation_server/layers/awq/quantize/qmodule.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py - -from typing import Optional -import torch -import torch.nn as nn -import awq_inference_engine # with CUDA kernels - - -# class ScaledActivation(nn.Module): -# def __init__(self, module, scales): -# super().__init__() -# self.act = module -# self.scales = nn.Parameter(scales.data) -# -# def forward(self, x): -# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) - - -class WQLinear(nn.Module): - def __init__( - self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] - ): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.in_features = qweight.shape[0] - self.out_features = qweight.shape[1] * 32 // w_bit - - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else self.in_features - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert self.out_features % (32 // self.w_bit) == 0 - - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.bias = bias - - @torch.no_grad() - def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features,) - out = awq_inference_engine.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/eetq.py b/backends/gaudi/server/text_generation_server/layers/eetq.py deleted file mode 100644 index b1e5235a0..000000000 --- a/backends/gaudi/server/text_generation_server/layers/eetq.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass - -import torch -from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import UnquantizedWeight - - -@dataclass -class EETQWeight(UnquantizedWeight): - weight: torch.Tensor - - def get_linear(self, bias: torch.Tensor): - try: - from text_generation_server.layers.eetq import EETQLinear - - return EETQLinear(self.weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - - -class EETQLinear(torch.nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - device = weight.device - if weight.dtype != torch.float16: - weight = weight.to(dtype=torch.float16) - weight = torch.t(weight).contiguous().cpu() - weight, scale = quant_weights(weight, torch.int8, False) - - self.weight = weight.cuda(device) - self.scale = scale.cuda(device) - self.bias = bias.cuda(device) if bias is not None else None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output = w8_a16_gemm(input, self.weight, self.scale) - output = output + self.bias if self.bias is not None else output - return output diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 61dd51151..0dc5cdafd 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -1,100 +1,152 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Type, Union, List + import torch -from dataclasses import dataclass -from typing import Optional, Union, List -from loguru import logger - -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import ( Weight, WeightsLoader, UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_master, log_once -import importlib.util + +from vllm_hpu_extension.ops import scaled_fp8_quant +from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 +import habana_frameworks.torch.utils.experimental as htexp + +w8a8_block_fp8_matmul = None +per_token_group_quant_fp8 = None +quant_dtype: torch.dtype = torch.float8_e4m3fn -FBGEMM_MM_AVAILABLE = False -FBGEMM_DYN_AVAILABLE = False - - -def is_fbgemm_gpu_available(): - try: - return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None - except ModuleNotFoundError: - return False - - -if is_fbgemm_gpu_available(): - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - FBGEMM_MM_AVAILABLE = major == 9 - FBGEMM_DYN_AVAILABLE = major >= 8 -else: - log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") - - -def get_fp8_linear() -> torch.nn.Module: +def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. """ - - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - if major == 8: - from text_generation_server.layers.marlin import GPTQMarlinFP8Linear - - return GPTQMarlinFP8Linear - # On other systems let Torch decide if the hardware supports FP8. return Fp8Linear -def fp8_quantize( - weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False -): - if FBGEMM_DYN_AVAILABLE and not scalar: - qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( - weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype - ) - return qweight, scale +def normalize_e4m3fn_to_native_float8( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return weight, weight_scale, input_scale - # weight, scale = quant_weights(weight, torch.int8, False) - finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) - scale = scale.float().reciprocal() - return qweight, scale + +def per_tensor_dequantize( + tensor: torch.Tensor, + inv_scale: Union[float, torch.Tensor], + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + device = tensor.device + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + # dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to("cpu") + + fake_qweight = tensor.to(dtype).to(device) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def requantize_with_max_scale( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: int, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + if is_hpu_gaudi2(): + max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor() + + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx], dtype + ) + weight[start:end, :], max_w_scale_normalized = fp8_quantize( + weight_dq, max_w_scale + ) + start = end + + return weight, max_w_scale_normalized + + +def fp8_quantize( + weight: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[torch.Tensor] = None, + qdtype: torch.dtype = torch.float8_e4m3fn, + scalar: bool = False, +): + """ + This function returns a reciprocal of the scale, so that a tensor can be unscaled + by multiplying it with the returned scale. If a scale is given through the `scale` + argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can + be used without modification). + """ + shape = weight.shape + qweight, scale = scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + scale=scale, + scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, + ) + + return qweight.reshape(shape), scale class HybridFP8UnquantLoader(WeightsLoader): """Weight loader that loads FP8 and unquantized Torch tensors.""" - def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + def __init__( + self, + activation_scale_ub: Optional[float], + to_fp8: bool, + weight_block_size: Optional[List[int]] = None, + ): self.activation_scale_ub = activation_scale_ub self.to_fp8 = to_fp8 + self.weight_block_size = weight_block_size def get_weights(self, weights: "Weights", prefix: str): w = weights.get_tensor(f"{prefix}.weight") if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = weights.get_tensor(f"{prefix}.weight_scale_inv") + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) # FP8 branch - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -116,6 +168,7 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: scale = weights.get_packed_sharded( f"{prefix}.weight_scale", @@ -123,11 +176,29 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -148,15 +219,48 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -169,14 +273,35 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) + if self.weight_block_size is not None: + # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems. + scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) + + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype ) return Fp8Weight( weight=w, weight_scale=scale, + input_scale=input_scale, activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) @@ -191,83 +316,126 @@ class Fp8Weight(Weight): weight: torch.Tensor dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None + force_w8a16: bool = False + weight_block_size: Optional[List[int]] = None def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: - return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant( + self.weight, bias, self.dtype + ) # This is not checked by the fbgemm kernels, but they require contiguous # memory. Can be non-contiguous when we e.g. expand from scalars. self.weight_scale = self.weight_scale.contiguous() - return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8( + weight=self.weight, + scale=self.weight_scale, + dtype=self.dtype, + bias=bias, + input_scale=self.input_scale, + scale_upper_bound=self.activation_scale_ub, + weight_block_size=self.weight_block_size, ) class Fp8Linear(torch.nn.Module): + _device_identity_cache = {} + def __init__( self, - qweight, - scale, - scale_upper_bound, - bias, - dtype, + qweight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[float] = None, + weight_block_size: Optional[List[int]] = None, ) -> None: super().__init__() - if FBGEMM_MM_AVAILABLE: - log_once(logger.info, "Using FBGEMM fp8 optimized kernels") self.dtype = dtype self.qweight = qweight - self.scale = scale - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device - ) - if scale_upper_bound is not None - else None - ) + self.scale = scale.float() + self.input_scale = input_scale.float() if input_scale is not None else None + self.weight_block_size = weight_block_size + self.scale_upper_bound = scale_upper_bound self.bias = bias if bias is not None else None @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=True) return cls( - qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + qweight=qweight, + scale=scale, + dtype=dtype, + bias=bias, + input_scale=None, + scale_upper_bound=None, ) @classmethod - def from_fp8(cls, weight, scale, input_scale, bias, dtype): - if FBGEMM_DYN_AVAILABLE: - # fbgemm needs float32 scales. - scale = scale.float() + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> "Fp8Linear": + input_scale = kwargs.get("input_scale", None) + scale_upper_bound = kwargs.get("scale_upper_bound", None) + weight_block_size = kwargs.get("weight_block_size", None) + return cls( qweight=weight, scale=scale, - scale_upper_bound=input_scale, + input_scale=input_scale, + scale_upper_bound=scale_upper_bound, bias=bias, dtype=dtype, + weight_block_size=weight_block_size, ) - def forward(self, input: torch.Tensor) -> torch.Tensor: - if FBGEMM_MM_AVAILABLE: - qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound - ) + @classmethod + def get_shared_device_identity(cls, device): + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale + if device not in cls._device_identity_cache: + cls._device_identity_cache[device] = torch.ones(1, device=device) + return cls._device_identity_cache[device] - y = torch.ops.fbgemm.f8f8bf16_rowwise( + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.weight_block_size is not None: + # https://arxiv.org/pdf/2412.19437 + # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and + # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we + # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output + # channels). + qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) + output = w8a8_block_fp8_matmul( qinput, self.qweight, scale, self.scale, - use_fast_accum=True, - bias=self.bias, + self.weight_block_size, + output_dtype=input.dtype, ) - return y.to(self.dtype) - qinput, scale = fp8_quantize(input, scalar=True) - output, _ = torch._scaled_mm( + if self.bias is not None: + output = output + self.bias + return output.to(dtype=input.dtype) + + qinput, scale = fp8_quantize( + input, + self.input_scale, + scale_upper_bound=self.scale_upper_bound, + scalar=True, + ) + + output = torch._scaled_mm( qinput, self.qweight.t(), out_dtype=self.dtype, @@ -275,11 +443,16 @@ class Fp8Linear(torch.nn.Module): scale_b=self.scale, bias=self.bias, ) + + if isinstance(output, tuple) and len(output) == 2: + output = output[0] + return output def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1).expand(shape[0]) + return scale.reshape(-1) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 505caa59a..90b8f6923 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -1,14 +1,15 @@ -import os from dataclasses import dataclass from typing import List, Optional, Union import torch from loguru import logger -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from .hpu import QuantLinear + + @dataclass class GPTQWeight(Weight): qweight: torch.Tensor @@ -30,13 +31,8 @@ class GPTQWeight(Weight): def get_linear(self, bias: torch.Tensor): if self.use_awq_kernel: - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear + from text_generation_server.layers.awq.quantize import WQLinear return WQLinear( w_bit=self.bits, @@ -50,18 +46,7 @@ class GPTQWeight(Weight): raise NotImplementedError( "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" ) - elif self.use_exllama: - try: - from text_generation_server.layers.gptq import ExllamaQuantLinear - except ImportError: - raise NotImplementedError( - "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - return ExllamaQuantLinear(self, bias) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - return QuantLinear( self.qweight, self.qzeros, @@ -118,23 +103,6 @@ class GPTQWeightsLoader(WeightsLoader): else: g_idx = None - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") @@ -247,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader): [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 ) - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - self.bits == 4 - and HAS_EXLLAMA - and self.quantize == "gptq" - and not self.desc_act - ) + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act if self.quantize == "gptq" and self.quant_method == "gptq": w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] @@ -298,6 +259,7 @@ class GPTQWeightsLoader(WeightsLoader): self._get_gptq_params(weights) use_exllama = True + desc_act = self.desc_act if self.bits != 4: use_exllama = False @@ -321,7 +283,8 @@ class GPTQWeightsLoader(WeightsLoader): if g_idx is not None: if ( not torch.equal( - g_idx.cpu(), + # Remove g_idx[0] to adapt the check with TP>1. + (g_idx - g_idx[0]).cpu(), torch.tensor( [i // self.groupsize for i in range(g_idx.shape[0])], dtype=torch.int32, @@ -332,34 +295,22 @@ class GPTQWeightsLoader(WeightsLoader): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_exllama = False + desc_act = True from text_generation_server.layers.gptq import ( - CAN_EXLLAMA, - HAS_EXLLAMA, GPTQWeight, ) - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and self.groupsize != -1: + if not desc_act and self.groupsize != -1: qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0) + if g_idx is not None: + # qzeros, scales sharded, and g_idx must be adjusted accordingly + g_idx = g_idx - g_idx[0] else: qzeros = weights.get_tensor(f"{prefix}.qzeros") scales = weights.get_tensor(f"{prefix}.scales") - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - if self.quantize == "gptq" and self.quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." @@ -392,7 +343,7 @@ class GPTQWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False @@ -400,41 +351,7 @@ class GPTQWeightsLoader(WeightsLoader): # before the `gptq_sym` setting tensor was added. self.sym = ( weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") + if weights.has_tensor("gptq_sym") else False ) self.quant_method = "gptq" - - -# Needs to be at the end because circular import. -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, # noqa: F401 - create_exllama_buffers, # noqa: F401 - set_device, # noqa: F401 - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 - create_exllama_buffers, # noqa: F401 - set_device, # noqa: F401 - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py b/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py deleted file mode 100644 index 0388ef20b..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/custom_autotune.py +++ /dev/null @@ -1,261 +0,0 @@ -# https://github.com/fpgaminer/GPTQ-triton -""" -Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. -""" - -import builtins -import math -import time -from typing import Dict - -import triton - - -class Autotuner(triton.KernelInterface): - def __init__( - self, - fn, - arg_names, - configs, - key, - reset_to_zero, - prune_configs_by: Dict = None, - nearest_power_of_two: bool = False, - ): - """ - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results - """ - if not configs: - self.configs = [triton.Config({}, num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.nearest_power_of_two = nearest_power_of_two - self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = ( - prune_configs_by["perf_model"], - prune_configs_by["top_k"], - ) - if "early_config_prune" in prune_configs_by: - early_config_prune = prune_configs_by["early_config_prune"] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - self.fn = fn - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.fn.run( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **current, - ) - - try: - # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses - # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - return triton.testing.do_bench( - kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 - ) - except triton.OutOfResources: - return [float("inf"), float("inf"), float("inf")] - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple(args[i] for i in self.key_idx) - - # This reduces the amount of autotuning by rounding the keys to the nearest power of two - # In my testing this gives decent results, and greatly reduces the amount of tuning required - if self.nearest_power_of_two: - key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) - - if key not in self.cache: - # prune configs - pruned_configs = self.prune_configs(kwargs) - bench_start = time.time() - timings = { - config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs - } - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.fn.run( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model( - **self.nargs, - **kwargs, - **config.kwargs, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ - :top_k - ] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - - -def autotune( - configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False -): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - .. highlight:: python - .. code-block:: python - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - - def decorator(fn): - return Autotuner( - fn, - fn.arg_names, - configs, - key, - reset_to_zero, - prune_configs_by, - nearest_power_of_two, - ) - - return decorator - - -def matmul248_kernel_config_pruner(configs, nargs): - """ - The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. - """ - m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) - n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) - k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) - - used = set() - for config in configs: - block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) - block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) - block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) - group_size_m = config.kwargs["GROUP_SIZE_M"] - - if ( - block_size_m, - block_size_n, - block_size_k, - group_size_m, - config.num_stages, - config.num_warps, - ) in used: - continue - - used.add( - ( - block_size_m, - block_size_n, - block_size_k, - group_size_m, - config.num_stages, - config.num_warps, - ) - ) - yield triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - }, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py deleted file mode 100644 index f27666b77..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/exllama.py +++ /dev/null @@ -1,134 +0,0 @@ -from text_generation_server.layers.gptq import GPTQWeight -import torch -from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -def ext_make_q4(qweight, qzeros, scales, g_idx, device): - """Construct Q4Matrix, return handle""" - return make_q4( - qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device - ) - - -def ext_q4_matmul(x, q4, q4_width): - """Matrix multiplication, returns x @ q4""" - outshape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device) - - q4_matmul(x, q4, output) - - return output.view(outshape) - - -MAX_DQ = 1 -MAX_INNER = 1 -ACT_ORDER = False -DEVICE = None - -TEMP_STATE = None -TEMP_DQ = None - - -def set_device(device): - global DEVICE - DEVICE = device - - -def create_exllama_buffers(max_total_tokens: int): - global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ - - assert DEVICE is not None, "call set_device first" - - if not ACT_ORDER: - max_total_tokens = 1 - - # This temp_state buffer is required to reorder X in the act-order case. - temp_state = torch.zeros( - (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE - ) - temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) - - # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - prepare_buffers(DEVICE, temp_state, temp_dq) - - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - TEMP_STATE, TEMP_DQ = temp_state, temp_dq - - -class Ex4bitLinear(torch.nn.Module): - """Linear layer implementation with per-group 4-bit quantization of the weights""" - - def __init__(self, weight: GPTQWeight, bias): - super().__init__() - global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE - assert weight.bits == 4 - - self.device = weight.qweight.device - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None - self.bias = bias if bias is not None else None - - if self.g_idx is not None and ( - (self.g_idx == 0).all() - or torch.equal( - weight.g_idx.cpu(), - torch.tensor( - [i // weight.groupsize for i in range(weight.g_idx.shape[0])], - dtype=torch.int32, - ), - ) - ): - self.empty_g_idx = True - self.g_idx = None - - assert self.device.type == "cuda" - assert self.device.index is not None - - self.q4 = ext_make_q4( - self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index - ) - - self.height = weight.qweight.shape[0] * 8 - self.width = weight.qweight.shape[1] - - # Infer groupsize from height of qzeros - self.groupsize = None - if self.qzeros.shape[0] > 1: - self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) - - if self.groupsize is not None: - assert weight.groupsize == self.groupsize - - # Handle act-order matrix - if self.g_idx is not None: - if self.groupsize is None: - raise ValueError("Found group index but no groupsize. What do?") - self.act_order = True - else: - self.act_order = False - - DEVICE = self.qweight.device - - MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8) - - if self.act_order: - MAX_INNER = max(MAX_INNER, self.height, self.width) - - ACT_ORDER = True - - def forward(self, x): - out = ext_q4_matmul(x, self.q4, self.width) - - if self.bias is not None: - out.add_(self.bias) - return out diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py b/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py deleted file mode 100644 index 920a6adf4..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/exllamav2.py +++ /dev/null @@ -1,267 +0,0 @@ -# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 - -from dataclasses import dataclass -from typing import Optional -import torch -import torch.nn as nn - -from loguru import logger - -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.utils.log import log_master - -try: - from exllamav2.ext import exllamav2_ext - - make_q_matrix = exllamav2_ext.make_q_matrix - gemm_half_q_half = exllamav2_ext.gemm_half_q_half -except ImportError: - log_master(logger.warning, "exllamav2_kernels not installed.") - raise - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -@dataclass -class _ExtraTensors: - """Additional generated quantizer tensors.""" - - q_group_map: Optional[torch.Tensor] = None - q_invperm: Optional[torch.Tensor] = None - q_perm: Optional[torch.Tensor] = None - - -def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): - """Matrix multiplication, returns x @ q4""" - output_shape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) - gemm_half_q_half(x, q_handle, output, force_cuda) - return output.view(output_shape) - - -def make_group_map(q_groups: torch.Tensor, num_qrows: int): - gr = q_groups.tolist() - group_map = [] - num_groups = len(gr) // 2 - - for i in range(num_groups): - bits = gr[i * 2] - if i < num_groups - 1: - qrows = gr[i * 2 + 3] - gr[i * 2 + 1] - else: - qrows = num_qrows - gr[i * 2 + 1] - rows = qrows * 32 // bits - for j in range(rows): - group_map += [i] - group_map += [rows - j] - - return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) - - -# Create Q matrix - - -def ext_make_q_matrix( - w: Exl2Weight | GPTQWeight, - extra: _ExtraTensors, - temp_dq, - key: Optional[str] = None, -): - """ - Create Q matrix - """ - # max_dq_size = 512*(1024**2) - # max_dq_rows = max_dq_size // out_features[0] - max_dq_rows = 0 - - # EXL2 - if isinstance(w, Exl2Weight): - extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) - extra.q_perm = torch.argsort(w.q_invperm).short() - - return make_q_matrix( - w.q_weight, - extra.q_perm, - w.q_invperm, - w.q_scale, - w.q_scale_max, - w.q_groups, - extra.q_group_map, - none_tensor, # zeros - none_tensor, # scales - none_tensor, # g_idx - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - # GPTQ - elif isinstance(w, GPTQWeight): - if w.scales.dtype == torch.float: - w.scales = w.scales.half() - - # GPTQ with g_idx (act_order) - if w.g_idx is not None and not (w.g_idx == 0).all().item(): - extra.q_perm = torch.empty( - (w.qweight.shape[0] * 8,), - dtype=torch.short, - device=w.qweight.device, - ) - extra.q_invperm = torch.empty_like(extra.q_perm) - # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return make_q_matrix( - w.qweight, - extra.q_perm, - extra.q_invperm, - none_tensor, # q_scale - none_tensor, # q_scale_max - none_tensor, # q_groups - none_tensor, # q_group_map - w.qzeros, - w.scales, - w.g_idx.cpu(), - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - # GPTQ without g_idx - else: - return make_q_matrix( - w.qweight, - none_tensor, # q_perm - none_tensor, # q_invperm - none_tensor, # q_scale - none_tensor, # q_scale_max - none_tensor, # q_groups - none_tensor, # q_group_map - w.qzeros, - w.scales, - none_tensor, # g_idx - none_tensor, # bias - temp_dq, - max_dq_rows, - ) - else: - RuntimeError("Cannot create handle") - - -DEVICE = None -LAYERS = [] - - -def set_device(device): - global DEVICE - DEVICE = device - - -def create_exllama_buffers(max_total_tokens: int): - global LAYERS, DEVICE - - # No need to initialize scratch space if there are no layers - # that use ExLLamav2. - if len(LAYERS) == 0: - return - - # Find the size of the scratch space. - scratch_bytes = max( - layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) - for layer in LAYERS - ) - temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) - - for layer in LAYERS: - layer.post_init(temp_dq) - - -class QuantLinear(nn.Module): - QUANT_TYPE = "exllamav2" - - """Linear layer implementation with per-group 4-bit quantization of the weights""" - - def __init__( - self, - weight: Exl2Weight | GPTQWeight, - bias: torch.Tensor, - ): - super().__init__() - - self.q_handle = None - self.q_tensors = weight - self.extra_tensors = _ExtraTensors() - - if isinstance(weight, Exl2Weight): - self.infeatures = weight.q_invperm.shape[0] - self.outfeatures = weight.q_weight.shape[1] - elif isinstance(weight, GPTQWeight): - if weight.bits != 4: - raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization." - ) - - self.infeatures = weight.qweight.shape[0] // weight.bits * 32 - self.outfeatures = weight.qweight.shape[1] - - self.padding = -self.outfeatures % 32 - self.outfeatures = self.outfeatures + self.padding - - self.device = weight.device - self.bias = bias if bias is not None else None - - global LAYERS - LAYERS.append(self) - - def post_init(self, temp_dq): - device = self.q_tensors.device - assert device.type == "cuda" - assert device.index is not None - temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - - # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, - # and `Memory access fault by GPU node-2` will EAT you. - self.temp_dq = temp_dq - self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq) - - def forward(self, x, force_cuda=False): - output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) - - if self.bias is not None: - output.add_(self.bias) - return output - - def temp_dq_size(self): - return self.infeatures * self.outfeatures * 2 + 128 - - def temp_fwd_size(self, max_input_len, max_batch_size): - return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - - def scratch_space_fixed(self, max_input_len, max_batch_size): - return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) - - -class ExLlamaV2DeviceTensors: - - device_idx: int - scratch_bytes: int - scratch_idx: int - scratch: torch.tensor = None - - def __init__(self, device, scratch_bytes): - self.device = device - self.scratch_bytes = scratch_bytes - - def prepare(self): - self.scratch = torch.empty( - (self.scratch_bytes // 2,), dtype=torch.half, device=self.device - ) - - def get_scratch_slice(self, size_bytes): - - if self.scratch is None: - self.prepare() - - size_bytes = ((size_bytes + 127) // 128) * 128 - size_half = size_bytes // 2 - scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py new file mode 100644 index 000000000..72944fa0e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py @@ -0,0 +1,186 @@ +import math +import numpy as np +import torch +import torch.nn as nn + +try: + + convert_from_uint4 = torch.ops.hpu.convert_from_uint4 +except Exception as e: + hpu_import_exception = e + + def error_raiser_hpu(*args, **kwargs): + raise ValueError( + f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" + ) + + convert_from_uint4 = error_raiser_hpu + + +def pack_tensor(input, bits=4): + normal = input.to(torch.int32) + q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32) + i = 0 + col = 0 + while col < q.shape[1]: + for j in range(i, i + (32 // bits)): + q[:, col] |= normal[:, j] << (bits * (j - i)) + i += 32 // bits + col += 1 + q = q.to(torch.int32) + return q + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self.wf = torch.tensor( + list(range(0, 32, self.bits)), dtype=torch.int32 + ).unsqueeze(0) + self._preprocessing() + + def unpack_zeros_from_cuda_old_format(self): + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0), + ).to(torch.int16 if self.bits == 8 else torch.int8) + + zeros = zeros + 1 + zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to( + self.scales.dtype + ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. + zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2]) + return zeros + + def unpack_weight_from_cuda_old_format(self): + weight = torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1), + ).to(torch.int16 if self.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**self.bits) - 1) + weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2])) + return weight + + def _preprocessing(self): + orig_device = self.qweight.device + self.qweight = self.qweight.cpu() + weight = self.unpack_weight_from_cuda_old_format() + new_qweight = pack_tensor(weight) + self.qweight = new_qweight.to(orig_device) + # TODO: Support group indexing and remove the check + columns = self.qweight.shape[0] + g_idx_trivial = [i // self.groupsize for i in range(columns)] + g_idx_trivial = torch.tensor( + g_idx_trivial, dtype=torch.int32, device=self.g_idx.device + ) + assert torch.equal( + self.g_idx, g_idx_trivial + ), "Non-trivial tensor g_idx is not supported" + self.qzeros = self.qzeros.cpu() + zeros = self.unpack_zeros_from_cuda_old_format() + new_qzeros = pack_tensor(zeros) + self.qzeros = new_qzeros.to(orig_device) + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) + out = torch.matmul(x, weight) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py b/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py deleted file mode 100644 index 736c357b0..000000000 --- a/backends/gaudi/server/text_generation_server/layers/gptq/quant_linear.py +++ /dev/null @@ -1,359 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_fwd - -import triton -import triton.language as tl -from . import custom_autotune - - -# code based https://github.com/fpgaminer/GPTQ-triton -@custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=3, - num_warps=8, - ), - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, -) -@triton.jit -def matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - g_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ( - (offs_k[:, None] // infearure_per_bits) * stride_bk - + offs_bn[None, :] * stride_bn - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load( - scales_ptrs + g_idx[:, None] * stride_scales - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load( - zeros_ptrs + g_idx[:, None] * stride_zeros - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # eventually avoid overflow - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty( - (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 - ) - - def grid(META): - return ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - - matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - ) - return output - - -class QuantLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - return output - - -class QuantLinear(nn.Module): - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - super().__init__() - self.register_buffer("qweight", qweight) - self.register_buffer("qzeros", qzeros) - self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize - - self.outfeatures = qweight.shape[1] - self.infeatures = qweight.shape[0] * 32 // bits - - @classmethod - def new(cls, bits, groupsize, infeatures, outfeatures, bias): - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) - qzeros = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), - dtype=torch.int32, - ) - scales = torch.zeros( - (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 - ) - g_idx = torch.tensor( - [i // groupsize for i in range(infeatures)], dtype=torch.int32 - ) - if bias: - bias = torch.zeros((outfeatures), dtype=torch.float16) - else: - bias = None - return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) - / self.scales[self.g_idx[idx]] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros( - (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 - ) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 - ) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py index b0086ea08..aa664ea60 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files -from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq import QuantLinear from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error @@ -956,15 +956,24 @@ def quantize( pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file - from transformers.modeling_utils import shard_checkpoint + from huggingface_hub import split_torch_state_dict_into_shards state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} max_shard_size = "10GB" - shards, index = shard_checkpoint( - state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + state_dict_split = split_torch_state_dict_into_shards( + state_dict, + filename_pattern="model.safetensors", + max_shard_size=max_shard_size, ) + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + shards = state_dict_split.filename_to_tensors os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): save_file( diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index ce5289f93..848787910 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -1,9 +1,6 @@ import torch from torch import nn from accelerate import init_empty_weights -from text_generation_server.utils.import_utils import ( - SYSTEM, -) # Monkey patching @@ -33,69 +30,14 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias -if SYSTEM == "cuda": - import dropout_layer_norm - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - -elif SYSTEM == "rocm": - from vllm._C import ops - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super().forward(hidden_states), residual - -elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - out = ipex.llm.functional.add_layer_norm( - residual, - hidden_states, - self.weight, - self.bias, - self.eps, - residual is not None, - ) - return out, residual if residual is not None else hidden_states + return super().forward(hidden_states), residual class FastRMSNorm(nn.Module): @@ -111,74 +53,15 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "ipex": - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - residual is not None, - ) - return out, residual if residual is not None else hidden_states - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states + from vllm_hpu_extension.kernels import rms_norm - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states, residual - elif SYSTEM == "cuda": - # faster post attention rms norm - ( - normed_hidden_states, - res, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual + orig_shape = hidden_states.shape + if residual is not None: + residual += hidden_states.view(residual.shape) else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + residual = hidden_states + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + if len(orig_shape) == 2: + residual = residual.unsqueeze(0) + x = rms_norm().apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual.view(orig_shape) diff --git a/backends/gaudi/server/text_generation_server/layers/linear.py b/backends/gaudi/server/text_generation_server/layers/linear.py index 08306d579..cca80c44e 100644 --- a/backends/gaudi/server/text_generation_server/layers/linear.py +++ b/backends/gaudi/server/text_generation_server/layers/linear.py @@ -1,21 +1,5 @@ import torch -from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F -import os - -if SYSTEM == "rocm": - ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( - "true", - "1", - ) - - if ROCM_USE_SKINNY_GEMM: - try: - from vllm import _custom_C - except Exception as e: - raise ImportError( - f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" - ) class FastLinear(torch.nn.Module): @@ -44,83 +28,11 @@ class FastLinear(torch.nn.Module): return F.linear(input, self.weight, self.bias) -class FastLinearROCm(torch.nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.weight = torch.nn.Parameter(weight) - if bias is not None: - self.bias = torch.nn.Parameter(bias) - else: - self.bias = None - - self.cu_count = torch.cuda.get_device_properties( - device="cuda" - ).multi_processor_count - self.use_skinny_gemm = ( - ROCM_USE_SKINNY_GEMM - and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName - ) - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") - if bias: - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None - return cls(weight, bias) - - def forward(self, inp: torch.Tensor) -> torch.Tensor: - weight = self.weight - bias = self.bias - - if ( - self.use_skinny_gemm - and inp.dtype == torch.float16 - and inp.shape[-1] % 8 == 0 - ): - batched = False - inp_shape = inp.shape - - if inp.dim() == 3: - inp = inp.view(-1, inp_shape[-1]) - batched = True - - m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] - if m > 8 and n <= 4: - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device - ) - _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) - elif m % 4 == 0 and n == 1 and k <= 8192: - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device - ) - _custom_C.LLMM1(weight, inp, out, 4) - else: - out = F.linear(inp, weight) - - if batched: - out.view(*inp_shape[:-1], out.shape[-1]) - - if bias is not None: - out = out + bias - return out - return F.linear(inp, self.weight, self.bias) - - def get_linear(weight, bias): # Weights that are loaded through methods that are not # quantization-aware are still bare tensors. We may want # to change this in the future. if isinstance(weight, torch.Tensor): - if SYSTEM == "rocm": - return FastLinearROCm(weight, bias) - else: - return FastLinear(weight, bias) + return FastLinear(weight, bias) return weight.get_linear(bias) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py b/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py deleted file mode 100644 index 3ff3ed58f..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeightsLoader, - can_use_gptq_marlin, - repack_gptq_for_marlin, -) -from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader - -__all__ = [ - "GPTQMarlinFP8Linear", - "GPTQMarlinWeightsLoader", - "MarlinWeightsLoader", - "can_use_gptq_marlin", - "repack_gptq_for_marlin", -] diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py b/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py deleted file mode 100644 index fe55a58a3..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/fp8.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.fp8 import fp8_quantize -from text_generation_server.layers.marlin.gptq import _check_valid_shape -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - permute_scales, -) -from text_generation_server.utils.log import log_once - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - - -MARLIN_TILE_SIZE = 16 - - -class GPTQMarlinFP8Linear(nn.Module): - """ - FP8 GPTQ-Marlin linear layer. - """ - - def __init__( - self, - qweight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> None: - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - - scales = scales.unsqueeze(0) - if scales.shape[1] == 1: - out_features, in_features = qweight.shape - scales = scales.repeat(1, out_features) - qweight, scales = repack_fp8_for_marlin(qweight, scales) - - in_features = qweight.shape[0] * MARLIN_TILE_SIZE - out_features = scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.qweight = qweight - self.scales = scales - self.bias = bias if bias is not None else None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=qweight.device - ) - - @classmethod - def from_unquant(cls, weight, bias, dtype): - qweight, scales = fp8_quantize(weight) - return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) - - @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, dtype): - return cls(qweight=weight, scales=scale.to(dtype), bias=bias) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.fp8_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.workspace, - 8, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements). - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - - if fp8_tensor.shape[0] % 4 != 0: - raise ValueError( - f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" - ) - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = torch.zeros( - fp8_tensor.shape[0] // 4, - fp8_tensor.shape[1], - dtype=torch.int32, - device=fp8_tensor.device, - ) - - for i in range(4): - packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) - - return packed - - -def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): - """ - Repack FP8 tensor for GPTQ-Marlin. - """ - - out_features, in_features = weight.shape - - # Torch linear layers weights with shape [out_features, in_features], - # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], - # so transpose before packing. - qweight = pack_fp8_as_int32(weight.t()) - - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, 8 - ) - - scales = permute_scales(scales) - - return repacked, scales diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py b/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py deleted file mode 100644 index 0a785d944..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/gptq.py +++ /dev/null @@ -1,464 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy -import torch -import torch.nn as nn -from loguru import logger -from text_generation_server.layers.marlin.util import ( - _check_marlin_kernels, - marlin_zero_points, - permute_scales, - unpack_cols, -) -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -GPTQ_MARLIN_BITS = [4, 8] -GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] -MARLIN_TILE_SIZE = 16 - - -def can_use_gptq_marlin( - *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool -) -> bool: - return ( - SYSTEM == "cuda" - and marlin_kernels is not None - and has_sm_8_0 - and quantize in {"awq", "gptq"} - and quant_method in {"awq", "gptq"} - and bits in GPTQ_MARLIN_BITS - and groupsize in GPTQ_MARLIN_GROUP_SIZES - # We only suppord asymmetric quantization for AWQ. - and (sym or quant_method == "awq") - ) - - -class GPTQMarlinWeightsLoader(WeightsLoader): - """ - Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. - """ - - def __init__( - self, - *, - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - quantize: str, - sym: bool, - ): - self.bits = bits - self.desc_act = desc_act - self.groupsize = groupsize - self.quant_method = quant_method - self.quantize = quantize - self.sym = sym - - def get_weights(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_tensor(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - scales = weights.get_tensor(f"{prefix}.scales") - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - try: - qweight = weights.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." - ) - scales = weights.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=weights.dtype) - - if not self.sym: - qzeros = weights.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - try: - qweight = torch.cat( - [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - if not self.sym: - qzeros = torch.cat( - [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=False, - ) - - def get_weights_row(self, weights: Weights, prefix: str): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - if not self.sym: - if self.desc_act or self.groupsize == -1: - qzeros = weights.get_tensor(f"{prefix}.qzeros") - else: - qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) - else: - qzeros = None - - if self.quant_method == "awq": - g_idx = None - else: - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - qzeros=qzeros, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - quant_method=self.quant_method, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) - - def _get_gptq_params(self, weights: Weights): - if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): - self.bits = weights.get_tensor("gptq_bits").item() - self.groupsize = weights.get_tensor("gptq_groupsize").item() - self.desc_act = False - # `server quantize` used asymmetric quantization unconditionally - # before the `gptq_sym` setting tensor was added. - self.sym = ( - weights.get_tensor("gptq_sym").item() - if weights._has_tensor("gptq_sym") - else False - ) - self.quant_method = "gptq" - - -@dataclass -class GPTQMarlinWeight(Weight): - """ - Repacked GPTQ Marlin weights. - """ - - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - bits: int - is_full_k: bool - - def __post_init__(self): - assert self.qweight.dtype == torch.int32 - assert self.scales.dtype == torch.float16 - assert self.g_idx.dtype == torch.int32 - assert self.perm.dtype == torch.int32 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlinLinear( - weight=self, - bias=bias, - ) - - -def repack_gptq_for_marlin( - *, - qweight: torch.Tensor, - qzeros: Optional[torch.Tensor], - scales: torch.Tensor, - g_idx: Optional[torch.Tensor], - bits: int, - desc_act: bool, - groupsize: int, - quant_method: str, - sym: bool, - sharded_infeatures: bool, -) -> GPTQMarlinWeight: - """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" - _check_marlin_kernels() - assert marlin_kernels is not None - - if bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" - ) - - if groupsize not in GPTQ_MARLIN_GROUP_SIZES: - supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) - raise RuntimeError( - f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" - ) - if not (sym or quant_method == "awq"): - raise RuntimeError( - "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." - ) - - log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") - - weights_per_int = 32 // bits - in_features = qweight.shape[0] - out_features = qweight.shape[1] - - # AWQ uses column packing, GPTQ uses row packing - if quant_method == "awq": - out_features *= weights_per_int - else: - in_features *= weights_per_int - - if in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({groupsize})" - ) - - if g_idx is not None and desc_act and groupsize != -1: - perm = torch.argsort(g_idx).to(torch.int) - g_idx = g_idx[perm] - else: - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - - if quant_method == "awq": - repacked = marlin_kernels.awq_marlin_repack( - qweight, in_features, out_features, bits - ) - if qzeros is not None: - qzeros = awq_to_marlin_zero_points( - qzeros, - in_features // groupsize, - out_features, - bits, - ) - - else: - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) - - if qzeros is None: - qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) - - scales = permute_scales(scales) - - is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) - - return GPTQMarlinWeight( - qweight=repacked, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - bits=bits, - is_full_k=is_full_k, - ) - - -class GPTQMarlinLinear(nn.Module): - """ - Linear layer for GPTQ weights that were converted for the GPTQ-Marlin - kernels. - """ - - def __init__( - self, - *, - weight: GPTQMarlinWeight, - bias: Optional[torch.Tensor], - ): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE - out_features = weight.scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) - - self.bits = weight.bits - self.is_full_k = weight.is_full_k - - self.qweight = weight.qweight - self.qzeros = weight.qzeros - self.scales = weight.scales - self.g_idx = weight.g_idx - self.perm = weight.perm - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.perm, - self.workspace, - self.bits, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, - self.qzeros.numel() > 0, - True, - ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def _check_valid_shape(in_features: int, out_features: int): - if (in_features % 128 != 0 or out_features % 64 != 0) and ( - in_features % 64 != 0 or out_features % 128 != 0 - ): - raise ValueError( - f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." - " The shape elements must be divisible by (128, 64) or (64, 128)." - ) diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py b/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py deleted file mode 100644 index 89ebaca62..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/marlin.py +++ /dev/null @@ -1,346 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import torch -import torch.nn as nn -from text_generation_server.layers.marlin.util import _check_marlin_kernels -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - - -class MarlinWeightsLoader(WeightsLoader): - """Loader for Marlin-quantized weights.""" - - def __init__(self, *, bits: int, is_marlin_24: bool): - self.bits = bits - self.is_marlin_24 = is_marlin_24 - - def get_weights(self, weights: "Weights", prefix: str): - """ - Get weights at the given prefix and apply without tensor paralllism. - """ - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = weights.get_tensor(f"{prefix}.B_24") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_tensor(f"{prefix}.B_meta") - s = weights.get_tensor(f"{prefix}.s") - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = weights.get_tensor(f"{prefix}.B") - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - s = weights.get_tensor(f"{prefix}.s") - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - if self.is_marlin_24: - B = weights.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = weights.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - B = weights.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = weights.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - if self.is_marlin_24: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = torch.cat( - [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized" - ) - s = torch.cat( - [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - return weight - - def get_weights_row(self, weights: Weights, prefix: str): - if self.is_marlin_24: - try: - B = weights.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - - weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) - else: - try: - B = weights.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = weights.get_tensor(f"{prefix}.s") - else: - s = weights.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - - return weight - - -@dataclass -class MarlinWeight(Weight): - """ - Marlin weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): bfloat16/float16 scales. - """ - - B: torch.Tensor - s: torch.Tensor - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.s.dtype in [torch.float16, torch.bfloat16] - - def get_linear(self, bias: torch.Tensor): - return MarlinLinear(weight=self, bias=bias) - - -class MarlinLinear(nn.Module): - def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE - out_features = weight.s.shape[1] - assert ( - in_features % 128 == 0 - ), f"Number of input features ({in_features}) not divisable by 128" - assert ( - out_features % 256 == 0 - ), f"Number of output features ({out_features}) not divisable by 256" - - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert groupsize in { - -1, - 128, - }, f"Group size must be -1 or 128, was {groupsize}" - - self.B = weight.B - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.B.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.marlin_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.s, - self.workspace, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C - - -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_TILE_SIZE = 16 - - -@dataclass -class GPTQMarlin24Weight: - """ - GPTQ-Marlin 2:4 weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - B_meta (torch.Tensor): metadata for 2:4 sparsity. - s (torch.Tensor): float16 scales. - bits: quantized weight size. - """ - - B: torch.Tensor - B_meta: torch.Tensor - s: torch.Tensor - bits: int - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.B_meta.dtype == torch.int16 - assert self.s.dtype == torch.float16 - - def get_linear(self, bias: torch.Tensor): - return GPTQMarlin24Linear( - weight=self, - bias=bias, - ) - - -class GPTQMarlin24Linear(nn.Module): - def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: - supported_bits = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS - ) - raise RuntimeError( - f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" - ) - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2 - out_features = weight.s.shape[1] - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - - if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: - supported_sizes = ", ".join( - str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ) - raise RuntimeError( - f"Group size {groupsize} is not supported, must be one of: {supported_sizes}" - ) - - self.bits = weight.bits - weights_per_int32 = 32 // self.bits - - assert ( - out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads" - assert ( - out_features % weights_per_int32 == 0 - ), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})" - - assert ( - in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0 - ), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads" - if groupsize != -1 and in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisable by group size ({groupsize})" - ) - - self.B = weight.B - self.B_meta = weight.B_meta - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL, - dtype=torch.int, - device=weight.B.device, - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.gptq_marlin_24_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.B_meta, - self.s, - self.workspace, - self.bits, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C diff --git a/backends/gaudi/server/text_generation_server/layers/marlin/util.py b/backends/gaudi/server/text_generation_server/layers/marlin/util.py deleted file mode 100644 index 250d17141..000000000 --- a/backends/gaudi/server/text_generation_server/layers/marlin/util.py +++ /dev/null @@ -1,141 +0,0 @@ -import functools -from typing import List, Tuple - -import numpy -import torch -from text_generation_server.utils.import_utils import SYSTEM - -try: - import marlin_kernels -except ImportError: - marlin_kernels = None - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def _check_marlin_kernels(): - if not (SYSTEM == "cuda" and has_sm_8_0): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) - - if marlin_kernels is None: - raise NotImplementedError( - "marlin is not installed, install it with: pip install server/marlin" - ) - - -# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 -@functools.cache -def get_perms() -> Tuple[List[int], List[int]]: - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def permute_scales(scales: torch.Tensor): - scale_perm, scale_perm_single = get_perms() - out_features = scales.shape[1] - if scales.shape[0] == 1: - scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - else: - scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] - return scales.reshape((-1, out_features)).contiguous() - - -# Functions below are from vLLM - - -def get_pack_factor(bits: int) -> int: - if 32 % bits != 0: - raise ValueError(f"Cannot {bits} bit values into uint32") - return 32 // bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - scale_perm, _ = get_perms() - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py index 2c46ca02a..8b9d6fcb0 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py @@ -10,13 +10,8 @@ from text_generation_server.layers import ( TensorParallelRowLinear, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader -from text_generation_server.layers.moe.gptq_marlin import ( - GPTQMarlinSparseMoELayer, - can_use_marlin_moe_gemm, -) from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, @@ -24,12 +19,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, ) -if SYSTEM == "rocm": - from .fused_moe_rocm import grouped_topk - from vllm.model_executor.layers.fused_moe import fused_topk -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_topk, grouped_topk - +from .fused_moe import fused_topk, grouped_topk # NOTE: we are using a protocol here, because multiple inherance is not nice. # We need `Module`, and `Module` -> some abstract class -> some concrete @@ -52,6 +42,8 @@ class MoELayer(Protocol): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): ... def forward( @@ -81,9 +73,14 @@ class DenseMoELayer(nn.Module): up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", hidden_act: str = "silu", + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ): super().__init__() + assert scoring_func is None, "scoring func is not handled" + assert e_score_correction_bias is None, "scoring correction bias is not handled" + log_once( logger.info, "No fused layers are available for this model type, using (slower) dense MoE layer", @@ -199,22 +196,27 @@ class SparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", ): super().__init__() - if ( isinstance(weights.loader, DefaultWeightsLoader) and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): - cls = UnquantizedSparseMoELayer - elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: - cls = GPTQMarlinSparseMoELayer + if ( + isinstance(weights.loader, HybridFP8UnquantLoader) + and weights.loader.to_fp8 + ): + cls = FP8SparseMoELayer + else: + cls = UnquantizedSparseMoELayer else: raise ValueError( - f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights" ) log_once( @@ -230,6 +232,8 @@ class SparseMoELayer(nn.Module): topk=topk, topk_group=topk_group, weights=weights, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, down_proj_name=down_proj_name, @@ -241,17 +245,6 @@ class SparseMoELayer(nn.Module): @staticmethod def is_supported(weights: Weights) -> bool: return ( - ( - isinstance(weights.loader, DefaultWeightsLoader) - and isinstance(weights.loader.weight_class, UnquantizedWeight) - ) - or isinstance(weights.loader, HybridFP8UnquantLoader) - or ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) - and can_use_marlin_moe_gemm( - quant_method=weights.loader.quant_method, - quantize=weights.loader.quantize, - sym=weights.loader.sym, - ) - ) - ) + isinstance(weights.loader, DefaultWeightsLoader) + and isinstance(weights.loader.weight_class, UnquantizedWeight) + ) or isinstance(weights.loader, HybridFP8UnquantLoader) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py new file mode 100644 index 000000000..071b2abee --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -0,0 +1,173 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.fp8 import ( + Fp8Weight, + fp8_quantize, + quant_dtype, + normalize_e4m3fn_to_native_float8, +) + +try: + from .unquantized import fused_moe +except Exception: + fused_moe = None + + +class FP8SparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + + ( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.gate_up_proj_input_scale, + ) = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( + _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + use_fp8_w8a8=True, + w1_scale=self.gate_up_proj_weight_scale, + w2_scale=self.down_proj_weight_scale, + a1_scale=self.gate_up_proj_input_scale, + a2_scale=self.down_proj_input_scale, + ) + + +def _load_expert_weights( + get_weight_fn, + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + all_weight_scales = None + max_input_scale = None + + for i in range(n_experts): + weight = get_weight_fn(prefix, i, name, weights) + + assert isinstance(weight, Fp8Weight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=quant_dtype, + device=weight.weight.device, + ) + if all_weight_scales is None: + all_weight_scales = torch.empty( + (n_experts,) + weight.weight_scale.shape, + dtype=torch.float32, + device=weight.weight.device, + ) + + if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: + all_weight[i], all_weight_scales[i], current_input_scale = ( + normalize_e4m3fn_to_native_float8( + weight.weight, weight.weight_scale, weight.input_scale + ) + ) + if current_input_scale is not None: + if max_input_scale is None or current_input_scale > max_input_scale: + max_input_scale = current_input_scale + else: + all_weight[i], all_weight_scales[i] = fp8_quantize( + weight.weight, scalar=True + ) + + assert all_weight is not None + + return all_weight, all_weight_scales, max_input_scale + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + ) + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights_row(f"{prefix}.{i}.{name}") + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py similarity index 80% rename from backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py rename to backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py index 68accb990..e26ff8770 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py @@ -16,10 +16,8 @@ from typing import Tuple import torch -import torch.distributed -# TODO: Remove the functions once moe_kernel are built for ROCM def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -50,3 +48,18 @@ def grouped_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + topk_weights = torch.nn.functional.softmax( + gating_output, dim=1, dtype=torch.float32 + ) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py b/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py deleted file mode 100644 index 3217cdc22..000000000 --- a/backends/gaudi/server/text_generation_server/layers/moe/gptq_marlin.py +++ /dev/null @@ -1,215 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import torch -import torch.nn as nn - -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.marlin.gptq import ( - GPTQMarlinWeight, - GPTQMarlinWeightsLoader, -) - -if SYSTEM == "cuda": - from moe_kernels.fused_marlin_moe import fused_marlin_moe -else: - fused_marlin_moe = None - - -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -def can_use_marlin_moe_gemm( - *, - quant_method: str, - quantize: str, - sym: bool, -): - return ( - SYSTEM == "cuda" - and fused_marlin_moe is not None - and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" - and sym - ) - - -@dataclass -class GPTQMarlinMoEWeight: - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - is_full_k: bool - - -class GPTQMarlinSparseMoELayer(nn.Module): - """ - MoE layer that uses a fused GPTQ-Marlin kernel. - """ - - def __init__( - self, - *, - n_expert_group: Optional[int], - n_experts: int, - prefix: str, - renormalize: bool, - topk: int, - topk_group: Optional[int], - weights: Weights, - gate_proj_name: str = "gate_proj", - up_proj_name: str = "up_proj", - down_proj_name: str = "down_proj", - ): - super().__init__() - - if not ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym - ): - raise ValueError( - f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" - ) - - assert (n_expert_group is None) == ( - topk_group is None - ), "n_expert_group and topk_group must both be None or have some value" - - self.n_expert_group = n_expert_group - self.topk = topk - self.topk_group = topk_group - self.renormalize = renormalize - - self.gate_up_proj = _load_expert_multi_weights_col( - prefix=prefix, - n_experts=n_experts, - names=[gate_proj_name, up_proj_name], - weights=weights, - ) - - self.down_proj = _load_expert_weights_row( - prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights - ) - - self.bits = weights.loader.bits - - def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_marlin_moe( - x, - w1=self.gate_up_proj.qweight, - w2=self.down_proj.qweight, - g_idx1=self.gate_up_proj.g_idx, - g_idx2=self.down_proj.g_idx, - perm1=self.gate_up_proj.perm, - perm2=self.down_proj.perm, - w1_scale=self.gate_up_proj.scales, - w2_scale=self.down_proj.scales, - is_full_k1=self.gate_up_proj.is_full_k, - is_full_k2=self.down_proj.is_full_k, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - num_bits=self.bits, - ) - - -def _load_expert_multi_weights_col( - *, - prefix: str, - n_experts: int, - names: List[str], - weights: Weights, -) -> GPTQMarlinMoEWeight: - moe_weight = None - for i in range(n_experts): - weight = weights.get_multi_weights_col( - [f"{prefix}.{i}.{name}" for name in names], 0 - ) - assert isinstance(weight, GPTQMarlinWeight) - moe_weight = _pack_weight( - n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight - ) - assert moe_weight is not None - return moe_weight - - -def _load_expert_weights_row( - *, - prefix: str, - n_experts: int, - name: str, - weights: Weights, -) -> GPTQMarlinMoEWeight: - moe_weight = None - for i in range(n_experts): - weight = weights.get_weights_row( - f"{prefix}.{i}.{name}", - ) - assert isinstance(weight, GPTQMarlinWeight) - moe_weight = _pack_weight( - n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight - ) - assert moe_weight is not None - return moe_weight - - -def _pack_weight( - *, - n_experts: int, - expert: int, - moe_weight: Optional[GPTQMarlinMoEWeight], - weight: GPTQMarlinWeight, -) -> GPTQMarlinMoEWeight: - if moe_weight is None: - qweight = torch.empty( - (n_experts,) + weight.qweight.shape, - dtype=weight.qweight.dtype, - device=weight.qweight.device, - ) - qzeros = torch.empty( - (n_experts,) + weight.qzeros.shape, - dtype=weight.qzeros.dtype, - device=weight.qzeros.device, - ) - scales = torch.empty( - (n_experts,) + weight.scales.shape, - dtype=weight.scales.dtype, - device=weight.scales.device, - ) - g_idx = torch.empty( - (n_experts,) + weight.g_idx.shape, - dtype=weight.g_idx.dtype, - device=weight.g_idx.device, - ) - perm = torch.empty( - (n_experts,) + weight.perm.shape, - dtype=weight.perm.dtype, - device=weight.perm.device, - ) - - moe_weight = GPTQMarlinMoEWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - perm=perm, - is_full_k=weight.is_full_k, - ) - - moe_weight.qweight[expert] = weight.qweight - moe_weight.qzeros[expert] = weight.qzeros - moe_weight.scales[expert] = weight.scales - moe_weight.g_idx[expert] = weight.g_idx - moe_weight.perm[expert] = weight.perm - - return moe_weight diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index d9d62c0ef..ec1583989 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -3,13 +3,8 @@ from typing import Optional import torch import torch.nn as nn -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights - -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM != "ipex": - from moe_kernels.fused_moe import fused_moe +from vllm_hpu_extension.ops import DynamicFusedMOE class UnquantizedSparseMoELayer(nn.Module): @@ -23,6 +18,8 @@ class UnquantizedSparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", @@ -37,6 +34,9 @@ class UnquantizedSparseMoELayer(nn.Module): self.topk = topk self.topk_group = topk_group self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, @@ -53,30 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) - def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - if SYSTEM == "rocm": - return fused_moe( - x, - self.gate_up_proj, - self.down_proj, - gating_output, - self.topk, - renormalize=self.renormalize, - inplace=True, - ) + self.hpu_fused_moe = DynamicFusedMOE(n_experts) + for i in range(n_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - ) + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return self.hpu_fused_moe(x, gating_output, self.topk) def _load_expert_multi_weights_col( diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index a2076bb20..6a83d6a57 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -2,14 +2,10 @@ import os import math import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM == "cuda": - import rotary_emb -elif SYSTEM == "rocm": - from vllm._C import ops -elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) def _create_inv_freq(dim, base, device): @@ -30,7 +26,7 @@ def _get_rope_config(config): class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq, scaling_factor): + def __init__(self, inv_freq, scaling_factor, max_position_embeddings): super().__init__() self.inv_freq = inv_freq self._seq_len_cached = 0 @@ -40,6 +36,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, inv_freq.device, max_position_embeddings + ) def forward( self, @@ -48,40 +47,41 @@ class PositionRotaryEmbedding(nn.Module): cos: torch.Tensor, sin: torch.Tensor, ): - # Such controlflows may add some overhead. - if SYSTEM == "cuda": - rotary_dim = cos.shape[-1] - q1 = query[..., :rotary_dim] - q2 = query[..., rotary_dim : 2 * rotary_dim] + num_tokens = query.shape[0] + head_size = query.shape[-1] + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., :rotary_dim] - k2 = key[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "ipex": - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), True - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) scaling_factor = None rope_scaling = _get_rope_config(config) + if not hasattr(config, "max_position_embeddings") and hasattr( + config, "max_seq_len" + ): + # handling for dbrx + config.max_position_embeddings = config.max_seq_len if rope_scaling is not None: # `rope_type` is now standard in transformers, but some existing models # have `type` instead. @@ -89,6 +89,17 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass + elif rope_type == "default": + pass + elif rope_type == "mrope": + mrope_section = rope_scaling["mrope_section"] + if mrope_section is not None: + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, + scaling_factor, + mrope_section, + config.max_position_embeddings, + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -109,7 +120,7 @@ class PositionRotaryEmbedding(nn.Module): ], ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] @@ -185,12 +196,13 @@ class PositionRotaryEmbedding(nn.Module): long_inv_freq=long_inv_freq, scaling_factor=scaling_factor, original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=config.max_position_embeddings, ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) @classmethod def load(cls, config, prefix, weights): @@ -236,7 +248,7 @@ class PositionRotaryEmbedding(nn.Module): raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) - return cls(inv_freq, scaling_factor) + return cls(inv_freq, scaling_factor, config.max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -257,17 +269,7 @@ class PositionRotaryEmbedding(nn.Module): self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - if SYSTEM == "rocm": - # For RoCm, we always use float cos/sin to avoid a cast. - # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 - # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. - dtype = torch.float32 - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) + def get_cos_sin(self, position_ids: torch.Tensor): cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -283,6 +285,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): long_inv_freq, scaling_factor, original_max_position_embeddings, + max_position_embeddings, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq @@ -295,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -348,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -383,7 +392,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -461,7 +470,9 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__(inv_freq, scaling_factor) + super().__init__( + inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor + ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -546,3 +557,50 @@ def apply_llama3_scaling( new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): + def __init__( + self, + inv_freq: torch.Tensor, + scaling_factor: float, + sections: list, + max_position_embeddings, + ): + self.sections = sections + self._cos_cached = None + self._sin_cached = None + self.section_indices = ( + torch.arange(len(self.sections)) + .repeat_interleave(torch.tensor(self.sections)) + .view(1, 1, -1) + .to(inv_freq.device) + ) + super().__init__(inv_freq, scaling_factor, max_position_embeddings) + + def _update_cos_sin_cache( + self, dtype: torch.dtype, device: torch.device, seqlen: int + ): + # always cache the cos/sin for the full sequence length to avoid + # recomputing if the sequence length is smaller than the cached one + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + self._sections = self.section_indices.expand(seqlen, -1, -1) + + def get_cos_sin( + self, + position_ids: torch.Tensor, + ): + slen = position_ids.shape[0] + + cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) + sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) + return cos, sin diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1e..8f19174f8 100644 --- a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py +++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py @@ -2,10 +2,8 @@ import torch from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear -from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex +import habana_frameworks.torch as htorch class LayerConcat(torch.nn.Module): @@ -90,14 +88,10 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - if SYSTEM == "ipex": - ipex.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) - else: - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + htorch.core.mark_step() + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -107,10 +101,9 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - if SYSTEM == "ipex": - ipex.distributed.all_gather(world_output, output, group=self.process_group) - else: - torch.distributed.all_gather(world_output, output, group=self.process_group) + + htorch.core.mark_step() + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -202,10 +195,11 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - if SYSTEM == "ipex": - ipex.distributed.all_reduce(out, group=self.process_group) - else: - torch.distributed.all_reduce(out, group=self.process_group) + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -242,8 +236,9 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - if SYSTEM == "ipex": - ipex.distributed.all_reduce(out, group=self.process_group) - else: - torch.distributed.all_reduce(out, group=self.process_group) + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 346016c21..778b14a1b 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -1,3 +1,5 @@ +# ruff: noqa: F821 +# the above line disables the `undefined-name` rule for the model type variables import torch import os @@ -8,6 +10,7 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path from typing import List, Dict +import enum # Needed to properly setup habana_frameworks @@ -16,15 +19,10 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder -from text_generation_server.models.vlm_causal_lm import VlmCausalLM -from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, +from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( + PhiMoEConfig, ) -# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, @@ -33,9 +31,285 @@ from text_generation_server.utils.adapter import ( ) from text_generation_server.adapters.lora import LoraWeights - +from text_generation_server.utils.log import log_master from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +__all__ = [ + "Model", + "CausalLM", + "Seq2SeqLM", + "get_model_with_lora_adapters", +] +from text_generation_server.models.globals import ATTENTION + +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." + +FLASH_ATTENTION = False +if ATTENTION == "paged": + FLASH_ATTENTION = True + +try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM + from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) + from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import ( + FlashDeepseekV3ForCausalLM, + DeepseekV3Config, + ) + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, + ) + from text_generation_server.models.pali_gemma import ( + PaliGemmaBatch, + ) + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, + ) + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch + from text_generation_server.models.custom_modeling.flash_mllama import ( + FlashMllamaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_llava_next import ( + FlashLlavaNextForConditionalGeneration, + ) + + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gptj_modeling import ( + FlashGPTJForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.idefics3 import ( + Idefics3ForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.qwen2_5_vl import ( + Qwen2_5VLForConditionalGeneration, + Qwen2_5_VLConfig, + Qwen2_5_VLProcessor, + ) + from text_generation_server.layers.attention import SUPPORTS_WINDOWING +except ImportError as e: + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") + SUPPORTS_WINDOWING = False + FLASH_ATTENTION = False + +if FLASH_ATTENTION: + __all__.append(FlashCausalLM) + + +class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } + DEEPSEEK_V3 = { + "type": "deepseek_v3", + "name": "Deepseek V3", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V3", + } + IDEFICS2 = { + "type": "idefics2", + "name": "Idefics 2", + "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", + "multimodal": True, + } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + "multimodal": True, + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + GRANITE = { + "type": "granite", + "name": "Granite", + "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } + GEMMA2 = { + "type": "gemma2", + "name": "Gemma2", + "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", + } + COHERE = { + "type": "cohere", + "name": "Cohere", + "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", + } + DBRX = { + "type": "dbrx", + "name": "Dbrx", + "url": "https://huggingface.co/databricks/dbrx-instruct", + } + MAMBA = { + "type": "mamba", + "name": "Mamba", + "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407", + } + MIXTRAL = { + "type": "mixtral", + "name": "Mixtral", + "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", + } + GPT_BIGCODE = { + "type": "gpt_bigcode", + "name": "Gpt Bigcode", + "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + PHI_MOE = { + "type": "phimoe", + "name": "PhiMoe", + "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + FALCON = { + "type": "falcon", + "name": "Falcon", + "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", + } + STARCODER2 = { + "type": "starcoder2", + "name": "StarCoder 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", + } + QWEN2_VL = { + "type": "qwen2_vl", + "name": "Qwen 2 VL", + "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", + } + QWEN2_5_VL = { + "type": "qwen2_5_vl", + "name": "Qwen 2.5 VL", + "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", + } + GALACTICA = { + "type": "galactica", + "name": "Galactica", + "url": "https://huggingface.co/facebook/galactica-120b", + } + SANTACODER = { + "type": "santacoder", + "name": "SantaCoder", + "url": "https://huggingface.co/bigcode/santacoder", + } + GPT2 = { + "type": "gpt2", + "name": "Gpt2", + "url": "https://huggingface.co/openai-community/gpt2", + } + GPT_NEOX = { + "type": "gpt_neox", + "name": "Gpt Neox", + "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", + } + GPTJ = { + "type": "gptj", + "name": "Gptj", + "url": "https://huggingface.co/EleutherAI/gpt-j-6b", + } + MLLAMA = { + "type": "mllama", + "name": "Mllama", + "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct", + "multimodal": True, + } + + +__GLOBALS = locals() +for data in ModelType: + __GLOBALS[data.name] = data.value["type"] SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0)) # Disable gradients @@ -53,9 +327,7 @@ def get_model( trust_remote_code: bool, max_input_tokens: int, ) -> Model: - adapt_transformers_to_gaudi() - if SDP_ON_BF16 == 1: - torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + global FLASH_ATTENTION if speculate is not None: set_speculate(speculate) @@ -177,9 +449,393 @@ def get_model( model_type = config_dict["model_type"] + kv_cache_dtype = dtype + + if FLASH_ATTENTION: + if model_type == DEEPSEEK_V2: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif model_type == DEEPSEEK_V3: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV3Config, + head_size=head_size, + ) + + elif ( + model_type == GPT_BIGCODE + or model_type == GPT2 + and model_id.startswith("bigcode/") + ): + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, + ) + elif model_type == GPT2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GPTJ: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTJForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GPT_NEOX: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, + ) + elif model_type == PHI: + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == PHI_MOE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + config_class=PhiMoEConfig, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == BAICHUAN: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GEMMA: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == GEMMA2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == COHERE: + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == DBRX: + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, + ) + elif ( + model_type in ["RefinedWeb", "RefinedWebModel", FALCON] + and not sharded + and not config_dict.get("alibi", False) + ): + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, + ) + elif model_type == MISTRAL: + return FlashCausalLM( + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == MIXTRAL: + return FlashCausalLM( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == STARCODER2: + return FlashCausalLM( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2_VL: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN2_5_VL: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Qwen2_5VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=Qwen2_5_VLConfig, + processor_class=Qwen2_5_VLProcessor, + ) + elif model_type == MLLAMA: + return FlashMllamaCausalLM( + model_id=model_id, + model_class=FlashMllamaForConditionalGeneration, + batch_class=FlashMllamaCausalLMBatch, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == IDEFICS2: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, + ) + elif model_type == IDEFICS3: + return FlashVlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 1456}}, + ) + elif model_type == PALIGEMMA: + return FlashVlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # Works better for these models + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, + ) + elif model_type == LLAVA_NEXT: + return FlashVlmCausalLM( + model_class=FlashLlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + ) + + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.mllama import ( + MllamaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + + adapt_transformers_to_gaudi() + if SDP_ON_BF16 == 1: + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if model_type == "gpt_bigcode": return StarCoder(model_id=model_id, revision=revision, dtype=dtype) - if model_type == "bloom": return BLOOM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py index e2719fad2..84835ab89 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -377,7 +377,7 @@ class BloomAttention(nn.Module): past_value.view(-1, *past_value.shape[-2:]), ) - if CUSTOM_KERNELS_ENABLED: + if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096: assert self.training is False, "Only foward pass was implemented" assert ( attention_mask.shape[-1] < 4096 @@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel): @staticmethod def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: """ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 30656038b..3bcc689d2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,10 +28,10 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -39,7 +39,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -47,11 +46,10 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight - -if SYSTEM == "cuda": - import dropout_layer_norm -else: - dropout_layer_norm = None +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) class CohereRotary(PositionRotaryEmbedding): @@ -63,38 +61,25 @@ class CohereRotary(PositionRotaryEmbedding): sin: torch.Tensor, ): # Such controlflows may add some overhead. - if SYSTEM == "cuda": - import rotary_emb + num_tokens = query.shape[0] + head_size = query.shape[-1] + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - q1 = query[..., ::2] - q2 = query[..., 1::2] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., ::2] - k2 = key[..., 1::2] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - from vllm._C import ops - - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, False) - elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), False - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) class CohereLayerNorm(nn.Module): @@ -107,49 +92,18 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": - hidden_states = hidden_states.reshape( - -1, self.weight.shape[0], self.weight.shape[1] - ) - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - hidden_states_minus_mean = hidden_states - mean - variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) - hidden_states = self.weight.to(torch.float32) * hidden_states - hidden_states = hidden_states.view(-1, self.weight.shape[1]) - return hidden_states.to(input_dtype) - - ( - hidden_states, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - None, - self.ones, - None, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - - # Required to apply one weight matrix per head - hidden_states = hidden_states.view( + hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) - hidden_states = self.weight * hidden_states + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + hidden_states_minus_mean = hidden_states - mean + variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) + hidden_states = self.weight.to(torch.float32) * hidden_states hidden_states = hidden_states.view(-1, self.weight.shape[1]) - - return hidden_states + return hidden_states.to(input_dtype) def load_attention(config, prefix, weights): @@ -229,6 +183,7 @@ class FlashCohereAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: @@ -264,10 +219,9 @@ class FlashCohereAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, key, value = qkv.split( @@ -291,30 +245,35 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -386,10 +345,9 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -400,10 +358,9 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) mlp_output = self.mlp(normed_hidden_states) @@ -452,18 +409,15 @@ class FlashCohereModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: torch.Tensor, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None @@ -475,10 +429,9 @@ class FlashCohereModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -516,11 +469,9 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -529,10 +480,9 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 1137a453f..15c243c97 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,17 +20,14 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales -if SYSTEM != "ipex": - from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, - PREFILL_IN_KV_CACHE, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( FastLinear, @@ -46,6 +43,7 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from vllm_hpu_extension.ops import DynamicFusedMOE class DbrxAttentionConfig(PretrainedConfig): @@ -290,6 +288,7 @@ class DbrxAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -309,10 +308,9 @@ class DbrxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) if self.clip_qkv is not None: @@ -330,30 +328,35 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -387,10 +390,9 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -401,10 +403,9 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # faster post attention rms norm @@ -482,18 +483,15 @@ class BlockSparseMoE(nn.Module): self.process_group = weights.process_group + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) + for i in range(self.num_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i]) + def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( - x, - self.wv1, - self.w2, - router_logits, - self.top_k, - renormalize=self.moe_normalize_expert_weights, - inplace=True, - ) + out = self.hpu_fused_moe(x, router_logits, self.top_k) # Reduce sum if self.process_group.size() > 1: @@ -620,10 +618,9 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): # Self Attention attn_output, attn_res = self.attn( @@ -633,10 +630,9 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) moe_output = self.moe(attn_output) @@ -677,18 +673,15 @@ class DbrxModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -699,10 +692,9 @@ class DbrxModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -732,11 +724,9 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -745,10 +735,9 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 88c2cf803..9d61c6941 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,21 +33,14 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, + HPUPagedAttentionMetadata, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - class DeepseekV2Config(PretrainedConfig): def __init__( @@ -232,6 +225,8 @@ class DeepseekV2Attention(torch.nn.Module): ), ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -260,11 +255,10 @@ class DeepseekV2Attention(torch.nn.Module): cos: torch.Tensor, sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - block_tables: torch.Tensor, + kv_cache: KVCache, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: query = self.q_proj(hidden_states) @@ -321,30 +315,35 @@ class DeepseekV2Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) # Remove padding. @@ -387,27 +386,11 @@ class DeepseekV2MLP(nn.Module): self.quantize = config.quantize def forward(self, hidden_states: torch.Tensor, reduce: bool = True): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out, reduce=reduce) - else: - gate_up_states = self.gate_up_proj(hidden_states) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce - ) + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) class DeepseekV2MoE(nn.Module): @@ -520,10 +503,9 @@ class DeepseekV2Layer(nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -534,10 +516,9 @@ class DeepseekV2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # faster post attention rms norm @@ -583,18 +564,15 @@ class DeepseekV2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -605,10 +583,9 @@ class DeepseekV2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -635,11 +612,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -648,10 +623,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py new file mode 100644 index 000000000..1a7ce5cf5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -0,0 +1,642 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, + HPUPagedAttentionMetadata, +) +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.weights import Weights + + +class DeepseekV3Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekV3Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: KVCache, + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV3MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class DeepseekV3MoE(nn.Module): + def __init__( + self, + prefix, + config: DeepseekV3Config, + moe_layer_cls: Type[MoELayer], + weights, + ): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = torch.zeros( + config.n_routed_experts, device=weights.device + ) + else: + self.gate.e_score_correction_bias = None + + self.moe_layer = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.n_routed_experts, + n_expert_group=config.n_group, + renormalize=config.norm_topk_prob, + topk=config.num_experts_per_tok, + topk_group=config.topk_group, + weights=weights, + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + assert isinstance(self.moe_layer, MoELayer) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV3MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + + out = self.moe_layer(x, gating_output=router_logits) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DeepseekV3Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV3Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + else: + self.mlp = DeepseekV3MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV3Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV3Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV3ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV3Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 7a3d60c97..79f21b0f3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,8 +28,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -40,7 +40,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -208,6 +208,7 @@ class FlashGemma2Attention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -234,11 +235,10 @@ class FlashGemma2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -253,19 +253,24 @@ class FlashGemma2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, - causal=self.causal, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, ) @@ -273,14 +278,13 @@ class FlashGemma2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, softcap=self.softcap, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -390,11 +394,10 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -405,11 +408,10 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -458,19 +460,16 @@ class FlashGemma2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - adapter_data: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -481,11 +480,10 @@ class FlashGemma2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -529,11 +527,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -543,11 +539,10 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4c1be6f60..609f03acc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,9 +28,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, - PREFILL_IN_KV_CACHE, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -39,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -187,6 +187,7 @@ class FlashGemmaAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -206,10 +207,9 @@ class FlashGemmaAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -224,31 +224,36 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, causal=self.causal, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -317,10 +322,9 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -331,10 +335,9 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # faster post attention rms norm @@ -379,18 +382,16 @@ class FlashGemmaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + adapter_data: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -401,10 +402,9 @@ class FlashGemmaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -446,11 +446,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -460,10 +458,10 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 44c015cf4..10024a6de 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,12 +24,11 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -38,6 +37,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads): prefix, weights, ) - elif config.quantize == "marlin": - raise RuntimeError( - "GPT-2 models with marlin quantization are not yet supported" - ) else: return _load_qkv(config, prefix, weights, head_size, num_heads) @@ -195,6 +191,7 @@ class FlashGPT2Attention(torch.nn.Module): head_size=self.head_size, num_heads=self.num_heads, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -212,10 +209,9 @@ class FlashGPT2Attention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -224,30 +220,35 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -313,10 +314,9 @@ class FlashGPT2Layer(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -326,10 +326,9 @@ class FlashGPT2Layer(nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states = attn_output + residual @@ -379,12 +378,9 @@ class FlashGPT2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -395,10 +391,9 @@ class FlashGPT2Model(torch.nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states = self.norm(hidden_states) @@ -432,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -448,12 +441,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s=max_s, - prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index aca970044..41eeab78c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,12 +24,12 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -38,13 +38,16 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) def load_attention(config, prefix: str, weights): @@ -78,39 +81,25 @@ class GPTJRotary(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - # Such controlflows may add some overhead. - if SYSTEM == "cuda": - import rotary_emb + num_tokens = query.shape[0] + head_size = query.shape[-1] + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + rotary_dim = cos.shape[-1] + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - q1 = query[..., ::2] - q2 = query[..., 1::2] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - - k1 = key[..., ::2] - k2 = key[..., 1::2] - - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - elif SYSTEM == "rocm": - from vllm._C import ops - - # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. - # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - - head_size = query.shape[-1] - - # Inplace operation, updating query and key. - ops.rotary_embedding(query, key, head_size, cos, sin, False) - elif SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), False - ) - else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) class FlashGPTJAttention(torch.nn.Module): @@ -140,6 +129,7 @@ class FlashGPTJAttention(torch.nn.Module): prefix=prefix, weights=weights, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -166,10 +156,9 @@ class FlashGPTJAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -186,30 +175,35 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -267,10 +261,9 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) # Self Attention @@ -280,10 +273,9 @@ class FlashGPTJLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) feed_forward_hidden_states = self.mlp(hidden_states) @@ -327,19 +319,15 @@ class FlashGPTJModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -350,10 +338,9 @@ class FlashGPTJModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -381,11 +368,9 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -394,11 +379,9 @@ class FlashGPTJForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c9ec70cca..81af55603 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,14 +27,16 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention import ( + KVCache, + get_kv_scales, +) from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -57,15 +59,6 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader -if SYSTEM != "ipex": - pass - -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. @@ -157,7 +150,10 @@ class FlashLlamaAttention(torch.nn.Module): device=weights.device, ) - self.softmax_scale = self.head_size**-0.5 + # `config.attention_multiplier` is used in Granite + self.softmax_scale = getattr( + config, "attention_multiplier", self.head_size**-0.5 + ) if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -177,11 +173,13 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index + self.kv_scales = get_kv_scales(weights, f"{prefix}") + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=False, + bias=getattr(config, "attention_bias", False), ) self.o_proj = TensorParallelAdapterRowLinear.load( @@ -202,12 +200,11 @@ class FlashLlamaAttention(torch.nn.Module): cos, sin, cu_seqlen_prefill, - kv_cache, - block_tables, + kv_cache: KVCache, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -222,30 +219,35 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_scales=self.kv_scales, + kv_cache=kv_cache, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -363,31 +365,11 @@ class LlamaMLP(nn.Module): self.hidden_size = config.hidden_size def forward(self, hidden_states, adapter_data): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - and self.hidden_size - != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu( - self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 - ) - return self.down_proj(out, adapter_data) - else: - gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - ) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): @@ -408,7 +390,7 @@ class FlashLlamaLayer(nn.Module): if SparseMoELayer.is_supported(weights) else DenseMoELayer ) - self.dense = Phi3MoE( + self.mlp = Phi3MoE( f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights ) # with moe the layernorms are are not rmsnorms and they have bias @@ -423,7 +405,7 @@ class FlashLlamaLayer(nn.Module): eps=config.rms_norm_eps, ) else: - self.dense = LlamaMLP( + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) self.input_layernorm = FastRMSNorm.load( @@ -437,6 +419,11 @@ class FlashLlamaLayer(nn.Module): eps=config.rms_norm_eps, ) + # Used in Granite + # This could eventually be baked into the weights like we do for the embeddings/lm_head + # but this would mean modifying the lora code + self.residual_multiplier = getattr(config, "residual_multiplier", None) + def forward( self, hidden_states, @@ -445,12 +432,11 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -461,19 +447,21 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, adapter_data, + hpu_attention_meta=hpu_attention_meta, ) + if self.residual_multiplier is not None: + attn_output *= self.residual_multiplier - # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( attn_output, res ) - mlp_output = self.dense(normed_attn_res_output, adapter_data) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) + if self.residual_multiplier is not None: + mlp_output *= self.residual_multiplier return mlp_output, attn_res @@ -493,9 +481,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=f"{prefix}.layers.0", config=config, weights=weights, ) @@ -504,18 +490,14 @@ class FlashLlamaModel(torch.nn.Module): # Skip first and last layers for layer_id in range(1, config.num_hidden_layers - 1): if layer_id in self.cross_attention_layers: - from text_generation_server.models.custom_modeling.mllama import ( + from text_generation_server.models.custom_modeling.flash_mllama import ( FlashLlamaCrossLayer, ) self.layers.append( FlashLlamaCrossLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -524,11 +506,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -539,18 +517,14 @@ class FlashLlamaModel(torch.nn.Module): self.layers.append( FlashLlamaLayer( index=last_layer_id, - prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.model.layers.{last_layer_id}" - ), + prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, ) ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -567,22 +541,17 @@ class FlashLlamaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -593,12 +562,11 @@ class FlashLlamaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, + hpu_attention_meta=hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -607,42 +575,60 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}.model.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" + # Used in Granite + embedding_multiplier = getattr(config, "embedding_multiplier", None) + if embedding_multiplier is not None: + self.embed_tokens.weight.data *= embedding_multiplier + prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}" with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, + prefix, + weights, ) + # Used in Granite + self.logits_scaling = getattr(config, "logits_scaling", None) + if self.logits_scaling is not None and self.lm_head.head is not None: + try: + # Scale the weights directly + self.lm_head.head.linear.weight.data /= self.logits_scaling + self.logits_scaled = True + except Exception: + self.logits_scaled = False + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, @@ -653,16 +639,20 @@ class FlashLlamaForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s=max_s, - prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) + + # Used in Granite + if self.logits_scaling is not None and not self.logits_scaled: + logits /= self.logits_scaling + if speculative_logits is not None: + speculative_logits /= self.logits_scaling + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py new file mode 100644 index 000000000..88548042d --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -0,0 +1,285 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava-NeXT model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.image_processing_utils import select_best_resolution + +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (height, width). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (height, width). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext +class LlavaNextMultiModalProjector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.linear_1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class FlashLlavaNextForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + vision_config = config.vision_config + # Instead of selecting in hidden_states[-2]. + # Instead compute only the n -2 + 1 layers and don't pool + if config.vision_feature_layer < 0: + vision_config.num_hidden_layers += config.vision_feature_layer + 1 + else: + vision_config.num_hidden_layers = config.vision_feature_layer + 1 + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.multi_modal_projector = LlavaNextMultiModalProjector( + prefix="multi_modal_projector", config=config, weights=weights + ) + + self.image_newline = weights.get_tensor("image_newline") + + self.vocab_size = config.text_config.vocab_size + self.config = config + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = torch.where(input_ids == self.config.image_token_index) + # Let's pray we have enabled enough slots ! + try: + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + except Exception as e: + raise RuntimeError( + f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + # Unused for this model + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None and len(pixel_values) > 0: + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_features + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 341a23524..d23d4f679 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,12 +26,12 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -41,20 +41,12 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") - - class MistralConfig(PretrainedConfig): model_type = "mistral" @@ -160,6 +152,7 @@ class MistralAttention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -185,12 +178,10 @@ class MistralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -205,38 +196,36 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -300,29 +289,11 @@ class MistralMLP(nn.Module): self.quantize = config.quantize def forward(self, hidden_states, adapter_data): - if ( - SYSTEM == "rocm" - and self.hidden_act == "silu" - and hidden_states.dtype == torch.float16 - and hidden_states.shape[0] == 1 - and not self.quantize - ): - out = torch.empty( - hidden_states.shape[0], - self.intermediate_size, - dtype=hidden_states.dtype, - device="cuda", - ) - _custom_C.LLMM_Silu( - self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 - ) - return self.down_proj(out, adapter_data) - else: - gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - ) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): @@ -355,12 +326,10 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -371,12 +340,10 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -423,20 +390,15 @@ class MistralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -447,12 +409,10 @@ class MistralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -498,35 +458,21 @@ class FlashMistralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + hpu_attention_meta, adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 5836d30af..1ef6be481 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,9 +37,9 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, + HPUPagedAttentionMetadata, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding @@ -215,6 +215,7 @@ class MixtralAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -234,11 +235,9 @@ class MixtralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -253,38 +252,36 @@ class MixtralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -378,11 +375,9 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -393,11 +388,9 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -448,20 +441,15 @@ class MixtralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -472,11 +460,9 @@ class MixtralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -507,34 +493,21 @@ class FlashMixtralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py new file mode 100644 index 000000000..216642e08 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -0,0 +1,986 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, +) +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + attention_mask = ( + attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave( + num_vision_tokens, dim=3 + ) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision +class MllamaVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.embed_dim = config.hidden_size + self.head_dim = config.hidden_size // config.attention_heads + self.num_heads = config.attention_heads // weights.process_group.size() + + self.qkv_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_state) + query, key, value = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + ], + dim=2, + ) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = MllamaVisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 + ) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False + ) + self.gate_ffn = nn.Parameter( + weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): + super().__init__() + self.config = config + self.layers = [ + MllamaVisionEncoderLayer( + prefix=f"{prefix}.layers.{i}", + config=config, + weights=weights, + is_gated=is_gated, + ) + for i in range(num_layers) + ] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + encoder_states = [hidden_states] + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs + encoder_states.append(hidden_states) + + return hidden_states, encoder_states + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + + self.embedding = TensorParallelEmbedding( + prefix=f"{prefix}.embedding", weights=weights + ) + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + # Always gated. + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) + + # position embedding + embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.embedding"), requires_grad=False + ) + self.gated_position_embedding = (1 - self.gate.tanh()) * embedding + self.tile_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.tile_embedding", weights=weights + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaVisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + self.dtype = weights.dtype + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False + ) + + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + prefix=f"{prefix}.gated_positional_embedding", + config=config, + weights=weights, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.pre_tile_positional_embedding", + config=config, + weights=weights, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + prefix=f"{prefix}.post_tile_positional_embedding", + config=config, + weights=weights, + ) + + ## layer norms + self.layernorm_pre = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_pre", + weights=weights, + # torch default + eps=1e-05, + ) + self.layernorm_post = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_post", + weights=weights, + # torch default + eps=1e-05, + ) + + ## encoders + self.transformer = MllamaVisionEncoder( + prefix=f"{prefix}.transformer", + config=config, + weights=weights, + is_gated=False, + num_layers=config.num_hidden_layers, + ) + self.global_transformer = MllamaVisionEncoder( + prefix=f"{prefix}.global_transformer", + config=config, + weights=weights, + is_gated=True, + num_layers=config.num_global_layers, + ) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + ( + batch_size, + num_concurrent_media, + num_tiles, + num_channels, + height, + width, + ) = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # patch embedding + patch_embeds = self.patch_embedding(pixel_values) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + if attention_mask is not None: + attention_mask = attention_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + hidden_state, all_intermediate_hidden_states = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + hidden_state, _ = self.global_transformer( + hidden_state, attention_mask=attention_mask + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, *, prefix, config, weights, layer_idx): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_size = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.q_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps + ) + self.k_norm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps + ) + self.softmax_scale = self.head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + # past_key_value=None, + # attention_mask: Optional[torch.Tensor] = None, + # cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # hidden_states = hidden_states.unsqueeze(0) + # bsz, q_len, _ = hidden_states.size() + ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + indices, + ) = cross_attention_states + bs = cu_seqlen_q.size(0) - 1 + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bs, -1, self.num_heads, self.head_size) + query_states = self.q_norm(query_states) + + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size) + value_states = value_states.view( + bs, -1, self.num_key_value_heads, self.head_size + ) + key_states = self.k_norm(key_states) + + # key_states = key_states.repeat(1, self.num_key_value_groups, 1) + # value_states = value_states.repeat(1, self.num_key_value_groups, 1) + + causal = False + # logger.info( + # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" + # ) + # execute sdpa + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query_states, + key_states, + value_states, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + return attn_output + + +# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText +class MllamaTextMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + shape = x.shape + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) + result = self.down_proj( + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] + ) + return result + + +class FlashLlamaCrossLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, *, prefix, config, weights, index) -> None: + layer_idx = index + super().__init__() + self.cross_attn = MllamaTextCrossAttention( + prefix=f"{prefix}.cross_attn", + config=config, + weights=weights, + layer_idx=layer_idx, + ) + + self.input_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.cross_attn_attn_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False + ) + + self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.post_attention_layernorm = MllamaTextRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.cross_attn_mlp_gate = torch.nn.Parameter( + weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False + ) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + cross_attention_states, # [ IB, ...] + hpu_attention_meta, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cross_attention_states is None: + return hidden_states, residual + if residual is not None: + hidden_states += residual + + indices = cross_attention_states[-1] + out_hidden_states = hidden_states[:] + if len(indices) > 0: + assert max(indices) < hidden_states.shape[0] + hidden_states = hidden_states[indices] + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + # attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + out_hidden_states[indices] = hidden_states + hidden_states = out_hidden_states + + return hidden_states, None + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, weight, eps): + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + @classmethod + def load(cls, *, prefix, weights, eps): + weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + return cls(weight=weight, eps=eps) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class FlashMllamaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + config.text_config._attn_implementation = "sdpa" + self.hidden_size = config.text_config.hidden_size + self.vision_model = MllamaVisionModel( + prefix="vision_model", config=config.vision_config, weights=weights + ) + self.multi_modal_projector = FastLinear.load( + prefix="multi_modal_projector", config=config, weights=weights, bias=True + ) + self.text_model = FlashLlamaForCausalLM( + prefix="language_model", config=config.text_config, weights=weights + ) + self.config = config + self.dtype = weights.dtype + self.device = weights.device + + def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # logger.info(f"PIxel values {pixel_values.shape}") + batch_size = pixel_values.shape[0] + vision_states = self.vision_model( + pixel_values, aspect_ratio_ids, aspect_ratio_mask + ) + cross_attention_states = self.multi_modal_projector(vision_states).reshape( + -1, vision_states.shape[-2], self.hidden_size + ) + _, _, h = cross_attention_states.shape + cross_attention_states = cross_attention_states.view(batch_size, -1, h) + # logger.info(f"cross {cross_attention_states.shape}") + return cross_attention_states + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, + # XXX: Putting these as optional so that the cuda warmup calls can go through. + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + if cross_attention_states is not None: + seqlen_q = len(image_indices) + n_images = cross_attention_states.shape[0] + seqlen_k = cross_attention_states.shape[1] + device = cross_attention_states.device + if cu_seqlen_prefill is not None: + offset = 0 + cu_q = [] + indices = [] + for index in image_indices: + cu_q.append(offset) + length = seqlen.input_lengths[index].item() + assert index < seqlen.cu_seqlen_q.shape[0] + input_ids_offset = seqlen.cu_seqlen_q[index] + indices.extend(range(input_ids_offset, input_ids_offset + length)) + offset += length + cu_q.append(offset) + cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) + + assert max(indices) < input_ids.shape[0] + + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + else: + cu_seqlen_q = torch.arange( + seqlen_q + 1, device=device, dtype=torch.int32 + ) + seqlen_k = cross_attention_states.shape[1] + n_images = cross_attention_states.shape[0] + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + indices = image_indices[:] + + cross_attention_states = ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + indices, + ) + + outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + cross_attention_states=cross_attention_states, + ) + + return outputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ad4e382fe..33f63333a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,8 +29,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -39,7 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -132,6 +132,7 @@ class FlashNeoxAttention(torch.nn.Module): head_size=self.head_size, hidden_size=self.hidden_size, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) @@ -146,10 +147,9 @@ class FlashNeoxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -165,30 +165,35 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=qkv[:, 1], + value=qkv[:, 2], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - qkv[:, 0], - kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], - kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], - seqlen, - block_tables, - self.softmax_scale, + query=qkv[:, 0], + key=qkv[:, 1], + value=qkv[:, 2], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( qkv[:, 0], - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -255,10 +260,9 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -269,10 +273,9 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -293,10 +296,9 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, residual = self.post_attention_layernorm( @@ -347,18 +349,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -369,10 +368,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.final_layer_norm(hidden_states, residual) @@ -401,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -414,10 +410,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0024f2bb9..4d31d5ddf 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,7 +19,7 @@ from torch import nn from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -69,22 +69,20 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused here pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: - max_s += 1 position_ids += 1 if pixel_values is not None: @@ -106,10 +104,10 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 2a0dc6066..0c7779124 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,8 +9,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -19,7 +19,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: + if config.quantize not in ["gptq", "awq"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -139,6 +139,7 @@ class FlashPhiAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") # in llama the dense layer is called "o_proj" and has bias=False self.dense = TensorParallelRowLinear.load( @@ -159,10 +160,9 @@ class FlashPhiAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): # Compute query, key, value and split qkv = self.query_key_value(hidden_states) @@ -188,29 +188,34 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_scales=self.kv_scales, + kv_cache=kv_cache, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -274,10 +279,9 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention @@ -287,10 +291,9 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states = self.resid_dropout(attn_output).add( @@ -339,18 +342,15 @@ class FlashPhiModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -361,10 +361,9 @@ class FlashPhiModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -394,11 +393,9 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -407,10 +404,9 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 02c788d3e..af4b404d0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,8 +8,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -17,7 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -86,6 +86,8 @@ class Qwen2Attention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -104,11 +106,9 @@ class Qwen2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -123,38 +123,36 @@ class Qwen2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -223,13 +221,11 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + normed_hidden_states, residual = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( @@ -238,21 +234,17 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) + hidden_states = attn_output + residual # faster post attention rms norm - normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res - ) - - mlp_output = self.mlp(normed_attn_res_output) - - return mlp_output, attn_res + hidden_states, residual = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states) + hidden_states = mlp_output + residual + return hidden_states class Qwen2Model(torch.nn.Module): @@ -264,9 +256,6 @@ class Qwen2Model(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ Qwen2Layer( @@ -290,42 +279,35 @@ class Qwen2Model(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = inputs_embeds - # Get rotary cos and sin for this forward - # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype + position_ids, ) residual = None for i, layer in enumerate(self.layers): - hidden_states, residual = layer( + hidden_states = layer( hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + hpu_attention_meta, ) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states) return hidden_states @@ -346,6 +328,12 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + self.max_past = config.sliding_window self.max_past_tensor = ( torch.tensor(config.sliding_window, device=weights.device) @@ -359,34 +347,23 @@ class Qwen2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) + + inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6671d85e2..141e13a63 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,14 +12,14 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) @@ -79,6 +79,7 @@ class RWConfig(PretrainedConfig): self.alibi = False self.rotary = True self.rope_theta = rope_theta + self.max_position_embeddings = 2048 self.vocab_size = vocab_size # Backward compatibility with n_embed kwarg @@ -160,6 +161,7 @@ class FlashRWAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -180,10 +182,9 @@ class FlashRWAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -200,30 +201,35 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -278,6 +284,7 @@ class FlashRWLargeAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -293,10 +300,9 @@ class FlashRWLargeAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -312,36 +318,35 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - reshape_and_cache( - kv[:, :, 0].contiguous(), - kv[:, :, 1].contiguous(), - kv_cache[0], - kv_cache[1], - slots, + kv_cache.store( + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, :, 0], + value=kv[:, :, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense( @@ -424,10 +429,9 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -438,10 +442,9 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) mlp_output = self.mlp(ln_hidden_states) @@ -460,10 +463,9 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if self.post_attention_layernorm is not None: @@ -547,10 +549,9 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): # Layer norm. ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) @@ -562,10 +563,9 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) # MLP. @@ -623,18 +623,15 @@ class FlashRWModel(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.h): @@ -645,10 +642,9 @@ class FlashRWModel(FlashRWPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -675,11 +671,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -688,10 +682,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 43eb9687f..b68f47840 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,8 +8,8 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -18,7 +18,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, @@ -32,10 +32,6 @@ def load_multi_mqa( return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) - elif config.quantize == "marlin": - raise RuntimeError( - "santacoder models with marlin quantization are not yet supported" - ) else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -259,6 +255,7 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_head_mapping = torch.zeros( self.num_heads, dtype=torch.int32, device=weights.device ) @@ -268,10 +265,9 @@ class FlashMQAttention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): qkv = self.c_attn(hidden_states) @@ -284,32 +280,35 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - reshape_and_cache( - key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=key_value[:, 0], + value=key_value[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key_value[:, 0], + value=key_value[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -371,20 +370,18 @@ class Block(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -435,10 +432,9 @@ class FlashSantacoderModel(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -452,10 +448,9 @@ class FlashSantacoderModel(nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -484,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -497,10 +490,9 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 4975cf225..76f6f473a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,17 +29,19 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -110,17 +112,31 @@ class Starcoder2Config(PretrainedConfig): ) -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): + prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + head_size = config.hidden_size // config.num_attention_heads + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + base_layer = _load_gqa(config, prefix, weights) else: - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + prefixes=prefixes, dim=0, weights=weights, bias=config.use_bias, ) + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) def _load_gqa(config, prefix: str, weights): @@ -158,6 +174,7 @@ def _load_gqa(config, prefix: str, weights): class Starcoder2Attention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -189,14 +206,23 @@ class Starcoder2Attention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) + self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, - bias=config.use_bias, + bias=getattr(config, "use_bias", False), ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -209,13 +235,12 @@ class Starcoder2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -228,45 +253,45 @@ class Starcoder2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( - query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, - max_s, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Starcoder2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() act = config.hidden_act self.act = ( @@ -280,27 +305,42 @@ class Starcoder2MLP(nn.Module): ) ) # Fuse gate and up proj - self.c_fc = TensorParallelColumnLinear.load( + c_fc = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.c_fc", weights=weights, bias=config.use_bias, ) - self.c_proj = TensorParallelRowLinear.load( + c_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.c_proj", weights=weights, bias=config.use_bias, ) - def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id=index, + layer_names=[f"{prefix}.c_fc"], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.c_proj = TensorParallelAdapterRowLinear.load( + c_proj, + index, + "c_proj", + process_group=weights.process_group, + ) + + def forward(self, hidden_states, adapter_data): + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - return self.c_proj(hidden_states) + return self.c_proj(hidden_states, adapter_data) class Starcoder2GatedMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -314,27 +354,47 @@ class Starcoder2GatedMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + prefixes=prefixes, weights=weights, dim=0, bias=config.use_bias, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=config.use_bias, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) STARCODER2_NORMALIZATION_CLASSES = { @@ -353,11 +413,11 @@ class Starcoder2Layer(nn.Module): super().__init__() prefix = f"model.layers.{layer_id}" self.self_attn = Starcoder2Attention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id ) self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id ) self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( @@ -379,11 +439,10 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -394,11 +453,10 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -406,7 +464,7 @@ class Starcoder2Layer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -447,20 +505,16 @@ class Starcoder2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -471,11 +525,10 @@ class Starcoder2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -519,34 +572,22 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - max_s, - true_max_s, - prefill_cache_indices, + adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index a829c3741..02806ac94 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,7 +25,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module): ): """In place merges in vision_embeddings with inputs_embeds.""" # mask = input_ids == self.config.image_token_index - mask = input_ids == self.config.image_token_id + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) # Let's pray we have enabled enough slots ! inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds @@ -739,17 +740,16 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: @@ -793,6 +793,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ].contiguous() patch_size = self.config.vision_config.patch_size + """ patches_subgrid = pixel_attention_mask.unfold( dimension=1, size=patch_size, step=patch_size ) @@ -800,6 +801,21 @@ class Idefics2ForConditionalGeneration(nn.Module): dimension=2, size=patch_size, step=patch_size ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) # Get sequence from the vision encoder image_hidden_states = self.vision_model( @@ -825,12 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, + hpu_attention_meta=hpu_attention_meta, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py new file mode 100644 index 000000000..964526fcf --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -0,0 +1,596 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics3 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics3VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics3VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics3VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics3EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics3VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics3VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics3Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics3EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics3VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics3VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics3Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics3SimpleMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj + + def forward(self, x): + return self.proj(x) + + +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight` + # since Idefics3 uses the `embed_tokens` for the final prediction + # config.text_config.tie_word_embeddings = True + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics3VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + # - replace `==` with torch.where to fix the issue in hpu graph + mask = torch.where(input_ids == self.config.image_token_id) + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + """ + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.eq( + patches_subgrid, (patch_size * patch_size) + ) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py index fc6becc4b..a130dbc12 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -46,15 +46,9 @@ from text_generation_server.layers import ( FastLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.utils.import_utils import SYSTEM from loguru import logger -if SYSTEM == "cuda": - import dropout_layer_norm -elif SYSTEM == "rocm": - from vllm._C import ops -else: - dropout_layer_norm = None +dropout_layer_norm = None @dataclass @@ -351,94 +345,18 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex + from vllm_hpu_extension.kernels import rms_norm - out = ipex.llm.functional.add_rms_norm( - residual, - hidden_states, - self.weight, - None, - self.variance_epsilon, - residual is not None, - ) - return out - elif hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - elif SYSTEM == "cuda": - # faster post attention rms norm - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) - - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - if unwrap: - normed_hidden_states = normed_hidden_states.view(*shape) - - return normed_hidden_states - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - - if unwrap: - out = out.view(*shape) - - return out + orig_shape = hidden_states.shape + if residual is not None: + residual += hidden_states.view(residual.shape) else: - raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." - ) + residual = hidden_states + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + if len(orig_shape) == 2: + residual = residual.unsqueeze(0) + x = rms_norm().apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual.view(orig_shape) # this was adapted from LlamaMLP diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 293051c2b..5a9c05887 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -196,7 +196,10 @@ class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + try: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights) + except RuntimeError: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) @@ -206,7 +209,10 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + try: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) + except RuntimeError: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) self.config = config def forward( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py deleted file mode 100644 index 988a74a39..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ /dev/null @@ -1,1215 +0,0 @@ -"""A simple, flexible implementation of a GPT model. - -Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py -""" - -import math -import warnings -from typing import List, Optional, Tuple, Union -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from einops import rearrange -from packaging import version -from text_generation_server.layers import ( - TensorParallelEmbedding, - TensorParallelColumnLinear, - TensorParallelRowLinear, - SpeculativeHead, - get_linear, -) - -EPS = 1e-5 - - -def load_col(config, prefix, weights, bias): - assert config.quantize != "gptq", NotImplementedError - slice_ = weights._get_slice(f"{prefix}.weight") - rank = weights.process_group.rank() - size = weights.process_group.size() - - h3, h = slice_.get_shape() - block_size = h // size - - q_part = slice_[rank * block_size : (rank + 1) * block_size] - k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] - v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] - - weight = torch.cat([q_part, k_part, v_part], dim=0) - if weight.dtype != torch.int32: - weight = weight.to(dtype=weights.dtype) - weight = weight.to(device=weights.device) - - if bias: - bias_slice_ = weights._get_slice(f"{prefix}.bias") - bias_rank = weights.process_group.rank() - bias_size = weights.process_group.size() - - bias_h = bias_slice_.get_shape() - bias_h = bias_h[0] - bias_block_size = bias_h // bias_size - - bias_q_part = bias_slice_[ - bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size - ] - bias_k_part = bias_slice_[ - bias_h - + bias_rank * bias_block_size : bias_h - + (bias_rank + 1) * bias_block_size - ] - bias_v_part = bias_slice_[ - 2 * bias_h - + bias_rank * bias_block_size : 2 * bias_h - + (bias_rank + 1) * bias_block_size - ] - - bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) - if bias.dtype != torch.int32: - bias = bias.to(dtype=weights.dtype) - bias = bias.to(device=weights.device) - else: - bias = None - linear = get_linear(weight, bias) - return TensorParallelColumnLinear(linear) - - -def _reset_is_causal( - num_query_tokens: int, num_key_tokens: int, original_is_causal: bool -): - if original_is_causal and num_query_tokens != num_key_tokens: - if num_query_tokens != 1: - raise NotImplementedError( - "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." - ) - else: - return False - return original_is_causal - - -def scaled_multihead_dot_product_attention( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) - kv_n_heads = 1 if multiquery else n_heads - k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) - v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) - if past_key_value is not None: - if len(past_key_value) != 0: - k = torch.cat([past_key_value[0], k], dim=3) - v = torch.cat([past_key_value[1], v], dim=2) - past_key_value = (k, v) - (b, _, s_q, d) = q.shape - s_k = k.size(-1) - attn_weight = q.matmul(k) * softmax_scale - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - s_q) - _s_k = max(0, attn_bias.size(3) - s_k) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if ( - attn_bias.size(-1) != 1 - and attn_bias.size(-1) != s_k - or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) - ): - raise RuntimeError( - f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." - ) - attn_weight = attn_weight + attn_bias - min_val = torch.finfo(q.dtype).min - if key_padding_mask is not None: - if attn_bias is not None: - warnings.warn( - "Propogating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unneccessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - attn_weight = attn_weight.masked_fill( - ~key_padding_mask.view((b, 1, 1, s_k)), min_val - ) - if is_causal and (not q.size(2) == 1): - s = max(s_q, s_k) - causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) - causal_mask = causal_mask.tril() - causal_mask = causal_mask.to(torch.bool) - causal_mask = ~causal_mask - causal_mask = causal_mask[-s_q:, -s_k:] - attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p: - attn_weight = torch.nn.functional.dropout( - attn_weight, p=dropout_p, training=training, inplace=True - ) - out = attn_weight.to(v.dtype).matmul(v) - out = rearrange(out, "b h s d -> b s (h d)") - if needs_weights: - return (out, attn_weight, past_key_value) - return (out, None, past_key_value) - - -def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): - for tensor in tensors: - if tensor.dtype not in valid_dtypes: - raise TypeError( - f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." - ) - if not tensor.is_cuda: - raise TypeError( - f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." - ) - - -def flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from flash_attn import bert_padding, flash_attn_interface - except Exception: - raise RuntimeError("Please install flash-attn==1.0.3.post0") - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: - raise NotImplementedError("attn_bias not implemented for flash attn.") - (batch_size, seqlen) = query.shape[:2] - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1) :] - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( - query, query_padding_mask - ) - query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( - key, key_padding_mask - ) - key_unpad = rearrange( - key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange( - value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - if multiquery: - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand( - value_unpad.size(0), n_heads, value_unpad.size(-1) - ) - dropout_p = dropout_p if training else 0.0 - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights, - ) - output = bert_padding.pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen - ) - return (output, None, past_key_value) - - -def triton_flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from .flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if version.parse(torch.__version__) < version.parse("2.0.0"): - _installed = True - try: - from flash_attn.flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if not _installed: - raise RuntimeError( - "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." - ) - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if dropout_p: - raise NotImplementedError("Dropout not implemented for attn_impl: triton.") - if needs_weights: - raise NotImplementedError("attn_impl: triton cannot return attn weights.") - if key_padding_mask is not None: - warnings.warn( - "Propagating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unnecessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - (b_size, s_k) = key_padding_mask.shape[:2] - if attn_bias is None: - attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill( - ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min - ) - query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) - key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - if multiquery: - key = key.expand(*key.shape[:2], n_heads, key.size(-1)) - value = value.expand(*value.shape[:2], n_heads, value.size(-1)) - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_func( - query, key, value, attn_bias, reset_is_causal, softmax_scale - ) - output = attn_output.view(*attn_output.shape[:2], -1) - return (output, None, past_key_value) - - -class MultiheadAttention(nn.Module): - """Multi-head self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__( - self, - config, - prefix, - weights, - ): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config.attn_pdrop - - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // weights.process_group.size() - self.Wqkv = load_col( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - if self.qk_ln: - bias = not config.no_bias - hidden_size = config.d_model - head_dim = hidden_size // self.n_heads - - self.q_ln = LPLayerNorm( - d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights - ) - self.k_ln = LPLayerNorm( - self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights - ) - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.chunk(3, dim=2) - - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - ) - out = self.out_proj(context) - return (out, attn_weights, past_key_value) - - -class MultiQueryAttention(nn.Module): - """Multi-Query self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__(self, config, prefix, weights, verbose=False): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config.attn_pdrop - # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - (d_model, d_model + self.head_dim) - if self.qk_ln: - raise NotImplementedError("qk_ln not supported") - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - if verbose: - warnings.warn( - "While `attn_impl: triton` can be faster than `attn_impl: flash` " - + "it uses more memory. When training larger models this can trigger " - + "alloc retries which hurts performance. If encountered, we recommend " - + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." - ) - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - if torch.cuda.is_available() and verbose: - warnings.warn( - "Using `attn_impl: torch`. If your model does not use `alibi` or " - + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " - + "we recommend using `attn_impl: triton`." - ) - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.out_proj._is_residual = True - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2 - ) - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - multiquery=True, - ) - return (self.out_proj(context), attn_weights, past_key_value) - - -def attn_bias_shape( - attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - if (prefix_lm or not causal) or use_sequence_id: - return (1, n_heads, seq_len, seq_len) - return (1, n_heads, 1, seq_len) - elif prefix_lm or use_sequence_id: - return (1, 1, seq_len, seq_len) - return None - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def build_attn_bias( - attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - (device, dtype) = (attn_bias.device, attn_bias.dtype) - attn_bias = attn_bias.add( - build_alibi_bias( - n_heads, - seq_len, - full=not causal, - alibi_bias_max=alibi_bias_max, - device=device, - dtype=dtype, - ) - ) - return attn_bias - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def gen_slopes(n_heads, alibi_bias_max=8, device=None): - _n_heads = 2 ** math.ceil(math.log2(n_heads)) - m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) - m = m.mul(alibi_bias_max / _n_heads) - slopes = 1.0 / torch.pow(2, m) - if _n_heads != n_heads: - slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - return slopes.view(1, n_heads, 1, 1) - - -def build_alibi_bias( - n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None -): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( - 1, 1, 1, seq_len - ) - if full: - alibi_bias = alibi_bias - torch.arange( - 1 - seq_len, 1, dtype=torch.int32, device=device - ).view(1, 1, seq_len, 1) - alibi_bias = alibi_bias.abs().mul(-1) - slopes = gen_slopes(n_heads, alibi_bias_max, device=device) - alibi_bias = alibi_bias * slopes - return alibi_bias.to(dtype=dtype) - - -ATTN_CLASS_REGISTRY = { - "multihead_attention": MultiheadAttention, - "multiquery_attention": MultiQueryAttention, -} - -"""GPT Blocks used for the GPT Model.""" - - -class MPTMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) - self.up_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias - ) - self.act = nn.GELU(approximate="none") - # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.down_proj._is_residual = True - - def forward(self, x): - return self.down_proj(self.act(self.up_proj(x))) - - -class MPTBlock(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.prefix = prefix - if config.attn_config.attn_type != "multihead_attention": - raise NotImplementedError( - f"""Not implemented attn {config.attn_config.attn_type}""" - ) - resid_pdrop = config.resid_pdrop - if config.no_bias: - self.norm_1 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - else: - self.norm_1 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) - self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) - self.resid_attn_dropout = nn.Dropout(resid_pdrop) - self.resid_ffn_dropout = nn.Dropout(resid_pdrop) - - def forward( - self, - x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attn_bias: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.ByteTensor] = None, - is_causal: bool = True, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: - a = self.norm_1(x) - (b, attn_weights, past_key_value) = self.attn( - a, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=is_causal, - ) - x = x + self.resid_attn_dropout(b) - m = self.norm_2(x) - n = self.ffn(m) - x = x + self.resid_ffn_dropout(n) - return (x, attn_weights, past_key_value) - - -def _cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor - - -class LPLayerNorm(torch.nn.LayerNorm): - def __init__( - self, - normalized_shape, - eps=1e-05, - elementwise_affine=True, - device=None, - dtype=None, - bias: Optional[bool] = True, - prefix=None, - weights=None, - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - device=device, - dtype=dtype, - bias=bias, - ) - if weights is not None: - self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) - if bias: - self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) - self.normalized_shape = self.weight.shape - - def forward(self, x): - module_device = x.device - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - downcast_bias = ( - _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - ) - with torch.autocast(enabled=False, device_type=module_device.type): - return torch.nn.functional.layer_norm( - downcast_x, - self.normalized_shape, - downcast_weight, - downcast_bias, - self.eps, - ) - - -def rms_norm(x, weight=None, eps=1e-05): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - if weight is not None: - return output * weight - return output - - -class RMSNorm(torch.nn.Module): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__() - self.eps = eps - if weight: - self.weight = torch.nn.Parameter( - torch.ones(normalized_shape, dtype=dtype, device=device) - ) - else: - self.register_parameter("weight", None) - - def forward(self, x): - return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) - - -class LPRMSNorm(RMSNorm): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - weight=weight, - dtype=dtype, - device=device, - ) - - def forward(self, x): - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - with torch.autocast(enabled=False, device_type=x.device.type): - return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) - - -NORM_CLASS_REGISTRY = { - "layernorm": torch.nn.LayerNorm, - "low_precision_layernorm": LPLayerNorm, - "rmsnorm": RMSNorm, - "low_precision_rmsnorm": LPRMSNorm, -} - -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - -class MPTPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" - _no_split_modules = ["MPTBlock"] - - -class MPTModel(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - # config._validate_config() - super().__init__(config) - self.world_size = weights.process_group.size() - self.rank = weights.process_group.rank() - self.n_heads = config.n_heads - self.attn_impl = config.attn_config.attn_impl - self.prefix_lm = config.attn_config.prefix_lm - self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id - self.alibi = config.attn_config.alibi - self.alibi_bias_max = config.attn_config.alibi_bias_max - if config.init_device == "mixed": - # TODO: reimplement mixed device initialization - # dist.get_local_rank() == 0: - if True: - config.init_device = "cpu" - else: - config.init_device = "meta" - if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." - ) - if config.norm_type.lower() != "low_precision_layernorm": - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo." - ) - - self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) - - if not self.alibi: - self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) - self.blocks = nn.ModuleList( - [ - MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) - for i in range(config.n_layers) - ] - ) - if config.no_bias: - self.norm_f = nn.LayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - else: - self.norm_f = nn.LayerNorm.load( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - self.is_causal = not self.prefix_lm - self._attn_bias_initialized = False - self.attn_bias = None - self.attn_bias_shape = attn_bias_shape( - self.attn_impl, - config.n_heads, - config.max_seq_len, - self.alibi, - prefix_lm=self.prefix_lm, - causal=self.is_causal, - use_sequence_id=self.attn_uses_sequence_id, - ) - if config.no_bias: - for module in self.modules(): - if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): - if config.verbose: - warnings.warn(f"Removing bias ({module.bias}) from {module}.") - module.register_parameter("bias", None) - if hasattr(self.config, "verbose"): - if config.verbose and config.verbose > 2: - print(self) - if "verbose" not in self.config.init_config: - self.config.init_config["verbose"] = self.config.verbose - if self.config.init_config["verbose"] > 1: - init_fn_name = self.config.init_config["name"] - warnings.warn(f"Using {init_fn_name} initialization.") - - @torch.no_grad() - def _attn_bias( - self, - device, - dtype, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - ): - if not self._attn_bias_initialized: - if self.attn_bias_shape: - self.attn_bias = torch.zeros( - self.attn_bias_shape, device=device, dtype=dtype - ) - self.attn_bias = build_attn_bias( - self.attn_impl, - self.attn_bias, - self.config.n_heads, - self.config.max_seq_len, - causal=self.is_causal, - alibi=self.alibi, - alibi_bias_max=self.alibi_bias_max, - ) - assert self.n_heads % self.world_size == 0 - block_size = self.n_heads // self.world_size - self.attn_bias = self.attn_bias[ - :, self.rank * block_size : (self.rank + 1) * block_size - ] - self._attn_bias_initialized = True - if self.attn_impl == "flash": - return (self.attn_bias, attention_mask) - if self.attn_bias is not None: - self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) - attn_bias = self.attn_bias - if self.prefix_lm: - assert isinstance(attn_bias, torch.Tensor) - assert isinstance(prefix_mask, torch.Tensor) - attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) - if self.attn_uses_sequence_id and sequence_id is not None: - assert isinstance(attn_bias, torch.Tensor) - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) - if attention_mask is not None: - s_k = attention_mask.shape[-1] - if attn_bias is None: - attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) - else: - _s_k = max(0, attn_bias.size(-1) - s_k) - attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: - raise ValueError( - f"attention_mask shape={attention_mask.shape} " - + f"and prefix_mask shape={prefix_mask.shape} are not equal." - ) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill( - ~attention_mask.view(-1, 1, 1, s_k), min_val - ) - return (attn_bias, None) - - def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): - (s_k, s_q) = attn_bias.shape[-2:] - if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: - raise ValueError( - "attn_bias does not match the expected shape. " - + f"The last two dimensions should both be {self.config.max_length} " - + f"but are {s_k} and {s_q}." - ) - seq_len = prefix_mask.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - causal = torch.tril( - torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) - ).view(1, 1, seq_len, seq_len) - prefix = prefix_mask.view(-1, 1, 1, seq_len) - cannot_attend = ~torch.logical_or(causal, prefix.bool()) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def _apply_sequence_id( - self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor - ): - seq_len = sequence_id.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - cannot_attend = torch.logical_not( - torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) - ).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if attention_mask is not None: - attention_mask = attention_mask.bool() - if prefix_mask is not None: - prefix_mask = prefix_mask.bool() - if not return_dict: - raise NotImplementedError( - "return_dict False is not implemented yet for MPT" - ) - if output_attentions: - if self.attn_impl != "torch": - raise NotImplementedError( - "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." - ) - if ( - attention_mask is not None - and attention_mask[:, 0].sum() != attention_mask.shape[0] - and self.training - ): - raise NotImplementedError( - "MPT does not support training with left padding." - ) - if self.prefix_lm and prefix_mask is None: - raise ValueError( - "prefix_mask is a required argument when MPT is configured with prefix_lm=True." - ) - if self.training: - if self.attn_uses_sequence_id and sequence_id is None: - raise ValueError( - "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " - + "and the model is in train mode." - ) - elif self.attn_uses_sequence_id is False and sequence_id is not None: - warnings.warn( - "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " - + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." - ) - S = input_ids.size(1) - assert ( - S <= self.config.max_seq_len - ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" - tok_emb = self.wte(input_ids) - if self.alibi: - x = tok_emb - else: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - "past_key_values must provide a past_key_value for each attention " - + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." - ) - past_position = past_key_values[0][0].size(1) - if self.attn_impl == "torch": - past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: - raise ValueError( - f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." - ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - pos = torch.clamp( - pos - - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ - :, past_position: - ], - min=0, - ) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - (attn_bias, attention_mask) = self._attn_bias( - device=x.device, - dtype=torch.float32, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - ) - if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.config.n_layers)] - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for b_idx, block in enumerate(self.blocks): - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - past_key_value = ( - past_key_values[b_idx] if past_key_values is not None else None - ) - (x, attn_weights, past_key_value) = block( - x, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=self.is_causal, - ) - if past_key_values is not None: - past_key_values[b_idx] = past_key_value - if output_attentions: - assert all_self_attns is not None - all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - return BaseModelOutputWithPast( - last_hidden_state=x, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - if not config.tie_word_embeddings: - raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(prefix, config, weights) - self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.wte", weights=weights - ) - self.logit_scale = None - if config.logit_scale is not None: - logit_scale = config.logit_scale - if isinstance(logit_scale, str): - if logit_scale == "inv_sqrt_d_model": - logit_scale = 1 / math.sqrt(config.d_model) - else: - raise ValueError( - f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." - ) - self.logit_scale = logit_scale - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - return_dict=return_dict, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=use_cache, - ) - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - if self.logit_scale is not None: - if self.logit_scale == 0: - warnings.warn( - f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." - ) - logits *= self.logit_scale - loss = None - if labels is not None: - labels = torch.roll(labels, shifts=-1) - labels[:, -1] = -100 - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) - ) - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - if inputs_embeds is not None: - raise NotImplementedError("inputs_embeds is not implemented for MPT yet") - attention_mask = kwargs["attention_mask"].bool() - if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError( - "MPT does not support generation with right padding." - ) - if self.transformer.attn_uses_sequence_id and self.training: - sequence_id = torch.zeros_like(input_ids[:1]) - else: - sequence_id = None - if past_key_values is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if self.transformer.prefix_lm: - prefix_mask = torch.ones_like(attention_mask) - if kwargs.get("use_cache") is False: - raise NotImplementedError( - "MPT with prefix_lm=True does not support use_cache=False." - ) - else: - prefix_mask = None - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "prefix_mask": prefix_mask, - "sequence_id": sequence_id, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache", True), - } - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - """Used by HuggingFace generate when using beam search with kv-caching. - - See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 - for an example in transformers. - """ - reordered_past = [] - for layer_past in past_key_values: - reordered_past += [ - tuple( - (past_state.index_select(0, beam_idx) for past_state in layer_past) - ) - ] - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py deleted file mode 100644 index 06731a6f9..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ /dev/null @@ -1,796 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch GPTNeoX model.""" - -from typing import Optional, Tuple, Union - -import os -import torch -import torch.distributed -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - - -CUSTOM_KERNELS_ENABLED = False -if ( - torch.cuda.is_available() - and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" -): - try: - from custom_kernels import fused_attention_cuda - - CUSTOM_KERNELS_ENABLED = True - except ImportError: - pass - - -def make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.ones( - (target_length, target_length + past_key_values_length), - dtype=torch.bool, - device=device, - ) - mask = mask.triu(1 + past_key_values_length) - - expanded_mask = mask.unsqueeze(0).expand( - batch_size, target_length, target_length + past_key_values_length - ) - return expanded_mask - - -def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: - """ - Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. - """ - batch_size, src_length = mask.shape - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = ~(mask[:, None, :].to(torch.bool)) - return expanded_mask.expand(batch_size, tgt_length, src_length) - - -def prepare_attn_mask( - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, -) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -class GPTNeoXPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - -class GPTNeoXAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size * config.rotary_pct) - # ??? TODO - # self.register_buffer( - # "bias", - # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - # 1, 1, max_positions, max_positions - # ), - # ) - # self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, - config.max_position_embeddings, - base=config.rotary_emb_base, - ) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) - self.inv_norm_factor = 1.0 / torch.sqrt( - torch.tensor(self.head_size, dtype=torch.float32) - ).to(torch.get_default_dtype()) - - if self.num_attention_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_attention_heads` must be divisible by `num_shards` " - f"(got `num_attention_heads`: {self.num_attention_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_attention_heads = ( - self.num_attention_heads // weights.process_group.size() - ) - self.query_key_value = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True - ) - self.dense = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense", weights=weights, bias=True - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask, - head_mask=None, - layer_past=None, - use_cache=False, - output_attentions=False, - ): - has_layer_past = layer_past is not None - - # Compute QKV - # Attention heads [batch, seq_len, hidden_size] - # --> [batch, seq_len, (np * 3 * head_size)] - qkv = self.query_key_value(hidden_states) - - # [batch, seq_len, (num_heads * 3 * head_size)] - # --> [batch, seq_len, num_heads, 3 * head_size] - new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) - qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) - # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] - query, key, value = qkv.split(self.head_size, -1) - - # Compute token offset for rotary embeddings (when decoding) - seq_len = key.shape[-2] - if has_layer_past: - seq_len += layer_past[0].shape[-2] - - # Compute rotary embeddings on rotary_ndims - query_rot = query[..., : self.rotary_ndims] - key_rot = key[..., : self.rotary_ndims] - - query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) - - query[..., : self.rotary_ndims] = query_rot - key[..., : self.rotary_ndims] = key_rot - - if CUSTOM_KERNELS_ENABLED: - attn_output, present, attn_weights = fused_attention_cuda.forward( - query, - key, - value, - layer_past, - attention_mask, - head_mask, - self.inv_norm_factor, - self.num_attention_heads, - use_cache, - ) - else: - # Cache QKV values - if has_layer_past: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = (key, value) if use_cache else None - - # Compute attention - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask - ) - - # Reshape outputs - attn_output = self._merge_heads( - attn_output, self.num_attention_heads, self.head_size - ) - - attn_output = self.dense(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs - - @classmethod - def _split_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - # tensor: [bs, seq_len, hidden_size] - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(new_shape) - # -> [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3) - return tensor - - @classmethod - def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - # tensor [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3).contiguous() - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view( - tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size - ) - # -> [bs, seq_len, hidden_size] - return tensor - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] - # compute causal mask from causal mask buffer - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - - query = query.reshape( - batch_size * num_attention_heads, query_length, attn_head_size - ) - key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - 1, - dtype=query.dtype, - device=key.device, - ).expand(batch_size * num_attention_heads, query_length, key_length) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.inv_norm_factor, - ) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attn_scores.dtype - if input_dtype in [torch.float16, torch.bfloat16]: - attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where( - attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores - ) - attn_scores = attn_scores.view( - batch_size, num_attention_heads, query_length, key_length - ) - - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings, base=10000, device=None): - super().__init__() - self.true_inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2).float().to(device) / dim) - ) - self.register_buffer("inv_freq", self.true_inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - self.cos_cached = None - self.sin_cached = None - - @staticmethod - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - @staticmethod - def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange( - max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype - ) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) - - def forward(self, q, k, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if ( - seq_len > self.max_seq_len_cached - or self.cos_cached is None - or self.sin_cached is None - ): - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.cos_cached, self.sin_cached = self._create_cos_sin( - self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device - ) - return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) - - -@torch.jit.script -def rotary_forward(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) - - chunk_size = q.shape[-1] // 2 - q1, q2 = q.split(chunk_size, -1) - q_rotated = torch.cat((-q2, q1), dim=-1) - k1, k2 = k.split(chunk_size, -1) - k_rotated = torch.cat((-k2, k1), dim=-1) - - q_embed = (q * cos) + (q_rotated * sin) - k_embed = (k * cos) + (k_rotated * sin) - return q_embed, k_embed - - -class GPTNeoXMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.act = ( - ACT2FN[config.hidden_act] - if "gelu_fast" not in config.hidden_act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - self.dense_h_to_4h = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True - ) - self.dense_4h_to_h = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True - ) - - def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states - - -class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, prefix: str, config, weights): - super().__init__() - self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.input_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.attention = GPTNeoXAttention( - config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights - ) - self.mlp = GPTNeoXMLP( - config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask=None, - head_mask=None, - use_cache=False, - layer_past=None, - output_attentions=False, - ): - attention_layer_outputs = self.attention( - self.input_layernorm(hidden_states), - attention_mask=attention_mask, - position_ids=position_ids, - layer_past=layer_past, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attention_layer_outputs[ - 0 - ] # output_attn: attn_output, present, (attn_weights) - outputs = attention_layer_outputs[1:] - - if self.use_parallel_residual: - # pseudocode: - # x = x + attn(ln1(x)) + mlp(ln2(x)) - mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = mlp_output + attn_output + hidden_states - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - attn_output = attn_output + hidden_states - mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - hidden_states = mlp_output + attn_output - - if use_cache: - outputs = ( - hidden_states, - ) + outputs # hidden_states, present, (attn_weights) - else: - outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) - - return outputs - - -class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - self.config = config - - self.num_attention_heads = config.num_attention_heads - - self.embed_in = TensorParallelEmbedding( - prefix=f"{prefix}.embed_in", weights=weights - ) - self.layers = nn.ModuleList( - [ - GPTNeoXLayer(layer_id, prefix, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.tp_world_size = weights.process_group.size() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids=None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * self.config.num_hidden_layers) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_length, seq_length + past_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - - hidden_states = inputs_embeds - - # Attention mask. - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-1] - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), device=hidden_states.device - ) - else: - attention_mask = attention_mask.to(hidden_states.device) - - causal_mask = prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - assert self.num_attention_heads % self.tp_world_size == 0 - block_size = self.num_attention_heads // self.tp_world_size - causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - presents = () if use_cache else None - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = layer( - hidden_states, - position_ids=position_ids, - attention_mask=causal_mask, - head_mask=head_mask[i], - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_attentions = all_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.final_layer_norm(hidden_states) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_attentions] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "gpt_neox" - else: - prefix = f"{prefix}.gpt_neox" - - self.gpt_neox = GPTNeoXModel(prefix, config, weights) - self.embed_out = SpeculativeHead.load( - config, prefix="embed_out", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are - only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see - `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config.is_decoder = True - >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.gpt_neox( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - lm_logits, speculative_logits = self.embed_out(hidden_states) - - lm_loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # we are doing next-token prediction; shift prediction scores and input ids by one - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return ( - CausalLMOutputWithPast( - loss=lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - input_shape = input_ids.shape - - # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - ) - - return model_inputs - - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past[:2] - ) - + layer_past[2:], - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py deleted file mode 100644 index bd4403214..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ /dev/null @@ -1,857 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch OPT model.""" -import random -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers import OPTConfig -from text_generation_server.layers import ( - FastLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -EPS = 1e-5 - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full( - (tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device, - ) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -class OPTLearnedPositionalEmbedding(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, prefix: str, weights): - super().__init__() - self.offset = 2 - self.weight = nn.Parameter( - weights.get_tensor( - f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" - ) - ) - - def forward( - self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 - ): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = ( - torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask - ).long() - 1 - - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] - - return torch.nn.functional.embedding(positions + self.offset, self.weight) - - -class OPTAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config, - prefix, - weights, - is_decoder: bool = False, - bias: bool = True, - process_group=None, - ): - super().__init__() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - - self.hidden_size = hidden_size - self.num_heads = num_heads - self.dropout = config.dropout - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - process_group = weights.process_group - if self.num_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads = self.num_heads // process_group.size() - self.hidden_size = self.hidden_size // process_group.size() - - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias - ) - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - if attn_weights.dtype == torch.float16: - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(torch.float16) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): - super().__init__() - self.process_group = weights.process_group - self.hidden_size = config.hidden_size - prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" - self.self_attn = OPTAttention( - config, - prefix=f"{prefix}.self_attn", - weights=weights, - is_decoder=True, - bias=config.enable_bias, - ) - self.do_layer_norm_before = config.do_layer_norm_before - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - - self.self_attn_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS - ) - self.fc1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias - ) - self.fc2 = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Fully Connected - hidden_states_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - hidden_states = (residual + hidden_states).view(hidden_states_shape) - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class OPTPreTrainedModel(PreTrainedModel): - config_class = OPTConfig - - -class OPTDecoder(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.vocab_size = config.vocab_size - - prefix = prefix + "." if prefix else "" - - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}decoder.embed_tokens", weights=weights - ) - self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) - - if config.word_embed_proj_dim != config.hidden_size: - self.project_out = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_out", - weights=weights, - bias=False, - ) - else: - self.project_out = None - - if config.word_embed_proj_dim != config.hidden_size: - self.project_in = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_in", - weights=weights, - bias=False, - ) - else: - self.project_in = None - - # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility - # with checkpoints that have been fine-tuned before transformers v4.20.1 - # see https://github.com/facebookresearch/metaseq/pull/164 - if config.do_layer_norm_before and not config._remove_final_layer_norm: - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS - ) - else: - self.final_layer_norm = None - - self.layers = nn.ModuleList( - [ - OPTDecoderLayer(layer_id, prefix, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - past_key_values_length = ( - past_key_values[0][0].shape[2] if past_key_values is not None else 0 - ) - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - causal_attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) - - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - - hidden_states = inputs_embeds + pos_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): - continue - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.final_layer_norm is not None: - hidden_states = self.final_layer_norm(hidden_states) - - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class OPTModel(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.decoder = OPTDecoder(prefix, config, weights) - # Initialize weights and apply final processing - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs - - return BaseModelOutputWithPast( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - ) - - -class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, prefix, config, weights): - super().__init__(config) - - self.model = OPTModel(prefix, config, weights) - - self.lm_head = SpeculativeHead.load( - config, - prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", - weights=weights, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - - loss = None - - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py deleted file mode 100644 index 3f2ed010f..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ /dev/null @@ -1,336 +0,0 @@ -# imlementation of the PhiModel and PhiForCausalLM classes - -import torch -import torch.distributed - -import math -from torch import nn -from typing import Optional, List, Tuple -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast - -from text_generation_server.layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - SpeculativeHead, - FastLinear, -) - - -# PhiConfig is the configuration class for the PhiModel. -class PhiConfig(PretrainedConfig): - def __init__( - self, - vocab_size=51200, - n_positions=2048, - n_embd=2560, - n_layer=32, - n_inner=None, - n_head=32, - rotary_dim=32, - layer_norm_epsilon=1e-5, - tie_word_embeddings=False, - pad_vocab_size_multiple=64, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - no_bias=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.rotary_dim = rotary_dim - - self.layer_norm_epsilon = layer_norm_epsilon - self.tie_word_embeddings = tie_word_embeddings - self.pad_vocab_size_multiple = pad_vocab_size_multiple - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.no_bias = no_bias - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -# RotaryEmbedding is a class that implements the rotary embedding. -class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)] - inv_freq_len = len(inv_freq) - inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) - t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) - freqs = t.matmul(inv_freq) - self.sin = freqs.sin() - self.cos = freqs.cos() - - def apply_rotary_emb_qkv(self, qkv, seqlen_offset): - b_size, seqlen, three, _, _headdim = qkv.shape - if three != 3: - raise Exception("unexpected shape for qkv") - _, rotary_dim = self.cos.shape - rotary_dim = rotary_dim * 2 - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - q12 = torch.chunk(q_rot, 2, dim=-1) - k12 = torch.chunk(k_rot, 2, dim=-1) - q1, q2 = q12[0], q12[1] - k1, k2 = k12[0], k12[1] - c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - q_rot = torch.cat( - [ - q1 * c - q2 * s, - q1 * s + q2 * c, - ], - dim=-1, - ) - k_rot = torch.cat( - [ - k1 * c - k2 * s, - k1 * s + k2 * c, - ], - dim=-1, - ) - q = torch.cat([q_rot, q_pass], dim=-1) - k = torch.cat([k_rot, k_pass], dim=-1) - v = qkv[:, :, 2] - return q, k, v - - -# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm. -class PhiCausalLMHead(nn.Module): - def __init__(self, config, weights): - super().__init__() - self.ln = nn.LayerNorm.load( - prefix="lm_head.ln", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.linear = SpeculativeHead.load( - config=config, prefix="lm_head.linear", weights=weights - ) - - def forward(self, hidden_states): - hidden_states = self.ln(hidden_states) - hidden_states = self.linear(hidden_states) - return hidden_states - - -# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. -class PhiMHA(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - self.op_size = config.n_embd - self.head_dim = int(config.n_embd / config.n_head) - self.num_heads = config.n_head - self.rotary_emb = RotaryEmbedding( - config.rotary_dim, - config.n_positions, - ) - self.softmax_scale = 1.0 / math.sqrt(self.head_dim) - - def forward( - self, - hidden_states, - past_kv_cache, - attention_mask=None, - ): - b_size, seq_len, _n_embd = hidden_states.shape - qkv = self.Wqkv(hidden_states) - qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim) - seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1] - q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) - - # if there is a kv_cache, then we need to concatenate - if past_kv_cache is not None: - prev_k, prev_v = past_kv_cache - k = torch.cat([prev_k, k], dim=1) - v = torch.cat([prev_v, v], dim=1) - - past_kv_cache = [k, v] - attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale) - - if attention_mask is not None: - seqlen_k = k.shape[1] - seqlen_q = q.shape[1] - causal_mask = torch.triu( - torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), - 1, - ) - attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) - attn_output = ( - attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)) - .transpose(1, 2) - .flatten(-2) - ) - return self.out_proj(attn_output), past_kv_cache - - -# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. -class PhiMLP(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.n_inner = config.n_inner - self.fc1 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc1", - weights=weights, - bias=False, - ) - self.fc2 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc2", - weights=weights, - bias=False, - ) - self.activation = torch.nn.functional.gelu - - def forward(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. -class PhiBlock(nn.Module): - def __init__(self, layer_id, config, weights): - super().__init__() - self.layer_id = layer_id - self.layer_norm = nn.LayerNorm.load( - prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon - ) - self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) - self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) - - def forward( - self, - hidden_states, - kv_cache, - attention_mask, - ): - residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - attn_outputs, past_kv_cache = self.mixer( - hidden_states, kv_cache, attention_mask - ) - feed_forward_hidden_states = self.mlp(hidden_states) - out = attn_outputs + feed_forward_hidden_states + residual - return out, past_kv_cache - - -# PhiModel implements the embedding layer and the transformer blocks. -class PhiModel(nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - self.tp_rank = weights.process_group.rank() - self.tp_world_size = weights.process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embd.wte", weights=weights - ) - self.blocks = nn.ModuleList( - [ - PhiBlock(f"{prefix}.h.{layer_id}", config, weights) - for layer_id in range(config.n_layer) - ] - ) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - hidden_states = self.embed_tokens(input_ids) - seq_len = hidden_states.shape[1] - mask = None if seq_len <= 1 else attention_mask - - past_key_values = ( - [None] * len(self.blocks) if past_key_values is None else past_key_values - ) - - for index, block in enumerate(self.blocks): - hidden_states, new_key_values = block( - hidden_states, past_key_values[index], mask - ) - past_key_values[index] = new_key_values - - return hidden_states, past_key_values - - -# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. -class PhiForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - self.model = PhiModel(prefix, config, weights) - self.lm_head = PhiCausalLMHead(config, weights) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - model_output = self.model( - input_ids, past_key_values, attention_mask, return_dict, use_cache - ) - logits = self.lm_head(model_output[0]) - - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()( - logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1) - ) - - if not return_dict: - return ( - ((loss,) + (logits,) + model_output[1:]) - if loss is not None - else (logits,) + model_output[1:] - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=model_output[1], - hidden_states=None, - attentions=None, - ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py new file mode 100644 index 000000000..441b0016e --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -0,0 +1,946 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +import numpy as np + +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +from typing import Union +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ( + ProcessingKwargs, + ProcessorMixin, + Unpack, + VideosKwargs, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, chat_template=None, **kwargs + ): + self.image_token = ( + "<|image_pad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.video_token = ( + "<|video_pad|>" + if not hasattr(tokenizer, "video_token") + else tokenizer.video_token + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[ + TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] + ] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor( + images=images, videos=None, **output_kwargs["images_kwargs"] + ) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor( + images=None, videos=videos, **output_kwargs["images_kwargs"] + ) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [ + self.image_processor.temporal_patch_size / fps + ] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [ + self.image_processor.temporal_patch_size / tmp for tmp in fps + ] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" + * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" + * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list( + dict.fromkeys(tokenizer_input_names + image_processor_input_names) + ) + return names_from_processor + ["second_per_grid_ts"] + + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + spatial_patch_size=14, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_patch_size = spatial_patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if vision_config is not None: + self.vision_config = Qwen2_5_VLVisionConfig(**vision_config) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute sdpa + query = query.unsqueeze(0).transpose(1, 2) + key = key.unsqueeze(0).transpose(1, 2) + value = value.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + self.up = TensorParallelColumnLinear.load( + prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True + ) + self.gate = TensorParallelColumnLinear.load( + prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True + ) + self.down = TensorParallelRowLinear.load( + prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_states = self.gate(hidden_states) + up_states = self.up(hidden_states) + activated_states = self.activation_fn(gate_states) * up_states + down_states = self.down(activated_states) + return down_states + + +class Qwen2_5VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2_5VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastRMSNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastRMSNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2_5VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, _ = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = hidden_states + attn_out + norm2_out, _ = self.norm2(hidden_states) + mlp_out = self.mlp(norm2_out) + hidden_states = hidden_states + mlp_out + return hidden_states + + +class Qwen2_5VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastRMSNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2_5VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_channels, + out_channels=config.hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.hidden_size // config.num_heads + + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2_5VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2_5VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + self.window_size = config.window_size + self.patch_size = config.patch_size + self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size + self.fullatt_block_indexes = config.fullatt_block_indexes + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + seq_len = hidden_states.shape[0] + patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + og_shape = (seq_len, -1) + + hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view( + og_shape + ) + rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view( + og_shape + ) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device="cpu", + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to( + hidden_states.device + ) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + + # iterately apply the blocks to the hidden states + for layer_num, block in enumerate(self.blocks): + # NOTE: qwen2_5_vl.py has a concept of full attention blocks + # that are applied at specific layers. + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = block( + hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen + ) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + +class Qwen2_5VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment + if ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and config.rope_scaling.get("type", None) == "default" + ): + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2_5VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.device = weights.device + + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + # import ipdb + + # ipdb.set_trace() + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + # Unused in this model + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py new file mode 100644 index 000000000..47ae2ac94 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -0,0 +1,519 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn + + +from habana_frameworks.torch.hpex.kernels import FusedSDPA +from vllm_hpu_extension.utils import ModuleFusedSDPA + + +import numpy as np + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.embed_dim // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute sdpa + query = query.unsqueeze(0).transpose(1, 2) + key = key.unsqueeze(0).transpose(1, 2) + value = value.unsqueeze(0).transpose(1, 2) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, + ) + attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastLayerNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastLayerNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, residual = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = attn_out + residual + norm2_out, residual = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(norm2_out) + return hidden_states + + +class Qwen2VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.embed_dim * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastLayerNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.embed_dim // config.num_heads + # TODO: replace with static positional embeddings once implemented + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.embed_dim + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + # iterately apply the blocks to the hidden states + for block in self.blocks: + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + return hidden_states + + +class Qwen2VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2-VL model at the moment + if ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and config.rope_scaling.get("type", None) == "default" + ): + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.device = weights.device + + # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391 + # modified to first find segments then initialize position ids for each segment + # Steps: + # locate all vision and text segments + # calculate `vision_segment_lengths` for each vision segment to be use as offset + # calculate `text_segment_lengths` for each text segment to be used as offset + # create position ids for each vision segment based on the image grid + # create position ids for each text segment + # combine all the position ids + # the final segment is the difference between the last vision segment and the end of the input + # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=seqlen, + hpu_attention_meta=hpu_attention_meta, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py deleted file mode 100644 index e6666acd3..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ /dev/null @@ -1,1227 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch T5 model.""" - -import copy -import math -import warnings -from typing import Optional, Tuple, Union - -from loguru import logger - -import torch -import torch.distributed -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( - is_torch_fx_proxy, -) -from transformers import T5Config -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" - - -class PartialTPEmbedding(nn.Module): - def __init__(self, prefix: str, weights): - super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=1) - self.weight = nn.Parameter(weight) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.embedding(input, self.weight) - - -@torch.jit.script -def layer_norm(hidden_states, weight, epsilon): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + epsilon) - - # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(weight.dtype) - - return weight * hidden_states - - -class T5LayerNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = torch.tensor(eps) - - def forward(self, hidden_states): - return layer_norm(hidden_states, self.weight, self.variance_epsilon) - - -try: - from apex.normalization import FusedRMSNorm - - T5LayerNorm = FusedRMSNorm # noqa - - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" - ) -except ImportError: - # using the normal T5LayerNorm - pass -except Exception: - logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") - pass - -ALL_LAYERNORM_LAYERS.append(T5LayerNorm) - - -class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi", weights=weights, bias=False - ) - - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi_0 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_0", weights=weights, bias=False - ) - self.wi_1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_1", weights=weights, bias=False - ) - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - if config.is_gated_act: - self.DenseReluDense = T5DenseGatedActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - else: - self.DenseReluDense = T5DenseActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class T5Attention(nn.Module): - def __init__( - self, config: T5Config, prefix, weights, has_relative_attention_bias=False - ): - super().__init__() - self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - process_group = weights.process_group - # Mesh TensorFlow initialization to avoid scaling before softmax - assert self.n_heads % process_group.size() == 0 - self.q = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q", weights=weights, bias=False - ) - self.k = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k", weights=weights, bias=False - ) - self.v = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v", weights=weights, bias=False - ) - self.o = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.o", weights=weights, bias=False - ) - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // process_group.size() - self.inner_dim = self.inner_dim // process_group.size() - - if self.has_relative_attention_bias: - self.relative_attention_bias = PartialTPEmbedding( - prefix=f"{prefix}.relative_attention_bias", weights=weights - ) - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets - - def compute_bias(self, query_length, key_length, device=None): - """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length - ) - - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) - - def shape(states): - """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) - - def unshape(states): - """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device - ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = ( - position_bias + mask - ) # (batch_size, n_heads, seq_length, key_length) - - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class T5LayerSelfAttention(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias=False): - super().__init__() - self.SelfAttention = T5Attention( - config, - prefix=f"{prefix}.SelfAttention", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5LayerCrossAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.EncDecAttention = T5Attention( - config, - prefix=f"{prefix}.EncDecAttention", - weights=weights, - has_relative_attention_bias=False, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - query_length=query_length, - output_attentions=output_attentions, - ) - layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5Block(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): - super().__init__() - self.is_decoder = config.is_decoder - self.layer = nn.ModuleList() - self.layer.append( - T5LayerSelfAttention( - config, - prefix=f"{prefix}.layer.0", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - ) - if self.is_decoder: - i = 2 - self.layer.append( - T5LayerCrossAttention( - config, prefix=f"{prefix}.layer.1", weights=weights - ) - ) - else: - i = 1 - - self.layer.append( - T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) - ) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning( - "`past_key_values` is passed to the encoder. Please make sure this is intended." - ) - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[ - 2: - ] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - outputs = (hidden_states,) - - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - - -class T5PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = T5Config - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." - " See T5 docs for more information" - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id - - assert ( - pad_token_id is not None - ), "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - - -class T5Stack(T5PreTrainedModel): - def __init__(self, config, prefix, weights, embed_tokens): - super().__init__(config) - - self.is_decoder = config.is_decoder - - self.embed_tokens = embed_tokens - self.block = nn.ModuleList( - [ - T5Block( - config, - prefix=f"{prefix}.block.{layer_id}", - weights=weights, - has_relative_attention_bias=(layer_id == 0), - ) - for layer_id in range(config.num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - # Model parallel - use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" - ) - - if inputs_embeds is None: - assert ( - self.embed_tokens is not None - ), "You have to initialize the model with valid token embeddings" - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values[0][0].shape[2] + seq_length - if past_key_values is not None - else seq_length - ) - - if use_cache is True: - assert ( - self.is_decoder - ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" - - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - if ( - self.is_decoder - and encoder_attention_mask is None - and encoder_hidden_states is not None - ): - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, - encoder_seq_length, - device=inputs_embeds.device, - dtype=torch.long, - ) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.is_decoder and encoder_hidden_states is not None: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) - present_key_value_states = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds) - - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): - layer_head_mask = head_mask[i] - cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class T5ForConditionalGeneration(T5PreTrainedModel): - def __init__(self, config: T5Config, weights): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack( - config=encoder_config, - prefix="encoder", - weights=weights, - embed_tokens=self.shared, - ) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack( - config=decoder_config, - prefix="decoder", - weights=weights, - embed_tokens=self.shared, - ) - - try: - self.lm_head = SpeculativeHead.load( - config, prefix="lm_head", weights=weights - ) - except RuntimeError: - # Some models like t5-small were saved with shared weights unlike flan - # Since they are declared as the same arch we have no choice but hope - # that this is OK instead of using a proper flag. - self.lm_head = SpeculativeHead.load( - config, prefix="shared", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - logits, speculative_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # move labels to correct device to enable PP - labels = labels.to(logits.device) - loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - return ( - Seq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning( - "You might want to consider setting `use_cache=True` to speed up decoding" - ) - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) - return reordered_decoder_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py index e5c44045a..ae704af31 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, @@ -16,7 +16,13 @@ def load_text_model(prefix, config, weights, name=None): FlashGemmaForCausalLM, ) - return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + return FlashGemmaForCausalLM(prefix, config, weights) + elif config.model_type == "gemma2": + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + + return FlashGemma2ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index bc9d44a0b..a4d58596b 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext import math import os import time @@ -16,12 +15,19 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict - +from typing import ( + Any, + Iterable, + Optional, + Tuple, + List, + Type, + Dict, + Union, +) +import torch.nn.functional as F from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens @@ -39,27 +45,34 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( - MEM_POOL, - ATTENTION, BLOCK_SIZE, - CUDA_GRAPHS, + REQUEST_LOGPROBS, TGI_WIGGLE_ROOM, get_adapter_to_index, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import ( + KVCache, + Seqlen, + HPUPagedAttentionMetadata, + trim_attn_metadata, + trim_seqlen_metadata, +) from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments - from text_generation_server.utils.import_utils import ( empty_cache, synchronize, get_free_memory, ) -tracer = trace.get_tracer(__name__) +import vllm_hpu_extension.environment as environment +import habana_frameworks.torch as htorch +import itertools +from vllm_hpu_extension.ops import batch2block, block2batch +tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -75,38 +88,75 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW -def init_cpu_threads_env(rank_id: int, world_size: int): - import importlib.util +def prepare_for_decode( + dtype, use_contiguous_pa, device, slot, block_tables, batch_size +): + # Prepare values if we need to continue decoding + # need for HPUPagedAttentionMetadata preparation + def flatten(in_list): + return list(itertools.chain(*in_list)) - if importlib.util.find_spec("numa") is not None: - import numa - import psutil + def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] - nodes = numa.info.get_max_node() + 1 - rank_per_node = math.ceil(world_size / nodes) - num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int(rank_id / rank_per_node) - rank_offset_per_node = rank_id % rank_per_node - if os.getenv("OMP_NUM_THREADS") is None: - num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) - else: - num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) - if len(numa.memory.get_membind_nodes()) == nodes: - numa.memory.set_membind_nodes((node_id)) - torch.set_num_threads(num_cpus_per_rank) - if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True): - cpu_start = num_cpus_per_rank * rank_offset_per_node - numa.schedule.run_on_cpus( - 0, - *( - numa.info.node_to_cpus(node_id)[ - cpu_start : cpu_start + num_cpus_per_rank - ] - ), - ) - logger.info( - f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}" + def pad_list(input, k, v): + input_len = len(input) + target_len = (input_len + k - 1) // k * k + padding = target_len - input_len + return input + [v] * padding + + last_block_usage = slot % BLOCK_SIZE + 1 + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [ + [BLOCK_SIZE] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt + ] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + if use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + # block_bucket_size) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + block_list = gather_list(block_list, indices, 0) + block_groups = gather_list(block_groups, indices, -1) + block_usage = gather_list(block_usage, indices, 1) + else: + block_bucket_size = len(block_list) + block_list = pad_list(block_list, block_bucket_size, 0) + block_groups = pad_list(block_groups, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + + block_list = torch.tensor(block_list, dtype=torch.int, device=device) + block_groups = torch.tensor(block_groups, dtype=torch.int, device=device) + block_usage = torch.tensor(block_usage, dtype=dtype, device=device) + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + ones = torch.ones( + (block_mapping.size(0),), device=device, dtype=block_mapping.dtype + ) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + return trim_attn_metadata( + HPUPagedAttentionMetadata( + block_list=block_list, + block_groups=block_groups, + block_usage=block_usage, + block_mapping=block_mapping.to(dtype), + attn_bias=attn_bias, + block_scales=block_scales, ) + ) @dataclass @@ -117,25 +167,17 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] - - # Paged Attention values - # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] @@ -143,19 +185,32 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor + # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch + # used for filtering + cu_slots: torch.Tensor - max_seqlen: int + max_input_length: int + max_current_length: int + + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] - - # Prefixes - prefix_ids: List[List[int]] + # Will be set by `generate_token` and reset after each prefill forward + prefill_logprob_tokens: List[Optional[Tokens]] # All tokens all_input_ids: List[List[int]] @@ -163,7 +218,14 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - input_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + cache_lengths: List[int] + prompt_lengths: List[int] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -174,19 +236,27 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Adapter metadata for each request - adapter_meta: AdapterBatchMetadata + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + adapter_meta: Optional[AdapterBatchMetadata] # Number of blocks in this batch num_blocks: int # Maximum number of blocks max_blocks: int + hpu_attn_meta: Optional[HPUPagedAttentionMetadata] + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) @classmethod @@ -218,86 +288,67 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - sliding_window = get_sliding_windows() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - + cache_lengths = [] input_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] - prefix_ids = [] + all_postfix_ids = [] requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] + slots = [] + cu_slots = [0] next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 + cu_blocks = [0] block_tables = [] - slots = [] - prefix_lens = [] + block_tables_ragged = [] # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): + ### XXX: This consumes so much memory on long requests + ### Deactivating it by default seems like the best course. + if not REQUEST_LOGPROBS: + r.prefill_logprobs = False # request id -> idx in list mapping requests_idx_mapping[r.id] = i - orig_input_length = len(tokenized_input) + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) + + cache_length = r.cache_len - prefix_len = r.prefix_len assert ( - prefix_len <= orig_input_length - ), f"Prefix {prefix_len} vs input {orig_input_length}" - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: + assert False, "unreachable" - # Commented as it's costly. - # log_master(logger.debug, "Tokenized input ids {tokenized_input}") - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - input_length = len(tokenized_input) input_lengths.append(input_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( @@ -307,22 +358,13 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((input_length,), adapter_index)) - adapter_set.add(adapter_index) - # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length # Tokens that need to be mapped to blocks. - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length - - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_length + block_tokens = prompt_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: @@ -337,70 +379,30 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length - ] + request_slots = r.slots block_tables.append(request_blocks) + block_tables_ragged.extend(request_blocks) + cu_blocks.append(len(block_tables_ragged)) slots.extend(request_slots) - prefix_lens.append(prefix_len) + cu_slots.append(len(slots)) + + cache_lengths.append(cache_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length - cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) max_length = max( - max_length, input_length + max_new_tokens + speculative_length + max_length, + prompt_length + max_new_tokens + speculative_length, ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -414,103 +416,71 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor, dtype=torch.int64, device=device ) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - - block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + block_tables_ragged = torch.tensor( + block_tables_ragged, device=device, dtype=torch.int32 ) + cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_tensor = torch.empty( + (len(block_tables), max_blocks), + device=device, + dtype=torch.int32, + ) + for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) + + slots = torch.tensor(slots, dtype=torch.int64, device=device) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, - slot_indices=slot_indices, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + cache_lengths=cache_lengths, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + slot_indices=None, + slots=slots, + cu_slots=cu_slots, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, + adapter_meta=None, + hpu_attn_meta=None, ) @classmethod @@ -533,7 +503,7 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -548,18 +518,23 @@ class FlashCausalLMBatch(Batch): # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 requests = [] - start_slots = [] block_tables = [] all_input_ids = [] - prefix_ids = [] + input_ids = [] + prompt_lengths = [] input_lengths = [] - prefix_lens = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] + cu_slots = [0] + + prefilling_mask = [] + prefill_logprob_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -567,8 +542,8 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 + max_slots = 0 + cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -577,16 +552,23 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling_mask.append(request_prefilling) + # Get length request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] - max_seqlen = max(max_seqlen, request_input_length) + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) + max_current_length = max( + max_current_length, request_cache_length + request_input_length + ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) + prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -594,60 +576,78 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + start_slot = self.cu_slots[idx] + end_slot = self.cu_slots[idx + 1] + slot_length = end_slot - start_slot # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True + slot_filtering_indices[start_slot:end_slot] = True - cumulative_max_length += request_input_length + remaining_tokens - 1 + cu_slots.append(cumulative_slot_tokens + slot_length) + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_slot_tokens + request_cache_length + + cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + slots = self.slots[slot_filtering_indices] - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] + + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -657,24 +657,29 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + cu_slots=cu_slots, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=self.prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -682,12 +687,8 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + hpu_attn_meta=None, ) @classmethod @@ -697,74 +698,105 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) + max_blocks = max(max_blocks, b.max_blocks) total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + max_input_length = max(max_input_length, b.max_input_length) + max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - input_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - b.input_lengths, b.stopping_criterias + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias ) ), ) + prefilling = prefilling or b.prefilling - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64) + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + if ( + batches[0].position_ids is not None + and batches[0].position_ids.dim() == 2 + ): + # Qwen2_vl case: + position_ids = batches[0].position_ids.new_empty( + (total_batch_size, batches[0].position_ids.shape[-1]) + ) + else: + position_ids = batches[0].position_ids.new_empty(total_batch_size) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() + + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - start_slots = [] block_tables = [] - prefix_lens = [] + cache_lengths = [] all_input_ids = [] - prefix_ids = [] + prompt_lengths = [] input_lengths = [] prefix_offsets = [] read_offsets = [] + prefill_logprob_tokens = [] + next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -783,32 +815,9 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices - ) - all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -816,20 +825,56 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + slots[slots_start_index:slots_end_index] = batch.slots + cu_slots[start_index + 1 : end_index + 1] = ( + batch.cu_slots[1:] + cumulative_slots + ) - start_slots.append(batch.start_slots + cumulative_slots) + if not prefilling: + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) + + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) @@ -837,12 +882,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens.extend(batch.top_n_tokens) # Update - cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) - - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + cumulative_batch_size += len(batch) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -852,13 +893,21 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states=fsm_grammar_states, ) - speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) - if batches[0].speculative_ids is not None - else None - ) + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: + speculative_ids = None - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -868,24 +917,29 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + cu_slots=cu_slots, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -893,12 +947,286 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + hpu_attn_meta=None, + ) + + def prepare_for_decode(self, dtype, use_contiguous_pa): + block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables = [] + for i, bt in enumerate(self.block_tables): + block_tables.append(bt[0 : block_num[i]]) + + self.hpu_attn_meta = prepare_for_decode( + dtype, + use_contiguous_pa, + self.block_tables_tensor.device, + self.slots[self.slot_indices], + block_tables, + self.input_ids.size(0), + ) + + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + device = self.block_tables_tensor.device + + # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position + # padding to left to work with sliding window + # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate + # the right logit position + input_ids_padded_length = [] + # need extra pad to match warmup seq + extra_pad = 0 + if isinstance(self.input_ids, list) and len(self) > 1: + input_ids_padded_length = [] + input_ids = [] + for input_id in self.input_ids: + padded = self.max_input_length - len(input_id) + extra_pad + if padded > 0: + input_id = [0] * padded + input_id + input_ids.append(input_id) + input_ids_padded_length.append(padded) + input_ids = np.concatenate(input_ids, dtype=np.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + elif isinstance(self.input_ids, list): + input_ids = self.input_ids[0] + input_ids_padded_length.append(extra_pad) + input_ids = [0] * extra_pad + input_ids + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + else: + self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) + input_ids_padded_length.append(extra_pad) + + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1) + torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0) + self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + + sliding_window = get_sliding_windows() + position_ids = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + blocks, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables, + ) + ): + next_chunk_length = input_length + + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + request_position_ids = F.pad( + request_position_ids, (input_ids_padded_length[i], 0), value=1 + ) + position_ids.append(request_position_ids) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slot_indices = torch.arange( + cache_length + cumulative_slot_tokens, + cache_length + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + slot_indices.append(request_slot_indices) + + # Update + cumulative_slot_tokens += len(request_slots) + + # Create tensor to slice into the kv tensor in prefill + # hpu need request_prefill_cache_indices to skip padding in kv cache + sliding_window = get_sliding_windows() + if sliding_window is None: + sliding_window = input_length + cumulative_length += input_ids_padded_length[i] + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + if ADAPTER_TO_INDEX: + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((next_chunk_length,), adapter_index) + ) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + + if not all_prefill_logprobs and not no_prefill_logprobs: + prefill_head_indices = [] + prefill_next_token_indices = [] + + # Cumulative length + cumulative_length = 0 + prefill_out_cumulative_length = 0 + + for i, ( + r, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.input_lengths, + self.prefilling_mask, + ) + ): + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length + + if len(self) > 1: + if position_ids: + position_ids = torch.cat(position_ids) + if slot_indices: + slot_indices = torch.cat(slot_indices) + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + if position_ids: + position_ids = position_ids[0] + if slot_indices: + slot_indices = slot_indices[0] + prefill_cache_indices = prefill_cache_indices[0] + + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + + self.prefill_cu_outlens = prefill_cu_outlens + self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices[prefill_cache_indices.to(device)] = True + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.cat(prefill_head_indices).to(device) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + input_ids_padded_length_tensor = torch.cumsum( + torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), + dim=-1, + ) + if self.prefill_head_indices is not None: + self.prefill_head_indices = ( + self.prefill_head_indices + input_ids_padded_length_tensor + ) + + if self.prefill_next_token_indices is not None: + self.prefill_next_token_indices = ( + self.prefill_next_token_indices + input_ids_padded_length_tensor + ) + + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] + + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, ) def __len__(self): @@ -937,23 +1265,14 @@ class FlashCausalLM(Model): # Deepseek V2 uses different QK and V dims. head_size: Optional[int] = None, skip_special_tokens: bool = True, + kv_cache_dtype: Optional[torch.dtype] = None, + support_chunking: bool = True, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - init_cpu_threads_env(rank_id=rank, world_size=world_size) - else: - raise NotImplementedError(f"{model_class} is only available on GPU") + + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype tokenizer = tokenizer_class.from_pretrained( model_id, @@ -991,7 +1310,7 @@ class FlashCausalLM(Model): weights_loader=weights_loader, ) - prefix = "" + prefix = None model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) @@ -1007,6 +1326,7 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() + self.config = config # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1034,25 +1354,14 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import ( - create_prefill_state, - create_decode_state, - create_prefill_with_paged_kv_state, - ) - - self.prefill_state = create_prefill_state(device=device) - self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( - device=device - ) - - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - + if htorch.utils.internal.is_lazy(): + htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False) + environment.set_model_config(self.config) + self.use_contiguous_pa = ( + os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" + ) super().__init__( model_id=model_id, model=model, @@ -1063,6 +1372,7 @@ class FlashCausalLM(Model): rank=rank, world_size=world_size, sliding_window=config.sliding_window, + support_chunking=support_chunking, ) @property @@ -1083,317 +1393,126 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "ipex" and device.type == "xpu": - x = 1 - else: - x = BLOCK_SIZE // element_size - - if ATTENTION in {"flashdecoding", "flashinfer"}: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - elif SYSTEM == "ipex" and device == torch.device("cpu"): - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - else: - self.kv_cache = [ - ( - torch.zeros( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.zeros( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = [max_s] * bs - prefix_lengths = [0] * bs - input_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - ) - prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).repeat(bs) - block_tables = block_tables.reshape((bs, max_bt)) - - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lengths, - ) - from text_generation_server.layers.attention.flashinfer import ( - create_decode_state_cuda_graphs, + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, ) + for _ in range(num_layers) + ] - block_tables_ptr = torch.zeros( - bs + 1, dtype=torch.int32, device=self.device - ) - last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) - state = create_decode_state_cuda_graphs( - device=input_ids.device, - block_tables=block_tables, - block_tables_ptr=block_tables_ptr, - last_page_len=last_page_len, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - else: - state = None - - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": self.kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths_tensor, - "prefix_lengths": prefix_lengths_tensor, - "state": state, - "graph": graph, - } - - torch.cuda.synchronize() - # Run once outside to warmup - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, - state=state, - prefix_lens_tensor=prefix_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - del seqlen - - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits - torch.cuda.synchronize() - - def warmup(self, batch: FlashCausalLMBatch): + def warmup( + self, + batch: FlashCausalLMBatch, + max_input_tokens: Optional[int], + max_total_tokens: Optional[int], + ): # The warmup batch is the biggest batch we could ever receive + self.kv_cache = [] empty_cache() + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + try: self.init_kv_cache( batch.num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) - max_bt = batch.max_blocks - max_s = max_bt * BLOCK_SIZE - if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - torch.cuda.tunable.tuning_enable(False) - _, batch, _ = self.generate_token(batch) - except torch.cuda.OutOfMemoryError as e: + batch_num_blocks = batch.num_blocks + + num_tokens = batch.to_pb().current_tokens + synchronize(self.device) + free_memory = get_free_memory( + self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM + ) + real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) + log_master( + logger.debug, + f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", + ) + + _, _batch, _ = self.generate_token([batch]) + except Exception: raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"Not enough memory to handle {num_tokens} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" - ) from e + ) synchronize(self.device) - - # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) - # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - batch_num_blocks = batch.num_blocks if batch is not None else 0 - + free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) + kv_memory = free_memory num_blocks = ( # Leave 5% for some wiggle room - int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + int(kv_memory // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + batch_num_blocks ) - del batch + log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") + if max_total_tokens is None: + max_total_tokens = sum(batch.cache_lengths) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + + del _batch, batch + self.kv_cache = [] + empty_cache() self.init_kv_cache( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - if SYSTEM == "rocm": - if ( - os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None - or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" - ): - torch.cuda.tunable.enable() + def warmup_prefill(self, prompt_len: int, bs: int): + input_ids = torch.zeros( + prompt_len, dtype=torch.int64, device=self.device + ).repeat(bs) + position_ids = torch.arange( + prompt_len, dtype=torch.int32, device=self.device + ).repeat(bs) + max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slot_acc = [] + for i in range(bs): + slots = [] + for b in block_tables[i]: + slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) + slot_acc.extend(slots[:prompt_len]) + slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) - if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": - torch.cuda.tunable.tuning_enable(True) - - if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None: - tuning_sequences = [ - int(val) - for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") - ] - elif CUDA_GRAPHS is not None: - tuning_sequences = CUDA_GRAPHS - else: - tuning_sequences = [1, 2, 3, 4, 5, 6, 7] - - tunableop_filepath = os.path.join( - HUGGINGFACE_HUB_CACHE, - f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", - ) - - log_master( - logger.info, - f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", - ) - - torch.cuda.tunable.set_filename( - tunableop_filepath, insert_device_ordinal=False - ) - - if os.path.isfile(tunableop_filepath): - log_master( - logger.info, - f"The file {tunableop_filepath} already exists and will be reused.", - ) - torch.cuda.tunable.read_file(tunableop_filepath) - - os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) - - for seqlen in tuning_sequences: - log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) - torch.cuda.tunable.write_file(tunableop_filepath) - if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": - torch.cuda.tunable.tuning_enable(False) - else: - log_master( - logger.info, - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", - ) - - if CUDA_GRAPHS: - try: - log_master( - logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" - ) - # Warmup cuda graphs - for bs in CUDA_GRAPHS: - if self.speculate is None or self.speculate + 1 <= bs: - self.cuda_graph_warmup(bs, max_s, max_bt) - except torch.cuda.OutOfMemoryError: - logger.exception("Decode cuda graph warmup failed") - else: - log_master( - logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." - ) - - return int(num_blocks * BLOCK_SIZE) - - def tunableop_warmup(self, seqlen: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 + input_lengths = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len ) - max_s = seqlen + cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + seqlen = Seqlen( input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=1, - max_k=seqlen, ) + lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1401,12 +1520,64 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, - seqlen=seqlen, slots=slots, - max_s=max_s, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=lm_head_indices, + adapter_data=None, + hpu_attention_meta=None, + ) + + def warmup_decode(self, bs: int, block_num: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + start=1, end=block_num + 1, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slots = [] + past_len = ( + len(block_tables[0]) * BLOCK_SIZE - 1 + ) # for decode, we only need to pass the past token + # fetch the last blocked to warmup block num + for i in range(bs): + slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + ) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + block_num = cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables_valid = [] + for i, bt in enumerate(block_tables.tolist()): + block_tables_valid.append(bt[0 : block_num[i]]) + + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables_valid, + bs, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), lm_head_indices=None, - prefill_cache_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, ) def forward( @@ -1421,7 +1592,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -1436,12 +1607,20 @@ class FlashCausalLM(Model): new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated + slot_indices = ( + batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + slots = batch.slots[slot_indices] + input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -1463,8 +1642,8 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -1473,105 +1652,48 @@ class FlashCausalLM(Model): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - - if cu_seqlen_prefill is not None or cuda_graph is None: - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, - ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - # assert block_tables.shape[0] >= slots.shape[0] - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - - # XXX: This is working only because block 0 is reserved for the healthcheck - # so it doesn't matter if we override it with bogus values. - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - cuda_graph["prefix_lengths"].zero_() - cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor - - with self._forward_context( - block_tables=cuda_graph["block_tables"], - cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - prefix_lens_tensor=cuda_graph["prefix_lengths"], - state=cuda_graph["state"], - ): - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + lm_head_indices=lm_head_indices, + # TODO not support adapter now, need the add in the future + adapter_data=None, + hpu_attention_meta=batch.hpu_attn_meta, + **kwargs, ) - logits = cuda_graph["logits"][:bs] return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batches: List[FlashCausalLMBatch] ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + if len(batches) > 1: + batch = self.batch_type.concatenate(batches) + else: + batch = batches[0] start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + else: + batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) prefill_logprobs = batch.prefill_next_token_indices is not None - # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta if batch.speculative_ids is not None: @@ -1611,13 +1733,23 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices + + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) + + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask speculate = get_speculate() ( @@ -1627,7 +1759,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculate, batch.speculative_ids, @@ -1638,85 +1770,110 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: + indices = batch.cu_seqlen_prefill[1:] - 1 + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ + indices + ] + else: + batch.position_ids = batch.position_ids[indices] - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - else: - prefill_logprobs = None - next_position_ids = batch.position_ids - - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - stopped = True + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ + indices + ] # Zipped iterator - iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) + iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, + batch.input_lengths, + batch.all_input_ids, + accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, + ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second - # one, we need to first do a GPU <-> CPU sync + # one, we need to first do a HPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - index = 0 - for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if prefill: + # Cumulative length + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) + cumulative_length = 0 + for i, ( + request, + prompt_length, + cache_length, + input_length, + all_input_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, + ) in enumerate(iterator): + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 ] + if len(batch) > 1: + prefill_tokens_indices[out_start_index:out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids - # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: - if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = ( - batch.input_ids[start_index + 1 : start_index + out_length] - ) - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] - - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] - index += 1 - + # If the device does not support triton, we copy one by one + if not request_is_prefilling: + # Only save tokens if we are done prefilling for this request + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] cumulative_length += input_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices + # These values can be updated without a HPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.speculative_ids = speculative_ids + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) + batch.slot_indices += accepted_ids - if prefill: + if prefill and prefill_logprobs: + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out + prefill_logprobs = torch.gather( + prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) + ) + # HPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # Does a HPU <-> CPU sync internally + if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1725,192 +1882,282 @@ class FlashCausalLM(Model): device=batch.adapter_meta.adapter_segments.device, ) - if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) - prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) - ) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - - # GPU <-> CPU sync + # HPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + next_cache_length : next_cache_length + next_chunk_length + ] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + start_decode = time.time_ns() + # Results + generations: List[Generation] = [] + stopped = True + # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.cache_lengths, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + current_prefilling_mask, + batch.prefilling_mask, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch index = 0 for i, ( request, + prompt_length, + cache_length, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, top_n_tokens, + request_was_prefilling, + request_is_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): - # Append next token to all tokens - next_token_texts = [] - left = 0 - - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: # Prefill - if prefill and request.prefill_logprobs: + if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[ + out_start_index:out_end_index + ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[ + cache_length + 1 : cache_length + input_length + 1 + ] + + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] + + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * ( + cache_length + 1 + ) + request_prefill_logprobs + prefill_token_ids = ( + all_input_ids[: cache_length + 1] + prefill_token_ids + ) - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ( - [float("nan")] * (len(prefix_ids) + 1) - ) + prefill_logprobs[out_start_index : out_end_index - 1] - prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, + prefill_logprob_tokens = Tokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens + + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - top_tokens = None + batch.prefill_logprob_tokens[i] = None - generation = Generation( - request.id, - prefill_tokens, - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length + else: + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 + # Append next token to all tokens + next_token_texts = [] + left = 0 - generations.append(generation) + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - # accept each new token for this specific request since we may - # have more than one new token per request with speculative decoding - for next_token_id in _next_token_ids: - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) - ) + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single( + i, next_token_id + ) + ) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + index += n_accepted_ids + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1921,83 +2168,14 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - def _forward_context( - self, - *, - block_tables: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, - state: Optional[Any] = None, - ) -> ContextManager: - if ATTENTION != "flashinfer": - return nullcontext() - - from text_generation_server.layers.attention.flashinfer import ( - use_decode_state, - use_prefill_with_paged_kv_state, - ) - - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) - - if cu_seqlen_prefill is not None: - return use_prefill_with_paged_kv_state( - state=( - state if state is not None else self.prefill_with_paged_kv_state - ), - # block_tables=block_tables_to_ragged( - # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, - # ), - block_tables=block_tables, - cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + prefix_lens_tensor, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) - else: - assert input_lengths_tensor is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, - block_tables=block_tables, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) - - -def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] -) -> torch.Tensor: - """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) - - total_len = sum(input_lengths) + sum(prefix_lens) - block_tables_ragged = torch.empty( - total_len, dtype=torch.int32, device=block_tables.device - ) - - offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length - block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] - offset += seq_len - - return block_tables_ragged diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py new file mode 100644 index 000000000..208ab3582 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -0,0 +1,489 @@ +import torch +from PIL import Image +from io import BytesIO + +from opentelemetry import trace +from typing import Iterable, Optional, Tuple, List, Type, Dict + +from transformers import PreTrainedTokenizerBase +from transformers.image_processing_utils import select_best_resolution +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, +) +from text_generation_server.models.globals import PREFIX_CACHING +from loguru import logger +from text_generation_server.utils.log import log_master +from transformers import AutoProcessor +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch + +tracer = trace.get_tracer(__name__) + +IDEFICS2_FAKE_TOKEN = "" +IDEFICS2_IMAGE_TOKEN = "" + +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 +def _prompt_split_image( + *, + image_seq_len: int, + image_rows: int, + image_cols: int, + fake_token_around_image: str, + image_token: str, + global_img_token: str, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (height, width). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_text_replacement(processor, image_input, config, image_id: int) -> str: + if config.model_type == "idefics2": + image_seq_len = 64 + image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" + if processor.image_processor.do_image_splitting: + image_str *= 5 + return image_str + if config.model_type == "idefics3": + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) + image_str = _prompt_split_image( + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) + return image_str + elif config.model_type == "llava_next": + height, width = image_input["image_sizes"][image_id] + num_features = get_number_of_features(height, width, config) + + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", + ) + return "" * num_features + + elif config.model_type == "paligemma": + return "" * config.text_config.num_image_tokens + elif config.model_type == "qwen2_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "qwen2_5_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "gemma3": + # TODO: get correct number of features via reviewing the Gemma3 architecture + # and calculating the number of image tokens + num_pads = 256 + padding = "" * num_pads + return f"\n\n{padding}\n\n" + else: + raise RuntimeError(f"Unknown config {config.model_type} for multimodal") + + +def image_text_replacement_fixup(config, text: str) -> str: + if config.model_type == "idefics2": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) + return text + + +def get_unpadded_features( + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = original_width / original_height + current_aspect_ratio: float = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = (original_height * current_width) // original_width + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = (original_width * current_height) // original_height + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +def get_number_of_features(height: int, width: int, config) -> int: + # From config + # Hardcoded for CLIP for now + # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + image_grid_pinpoints = config.image_grid_pinpoints + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + + assert image_size % patch_size == 0 + + npatches = image_size // patch_size + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + [height, width], + image_grid_pinpoints, + image_size, + ) + unpadded_features, newline_features = get_unpadded_features( + height, width, npatches, num_patch_height, num_patch_width + ) + # The base patch covers the entire image + base_features = npatches**2 + return unpadded_features + newline_features + base_features + + +class FlashVlmCausalLMBatch(FlashCausalLMBatch): + pixel_values: Optional[List[torch.Tensor]] + pixel_attention_mask: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] + image_grid_thw: Optional[torch.Tensor] + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + batch = super().filter(request_ids) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): + # Process images first. We need all of them so that the processor + # can make the image splits the same size. And we need the final + # sizes to insert correct number of image tokens. + images = [] + for r in requests: + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + pass + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the + # default warmup image is 20x20 + if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if image.width <= 20: + w = image.width * 2 + h = image.height * 2 + image = image.resize((w, h)) + + if config.model_type == "llava_next": + images.append(image) + elif config.model_type == "gemma3": + images.append(image) + else: + images.append([image]) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + if images: + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + image_inputs = processor.image_processor( + images, return_tensors="pt", **kwargs + ) + else: + image_inputs = None + + batch_tokenized_inputs = [] + max_length = 0 + image_id = 0 + for r in requests: + full_text = "" + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + full_text += image_text_replacement( + processor, image_inputs, config, image_id + ) + image_id += 1 + + full_text = image_text_replacement_fixup(config, full_text) + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) + + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashVlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to(device=device) + if "pixel_attention_mask" in image_inputs: + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + else: + batch.pixel_attention_mask = None + if "image_sizes" in image_inputs: + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.image_sizes = None + if "image_grid_thw" in image_inputs: + batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) + else: + batch.image_grid_thw = None + else: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + batch.image_grid_thw = None + return batch + + +class FlashVlmCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=FlashVlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, + **kwargs, + ) + + @property + def batch_type(self) -> Type[FlashVlmCausalLMBatch]: + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) + + def forward( + self, + batch: FlashVlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if position_ids.dim() == 1 and batch.prefilling: + position_ids = self.model.get_position_ids( + input_ids, batch.image_grid_thw + ) + batch.position_ids = position_ids + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, + **kwargs, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index 30a5d3da4..cd221e148 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -1,53 +1,31 @@ -import torch import os from typing import Dict, Optional from loguru import logger from text_generation_server.utils.log import log_master +REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.getenv("ATTENTION", "default") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } -PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer", "default"} +_expected = {"paged", "default"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: - raise RuntimeError("Prefix caching is only supported with flashinfer") - -MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 # This is overridden by the cli BLOCK_SIZE: int -if ATTENTION == "flashdecoding": - BLOCK_SIZE = 256 -elif ATTENTION == "flashinfer": - BLOCK_SIZE = 1 -else: - BLOCK_SIZE = 16 -# This is overridden by the cli -cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None: - try: - cuda_graphs = [int(item) for item in cuda_graphs.split(",")] - except Exception as e: - raise RuntimeError( - f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" - ) -else: - cuda_graphs = None +BLOCK_SIZE = 128 -CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. global MODEL_ID diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py index 9a7a6fe15..98d7352a8 100644 --- a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py @@ -34,9 +34,6 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.quantization import get_loader -from text_generation_server.utils.import_utils import SYSTEM - - tracer = trace.get_tracer(__name__) @@ -596,22 +593,8 @@ class IdeficsCausalLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - # 9b seems to work correctly enough in float16, but 80b seems - # to be really saturating for f16. - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype self.device, self.dtype = device, dtype config = AutoConfig.from_pretrained( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 9e19e1715..e034ed492 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -1,28 +1,30 @@ -from io import BytesIO -from PIL import Image import torch + +import numpy as np + from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request - +from io import BytesIO +from PIL import Image from dataclasses import dataclass from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import ( - block_tables_to_ragged, -) -from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION -from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + FlashVlmCausalLM, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch tracer = trace.get_tracer(__name__) @dataclass -class MllamaCausalLMBatch(VlmCausalLMBatch): +class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): image_indices: List[int] = 42 aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None @@ -158,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): config, dtype: torch.dtype, device: torch.device, - ) -> "VlmCausalLMBatch": + ) -> "FlashVlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( pb.requests, tokenizer, processor, config ) @@ -167,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( max=config.text_config.vocab_size - 1 ) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: @@ -187,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): return batch -class MllamaCausalLM(VlmCausalLM): +class FlashMllamaCausalLM(FlashVlmCausalLM): def forward( self, - batch: VlmCausalLMBatch, + batch: FlashMllamaCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward @@ -202,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -221,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -244,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -254,104 +263,46 @@ class MllamaCausalLM(VlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] - # Try to find an associated cuda graph - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - if ( - cu_seqlen_prefill is not None - or cuda_graph is None - # Only run cuda graphs when there's no images. - or batch.cross_attention_states is not None - ): - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, - ) + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) - if batch.pixel_values is not None: - cross_attention_states = self.model.vision_forward( - pixel_values=batch.pixel_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - ) - batch.cross_attention_states = cross_attention_states - - cross_attention_states = batch.cross_attention_states - - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, - adapter_data=adapter_data, - image_indices=batch.image_indices[:], - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, ) - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + batch.cross_attention_states = cross_attention_states - # Replay the graph - cuda_graph["graph"].replay() + cross_attention_states = batch.cross_attention_states - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, + lm_head_indices=lm_head_indices, + cross_attention_states=cross_attention_states, + # TODO list + adapter_data=None, + image_indices=batch.image_indices[:], + **kwargs, ) - logits = cuda_graph["logits"][:bs] + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py index 4fda22713..66c69bc1f 100644 --- a/backends/gaudi/server/text_generation_server/models/model.py +++ b/backends/gaudi/server/text_generation_server/models/model.py @@ -33,6 +33,7 @@ class Model(ABC): sliding_window: Optional[int] = None, speculate: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + support_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py index fe75570ea..e91aaed99 100644 --- a/backends/gaudi/server/text_generation_server/models/pali_gemma.py +++ b/backends/gaudi/server/text_generation_server/models/pali_gemma.py @@ -4,8 +4,8 @@ import torch import torch.distributed from opentelemetry import trace from typing import Iterable -from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, +from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, image_text_replacement, ) @@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request tracer = trace.get_tracer(__name__) -class PaliGemmaBatch(VlmCausalLMBatch): +class PaliGemmaBatch(FlashVlmCausalLMBatch): @classmethod def batch_tokenized_inputs( cls, requests: Iterable[Request], tokenizer, processor, config diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py index 04d4c28ba..0ee6ed167 100644 --- a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py +++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py @@ -10,7 +10,6 @@ from transformers import ( AutoConfig, ) from typing import Optional, Tuple, List, Type, Dict -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -555,20 +554,9 @@ class Seq2SeqLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + + device = torch.device("hpu") + dtype = torch.bfloat16 if dtype is None else dtype config = config_class.from_pretrained( model_id, @@ -600,7 +588,7 @@ class Seq2SeqLM(Model): aliases=aliases, weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + if config.quantize in ["awq", "gptq"]: weights._set_gptq_params(model_id, revision) model = model_class(config, weights) diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index 543b07e8e..709437d93 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") -if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) -else: - raise ValueError("MAX_BATCH_SIZE is not set") + PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = [] @@ -1467,6 +1463,12 @@ class VlmCausalLM(Model): batch = self.batch_from_pb(request.batch, is_warmup=True) max_input_tokens = request.max_input_tokens max_prefill_batch_size = batch.input_ids.shape[0] + max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") + if max_batch_size_str is not None: + MAX_BATCH_SIZE = int(max_batch_size_str) + else: + raise ValueError("MAX_BATCH_SIZE is not set") + try: # max prefill batch size warmup _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 674a8aed1..5a7d21175 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -18,22 +18,27 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, ATTENTION from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + ) VLM_BATCH_TYPES = { PaliGemmaBatch, VlmCausalLMBatch, - IdeficsCausalLMBatch, + FlashVlmCausalLMBatch, + FlashMllamaCausalLMBatch, } except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. @@ -103,14 +108,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens, max_input_tokens, max_total_tokens = ( - self.model.warmup(request) - ) + if ATTENTION == "paged": + set_max_prefill_tokens(request.max_prefill_tokens) + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.dtype, + self.model.device, + ) - # W/A for the skip tokenizer path - # We need to call make_tokenizer_optional after the warmup, - # because router is not aware of that feature - make_tokenizer_optional(self.model.tokenizer) + # Override default values with None for clearer semantics. + max_input_tokens = ( + request.max_input_tokens + if request.HasField("max_input_tokens") + else None + ) + max_total_tokens = ( + request.max_total_tokens + if request.HasField("max_total_tokens") + else None + ) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(batch, max_input_tokens, max_total_tokens) + ) + else: + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(request) + ) + + # W/A for the skip tokenizer path + # We need to call make_tokenizer_optional after the warmup, + # because router is not aware of that feature + make_tokenizer_optional(self.model.tokenizer) return generate_pb2.WarmupResponse( max_supported_total_tokens=max_supported_total_tokens, diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 0e9b97fb2..1c45713e8 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -1,15 +1,13 @@ import os import torch - +from torch.distributed import ProcessGroup from datetime import timedelta from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - -# CUDA memory fraction -MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) class FakeBarrier: @@ -17,10 +15,11 @@ class FakeBarrier: pass -class FakeGroup: +class FakeGroup(ProcessGroup): def __init__(self, rank, size): self._rank = rank self._size = size + super().__init__(rank, size) def allreduce(self, *args, **kwargs): return FakeBarrier() @@ -42,42 +41,11 @@ class FakeGroup: def rank(self): return self._rank + def _get_backend_name(self): + return "fake" + def initialize_torch_distributed(): - - world_size = int(os.getenv("WORLD_SIZE", "1")) - - options = None - if torch.cuda.is_available(): - from torch.distributed import ProcessGroupNCCL - - # Set the device id. - assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" - device = RANK % torch.cuda.device_count() - torch.cuda.set_device(device) - torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) - backend = "nccl" - options = ProcessGroupNCCL.Options() - options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) - elif torch.hpu.is_available(): - backend = "hccl" - n_hpus = torch.hpu.device_count() - if world_size > n_hpus: - raise ValueError( - f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})." - ) - else: - try: - import oneccl_bindings_for_pytorch # noqa: F401 - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" - options = None - if WORLD_SIZE == 1: return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE else: @@ -87,11 +55,10 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. torch.distributed.init_process_group( - backend=backend, + backend="hccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, + timeout=timedelta(seconds=120), ) else: logger.warning("torch.distributed is already initialized.") diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 782b4f15b..22560dd7a 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,75 +1,28 @@ import torch from loguru import logger -import os -import importlib.util +def get_hpu_free_memory(device, memory_fraction): + from habana_frameworks.torch.hpu import memory_stats - -def is_ipex_available(): - return importlib.util.find_spec("intel_extension_for_pytorch") is not None - - -def get_cuda_free_memory(device, memory_fraction): - total_free_memory, _ = torch.cuda.mem_get_info(device) - total_gpu_memory = torch.cuda.get_device_properties(device).total_memory - free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) - return free_memory - - -def get_xpu_free_memory(device, memory_fraction): - total_memory = torch.xpu.get_device_properties(device).total_memory device_id = device.index - memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) + mem_stats = memory_stats(device_id) + logger.info(f"mem_stats: {mem_stats}") + total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] free_memory = max( - 0, - int( - total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id) - ), + 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) ) return free_memory -def get_cpu_free_memory(device, memory_fraction): - import psutil - from text_generation_server.utils.dist import WORLD_SIZE - - mem = psutil.virtual_memory() - free_memory = int(mem.available * 0.95 / WORLD_SIZE) - return free_memory +def synchronize_hpu(device): + torch.hpu.synchronize() def noop(*args, **kwargs): pass -SYSTEM = None -if torch.version.hip is not None: - SYSTEM = "rocm" - empty_cache = torch.cuda.empty_cache - synchronize = torch.cuda.synchronize - get_free_memory = get_cuda_free_memory -elif torch.version.cuda is not None and torch.cuda.is_available(): - SYSTEM = "cuda" - empty_cache = torch.cuda.empty_cache - synchronize = torch.cuda.synchronize - get_free_memory = get_cuda_free_memory -elif is_ipex_available(): - SYSTEM = "ipex" - import intel_extension_for_pytorch # noqa: F401 - - if hasattr(torch, "xpu") and torch.xpu.is_available(): - empty_cache = torch.xpu.empty_cache - synchronize = torch.xpu.synchronize - get_free_memory = get_xpu_free_memory - else: - empty_cache = noop - synchronize = noop - get_free_memory = get_cpu_free_memory -else: - SYSTEM = "cpu" - - empty_cache = noop - synchronize = noop - get_free_memory = get_cpu_free_memory -logger.info(f"Detected system {SYSTEM}") +empty_cache = noop +synchronize = synchronize_hpu +get_free_memory = get_hpu_free_memory diff --git a/backends/gaudi/server/text_generation_server/utils/kernels.py b/backends/gaudi/server/text_generation_server/utils/kernels.py new file mode 100644 index 000000000..42745c716 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/kernels.py @@ -0,0 +1,22 @@ +import importlib + +from loguru import logger +from hf_kernels import load_kernel as hf_load_kernel + +from text_generation_server.utils.log import log_once + + +def load_kernel(*, module: str, repo_id: str): + """ + Load a kernel. First try to load it as the given module (e.g. for + local development), falling back to a locked Hub kernel. + """ + try: + m = importlib.import_module(module) + log_once(logger.info, f"Using local module for `{module}`") + return m + except ModuleNotFoundError: + return hf_load_kernel(repo_id=repo_id) + + +__all__ = ["load_kernel"] diff --git a/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 000000000..c227d30f5 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index ee561acc4..a8faf4a59 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -4,9 +4,7 @@ from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download -from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.utils.weights import ( - DefaultWeightsLoader, WeightsLoader, ) @@ -129,64 +127,13 @@ def get_loader( f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." ) - if can_use_gptq_marlin( + return GPTQWeightsLoader( bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ): - from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader - - return GPTQMarlinWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - else: - return GPTQWeightsLoader( - bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, - groupsize=quantizer_config.groupsize, - quant_method=quantizer_config.quant_method, - quantize=quantize, - sym=quantizer_config.sym, - ) - elif quantize == "bitsandbytes": - from text_generation_server.layers.bnb import BNBWeight - - return DefaultWeightsLoader(BNBWeight) - elif quantize == "bitsandbytes-fp4": - from text_generation_server.layers.bnb import BNBFP4Weight - - return DefaultWeightsLoader(BNBFP4Weight) - elif quantize == "bitsandbytes-nf4": - from text_generation_server.layers.bnb import BNBNF4Weight - - return DefaultWeightsLoader(BNBNF4Weight) - elif quantize == "eetq": - from text_generation_server.layers.eetq import EETQWeight - - return DefaultWeightsLoader(EETQWeight) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2WeightsLoader - - return Exl2WeightsLoader() - elif quantize == "marlin": - from text_generation_server.layers.marlin import MarlinWeightsLoader - - # TODO: improve check once we have one config type per quantize value - if not isinstance(quantizer_config, _QuantizerConfig): - raise ValueError( - f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." - ) - - return MarlinWeightsLoader( - bits=quantizer_config.bits, - is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index 75e01f7ce..acd598d7a 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Union, Type from safetensors import safe_open from dataclasses import dataclass -from text_generation_server.utils.import_utils import SYSTEM - class WeightsLoader(ABC): """ @@ -88,12 +86,9 @@ class UnquantizedWeight(Weight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): - from text_generation_server.layers.linear import FastLinear, FastLinearROCm + from text_generation_server.layers.linear import FastLinear - if SYSTEM == "rocm": - return FastLinearROCm(self.weight, bias) - else: - return FastLinear(self.weight, bias) + return FastLinear(self.weight, bias) class DefaultWeightsLoader(WeightsLoader): @@ -197,7 +192,7 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ - def _has_tensor(self, tensor_name: str): + def has_tensor(self, tensor_name: str): try: self.get_filename(tensor_name) except Exception: @@ -207,7 +202,9 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): + def get_tensor( + self, tensor_name: str, to_device: bool = True, to_dtype: bool = True + ) -> torch.Tensor: filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -218,6 +215,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, @@ -253,7 +251,8 @@ class Weights: # u4 which are disguised as int32. exl2 uses int16. # FP8 uses torch.float8_e4m3fn. if ( - tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + tensor.dtype + not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32) and to_dtype ): tensor = tensor.to(dtype=self.dtype) @@ -329,6 +328,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index e7f3d85a9..6da2b51da 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; - +use text_generation_router::usage_stats::Env; #[derive(Debug, Clone)] pub struct BlockAllocation { pub allocation_id: u64, @@ -141,6 +141,7 @@ pub struct SimpleAllocator { free_blocks: Vec, block_size: u32, window_size: Option, + is_hpu_device: bool, } impl SimpleAllocator { @@ -150,6 +151,7 @@ impl SimpleAllocator { // Block 0 is reserved for health checks free_blocks: (1..blocks).collect(), window_size, + is_hpu_device: Env::new().is_hpu_device(), } } } @@ -179,9 +181,15 @@ impl Allocator for SimpleAllocator { if required_blocks > self.free_blocks.len() as u32 { None } else { - let blocks = self + if self.is_hpu_device { + self.free_blocks.sort_by(|a, b| b.cmp(a)); + } + let mut blocks = self .free_blocks .split_off(self.free_blocks.len() - required_blocks as usize); + if self.is_hpu_device { + blocks.sort(); + } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d7ae11d54..d9056e413 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -28,8 +28,8 @@ impl Env { } } - pub fn is_hpu_device(&self) -> bool { - self.hpu_env != "N/A" + pub fn should_start_a_single_hpu_shard(&self) -> bool { + self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") } } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index acff85730..c169a78ce 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1559,7 +1559,7 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().is_hpu_device() { + if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); break; } @@ -1639,7 +1639,7 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().is_hpu_device() { + if env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("HPU detected, shard is ready"); break; } diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 353e9e378..a17aade9c 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -157,6 +157,7 @@ pub struct Env { docker_label: &'static str, nvidia_info: Option>, xpu_info: Option>, + hpu_info: Option>, system_env: SystemInfo, } @@ -289,6 +290,60 @@ impl XpuSmiInfo { } } +#[derive(Debug, Serialize, Clone)] +struct HpuSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + temperature: String, + utilization: String, + memory_total: String, + memory_free: String, + memory_used: String, + power_draw_instant: String, +} + +impl HpuSmiInfo { + fn new() -> Option> { + let output = Command::new("hl-smi") + .args([ + "--query-aip=name,bus_id,driver_version,temperature.aip,utilization.aip,memory.total,memory.free,memory.used,power.draw", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(HpuSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + temperature: record[3].to_string(), + utilization: record[4].to_string(), + memory_total: record[5].to_string(), + memory_free: record[6].to_string(), + memory_used: record[7].to_string(), + power_draw_instant: record[8].to_string(), + }); + } + + Some(infos) + } +} + #[derive(Serialize, Debug, Clone)] pub struct SystemInfo { cpu_count: usize, @@ -335,10 +390,14 @@ impl Env { system_env: SystemInfo::new(), nvidia_info: NvidiaSmiInfo::new(), xpu_info: XpuSmiInfo::new(), + hpu_info: HpuSmiInfo::new(), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } + pub fn is_hpu_device(&self) -> bool { + self.hpu_info.is_some() + } } pub fn is_container() -> io::Result { From fe56f760df3230dd9a24c4ed25741bb576904621 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 14 Apr 2025 17:18:43 +0200 Subject: [PATCH 2/3] Upgrading the python client deps (still deprecated, but used for integration-tests) --- clients/python/poetry.lock | 1872 ++++++++++++++++++--------------- clients/python/pyproject.toml | 12 +- 2 files changed, 1047 insertions(+), 837 deletions(-) diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index 148d99065..36e82f2a0 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -1,124 +1,131 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand. + +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +description = "Happy Eyeballs for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8"}, + {file = "aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558"}, +] [[package]] name = "aiohttp" -version = "3.8.5" +version = "3.11.16" description = "Async http client/server framework (asyncio)" optional = false -python-versions = ">=3.6" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"}, - {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"}, - {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"}, - {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"}, - {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"}, - {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"}, - {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"}, - {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"}, - {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"}, - {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"}, - {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"}, - {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"}, - {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"}, - {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"}, - {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"}, - {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"}, - {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"}, - {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"}, - {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"}, - {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"}, - {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"}, - {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"}, - {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"}, - {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"}, - {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"}, + {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb46bb0f24813e6cede6cc07b1961d4b04f331f7112a23b5e21f567da4ee50aa"}, + {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:54eb3aead72a5c19fad07219acd882c1643a1027fbcdefac9b502c267242f955"}, + {file = "aiohttp-3.11.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38bea84ee4fe24ebcc8edeb7b54bf20f06fd53ce4d2cc8b74344c5b9620597fd"}, + {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0666afbe984f6933fe72cd1f1c3560d8c55880a0bdd728ad774006eb4241ecd"}, + {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ba92a2d9ace559a0a14b03d87f47e021e4fa7681dc6970ebbc7b447c7d4b7cd"}, + {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ad1d59fd7114e6a08c4814983bb498f391c699f3c78712770077518cae63ff7"}, + {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b88a2bf26965f2015a771381624dd4b0839034b70d406dc74fd8be4cc053e3"}, + {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:576f5ca28d1b3276026f7df3ec841ae460e0fc3aac2a47cbf72eabcfc0f102e1"}, + {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a2a450bcce4931b295fc0848f384834c3f9b00edfc2150baafb4488c27953de6"}, + {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:37dcee4906454ae377be5937ab2a66a9a88377b11dd7c072df7a7c142b63c37c"}, + {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4d0c970c0d602b1017e2067ff3b7dac41c98fef4f7472ec2ea26fd8a4e8c2149"}, + {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:004511d3413737700835e949433536a2fe95a7d0297edd911a1e9705c5b5ea43"}, + {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:c15b2271c44da77ee9d822552201180779e5e942f3a71fb74e026bf6172ff287"}, + {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ad9509ffb2396483ceacb1eee9134724443ee45b92141105a4645857244aecc8"}, + {file = "aiohttp-3.11.16-cp310-cp310-win32.whl", hash = "sha256:634d96869be6c4dc232fc503e03e40c42d32cfaa51712aee181e922e61d74814"}, + {file = "aiohttp-3.11.16-cp310-cp310-win_amd64.whl", hash = "sha256:938f756c2b9374bbcc262a37eea521d8a0e6458162f2a9c26329cc87fdf06534"}, + {file = "aiohttp-3.11.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8cb0688a8d81c63d716e867d59a9ccc389e97ac7037ebef904c2b89334407180"}, + {file = "aiohttp-3.11.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ad1fb47da60ae1ddfb316f0ff16d1f3b8e844d1a1e154641928ea0583d486ed"}, + {file = "aiohttp-3.11.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:df7db76400bf46ec6a0a73192b14c8295bdb9812053f4fe53f4e789f3ea66bbb"}, + {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc3a145479a76ad0ed646434d09216d33d08eef0d8c9a11f5ae5cdc37caa3540"}, + {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d007aa39a52d62373bd23428ba4a2546eed0e7643d7bf2e41ddcefd54519842c"}, + {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6ddd90d9fb4b501c97a4458f1c1720e42432c26cb76d28177c5b5ad4e332601"}, + {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a2f451849e6b39e5c226803dcacfa9c7133e9825dcefd2f4e837a2ec5a3bb98"}, + {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8df6612df74409080575dca38a5237282865408016e65636a76a2eb9348c2567"}, + {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:78e6e23b954644737e385befa0deb20233e2dfddf95dd11e9db752bdd2a294d3"}, + {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:696ef00e8a1f0cec5e30640e64eca75d8e777933d1438f4facc9c0cdf288a810"}, + {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e3538bc9fe1b902bef51372462e3d7c96fce2b566642512138a480b7adc9d508"}, + {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3ab3367bb7f61ad18793fea2ef71f2d181c528c87948638366bf1de26e239183"}, + {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:56a3443aca82abda0e07be2e1ecb76a050714faf2be84256dae291182ba59049"}, + {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:61c721764e41af907c9d16b6daa05a458f066015abd35923051be8705108ed17"}, + {file = "aiohttp-3.11.16-cp311-cp311-win32.whl", hash = "sha256:3e061b09f6fa42997cf627307f220315e313ece74907d35776ec4373ed718b86"}, + {file = "aiohttp-3.11.16-cp311-cp311-win_amd64.whl", hash = "sha256:745f1ed5e2c687baefc3c5e7b4304e91bf3e2f32834d07baaee243e349624b24"}, + {file = "aiohttp-3.11.16-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:911a6e91d08bb2c72938bc17f0a2d97864c531536b7832abee6429d5296e5b27"}, + {file = "aiohttp-3.11.16-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac13b71761e49d5f9e4d05d33683bbafef753e876e8e5a7ef26e937dd766713"}, + {file = "aiohttp-3.11.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fd36c119c5d6551bce374fcb5c19269638f8d09862445f85a5a48596fd59f4bb"}, + {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d489d9778522fbd0f8d6a5c6e48e3514f11be81cb0a5954bdda06f7e1594b321"}, + {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:69a2cbd61788d26f8f1e626e188044834f37f6ae3f937bd9f08b65fc9d7e514e"}, + {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd464ba806e27ee24a91362ba3621bfc39dbbb8b79f2e1340201615197370f7c"}, + {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce63ae04719513dd2651202352a2beb9f67f55cb8490c40f056cea3c5c355ce"}, + {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09b00dd520d88eac9d1768439a59ab3d145065c91a8fab97f900d1b5f802895e"}, + {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7f6428fee52d2bcf96a8aa7b62095b190ee341ab0e6b1bcf50c615d7966fd45b"}, + {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:13ceac2c5cdcc3f64b9015710221ddf81c900c5febc505dbd8f810e770011540"}, + {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fadbb8f1d4140825069db3fedbbb843290fd5f5bc0a5dbd7eaf81d91bf1b003b"}, + {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6a792ce34b999fbe04a7a71a90c74f10c57ae4c51f65461a411faa70e154154e"}, + {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f4065145bf69de124accdd17ea5f4dc770da0a6a6e440c53f6e0a8c27b3e635c"}, + {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa73e8c2656a3653ae6c307b3f4e878a21f87859a9afab228280ddccd7369d71"}, + {file = "aiohttp-3.11.16-cp312-cp312-win32.whl", hash = "sha256:f244b8e541f414664889e2c87cac11a07b918cb4b540c36f7ada7bfa76571ea2"}, + {file = "aiohttp-3.11.16-cp312-cp312-win_amd64.whl", hash = "sha256:23a15727fbfccab973343b6d1b7181bfb0b4aa7ae280f36fd2f90f5476805682"}, + {file = "aiohttp-3.11.16-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a3814760a1a700f3cfd2f977249f1032301d0a12c92aba74605cfa6ce9f78489"}, + {file = "aiohttp-3.11.16-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9b751a6306f330801665ae69270a8a3993654a85569b3469662efaad6cf5cc50"}, + {file = "aiohttp-3.11.16-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ad497f38a0d6c329cb621774788583ee12321863cd4bd9feee1effd60f2ad133"}, + {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca37057625693d097543bd88076ceebeb248291df9d6ca8481349efc0b05dcd0"}, + {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5abcbba9f4b463a45c8ca8b7720891200658f6f46894f79517e6cd11f3405ca"}, + {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f420bfe862fb357a6d76f2065447ef6f484bc489292ac91e29bc65d2d7a2c84d"}, + {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58ede86453a6cf2d6ce40ef0ca15481677a66950e73b0a788917916f7e35a0bb"}, + {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fdec0213244c39973674ca2a7f5435bf74369e7d4e104d6c7473c81c9bcc8c4"}, + {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:72b1b03fb4655c1960403c131740755ec19c5898c82abd3961c364c2afd59fe7"}, + {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:780df0d837276276226a1ff803f8d0fa5f8996c479aeef52eb040179f3156cbd"}, + {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ecdb8173e6c7aa09eee342ac62e193e6904923bd232e76b4157ac0bfa670609f"}, + {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a6db7458ab89c7d80bc1f4e930cc9df6edee2200127cfa6f6e080cf619eddfbd"}, + {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:2540ddc83cc724b13d1838026f6a5ad178510953302a49e6d647f6e1de82bc34"}, + {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3b4e6db8dc4879015b9955778cfb9881897339c8fab7b3676f8433f849425913"}, + {file = "aiohttp-3.11.16-cp313-cp313-win32.whl", hash = "sha256:493910ceb2764f792db4dc6e8e4b375dae1b08f72e18e8f10f18b34ca17d0979"}, + {file = "aiohttp-3.11.16-cp313-cp313-win_amd64.whl", hash = "sha256:42864e70a248f5f6a49fdaf417d9bc62d6e4d8ee9695b24c5916cb4bb666c802"}, + {file = "aiohttp-3.11.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bbcba75fe879ad6fd2e0d6a8d937f34a571f116a0e4db37df8079e738ea95c71"}, + {file = "aiohttp-3.11.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:87a6e922b2b2401e0b0cf6b976b97f11ec7f136bfed445e16384fbf6fd5e8602"}, + {file = "aiohttp-3.11.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ccf10f16ab498d20e28bc2b5c1306e9c1512f2840f7b6a67000a517a4b37d5ee"}, + {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb3d0cc5cdb926090748ea60172fa8a213cec728bd6c54eae18b96040fcd6227"}, + {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d07502cc14ecd64f52b2a74ebbc106893d9a9717120057ea9ea1fd6568a747e7"}, + {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:776c8e959a01e5e8321f1dec77964cb6101020a69d5a94cd3d34db6d555e01f7"}, + {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0902e887b0e1d50424112f200eb9ae3dfed6c0d0a19fc60f633ae5a57c809656"}, + {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87fd812899aa78252866ae03a048e77bd11b80fb4878ce27c23cade239b42b2"}, + {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0a950c2eb8ff17361abd8c85987fd6076d9f47d040ebffce67dce4993285e973"}, + {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:c10d85e81d0b9ef87970ecbdbfaeec14a361a7fa947118817fcea8e45335fa46"}, + {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7951decace76a9271a1ef181b04aa77d3cc309a02a51d73826039003210bdc86"}, + {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:14461157d8426bcb40bd94deb0450a6fa16f05129f7da546090cebf8f3123b0f"}, + {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9756d9b9d4547e091f99d554fbba0d2a920aab98caa82a8fb3d3d9bee3c9ae85"}, + {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:87944bd16b7fe6160607f6a17808abd25f17f61ae1e26c47a491b970fb66d8cb"}, + {file = "aiohttp-3.11.16-cp39-cp39-win32.whl", hash = "sha256:92b7ee222e2b903e0a4b329a9943d432b3767f2d5029dbe4ca59fb75223bbe2e"}, + {file = "aiohttp-3.11.16-cp39-cp39-win_amd64.whl", hash = "sha256:17ae4664031aadfbcb34fd40ffd90976671fa0c0286e6c4113989f78bebab37a"}, + {file = "aiohttp-3.11.16.tar.gz", hash = "sha256:16f8a2c9538c14a557b4d309ed4d0a7c60f0253e8ed7b6c9a2859a7582f8b1b8"}, ] [package.dependencies] +aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = ">=4.0.0a3,<5.0" -asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""} +async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" -charset-normalizer = ">=2.0,<4.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" -typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} -yarl = ">=1.0,<2.0" +propcache = ">=0.2.0" +yarl = ">=1.17.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns", "cchardet"] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] [[package]] name = "aiosignal" -version = "1.3.1" +version = "1.3.2" description = "aiosignal: a list of registered asynchronous callbacks" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, - {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, + {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"}, + {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, ] [package.dependencies] @@ -126,167 +133,161 @@ frozenlist = ">=1.1.0" [[package]] name = "annotated-types" -version = "0.5.0" +version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main"] files = [ - {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"}, - {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"}, + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "async-timeout" -version = "4.0.3" +version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.11\"" files = [ - {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, - {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, -] - -[package.dependencies] -typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""} - -[[package]] -name = "asynctest" -version = "0.13.0" -description = "Enhance the standard unittest package with features for testing asyncio libraries" -optional = false -python-versions = ">=3.5" -files = [ - {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"}, - {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"}, -] - -[[package]] -name = "atomicwrites" -version = "1.4.1" -description = "Atomic file writes." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"}, + {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, + {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, ] [[package]] name = "attrs" -version = "23.1.0" +version = "25.3.0" description = "Classes Without Boilerplate" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main"] files = [ - {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, - {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, + {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, + {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, ] -[package.dependencies] -importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} - [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[docs,tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "certifi" -version = "2023.7.22" +version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ - {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, - {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, + {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, + {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, ] [[package]] name = "charset-normalizer" -version = "3.2.0" +version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.7" +groups = ["main"] files = [ - {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, - {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407"}, + {file = "charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487"}, + {file = "charset_normalizer-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-win32.whl", hash = "sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e"}, + {file = "charset_normalizer-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-win32.whl", hash = "sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5"}, + {file = "charset_normalizer-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765"}, + {file = "charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85"}, + {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"}, ] [[package]] @@ -295,78 +296,84 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [[package]] name = "coverage" -version = "7.2.7" +version = "7.8.0" description = "Code coverage measurement for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["dev"] files = [ - {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"}, - {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"}, - {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"}, - {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"}, - {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"}, - {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"}, - {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"}, - {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"}, - {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"}, - {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"}, - {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"}, - {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"}, - {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"}, - {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"}, - {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"}, - {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"}, - {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"}, - {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"}, - {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"}, - {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"}, - {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"}, - {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"}, - {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"}, - {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"}, + {file = "coverage-7.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2931f66991175369859b5fd58529cd4b73582461877ecfd859b6549869287ffe"}, + {file = "coverage-7.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52a523153c568d2c0ef8826f6cc23031dc86cffb8c6aeab92c4ff776e7951b28"}, + {file = "coverage-7.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c8a5c139aae4c35cbd7cadca1df02ea8cf28a911534fc1b0456acb0b14234f3"}, + {file = "coverage-7.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a26c0c795c3e0b63ec7da6efded5f0bc856d7c0b24b2ac84b4d1d7bc578d676"}, + {file = "coverage-7.8.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:821f7bcbaa84318287115d54becb1915eece6918136c6f91045bb84e2f88739d"}, + {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a321c61477ff8ee705b8a5fed370b5710c56b3a52d17b983d9215861e37b642a"}, + {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ed2144b8a78f9d94d9515963ed273d620e07846acd5d4b0a642d4849e8d91a0c"}, + {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:042e7841a26498fff7a37d6fda770d17519982f5b7d8bf5278d140b67b61095f"}, + {file = "coverage-7.8.0-cp310-cp310-win32.whl", hash = "sha256:f9983d01d7705b2d1f7a95e10bbe4091fabc03a46881a256c2787637b087003f"}, + {file = "coverage-7.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5a570cd9bd20b85d1a0d7b009aaf6c110b52b5755c17be6962f8ccd65d1dbd23"}, + {file = "coverage-7.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7ac22a0bb2c7c49f441f7a6d46c9c80d96e56f5a8bc6972529ed43c8b694e27"}, + {file = "coverage-7.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf13d564d310c156d1c8e53877baf2993fb3073b2fc9f69790ca6a732eb4bfea"}, + {file = "coverage-7.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5761c70c017c1b0d21b0815a920ffb94a670c8d5d409d9b38857874c21f70d7"}, + {file = "coverage-7.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5ff52d790c7e1628241ffbcaeb33e07d14b007b6eb00a19320c7b8a7024c040"}, + {file = "coverage-7.8.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d39fc4817fd67b3915256af5dda75fd4ee10621a3d484524487e33416c6f3543"}, + {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b44674870709017e4b4036e3d0d6c17f06a0e6d4436422e0ad29b882c40697d2"}, + {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8f99eb72bf27cbb167b636eb1726f590c00e1ad375002230607a844d9e9a2318"}, + {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b571bf5341ba8c6bc02e0baeaf3b061ab993bf372d982ae509807e7f112554e9"}, + {file = "coverage-7.8.0-cp311-cp311-win32.whl", hash = "sha256:e75a2ad7b647fd8046d58c3132d7eaf31b12d8a53c0e4b21fa9c4d23d6ee6d3c"}, + {file = "coverage-7.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:3043ba1c88b2139126fc72cb48574b90e2e0546d4c78b5299317f61b7f718b78"}, + {file = "coverage-7.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bbb5cc845a0292e0c520656d19d7ce40e18d0e19b22cb3e0409135a575bf79fc"}, + {file = "coverage-7.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4dfd9a93db9e78666d178d4f08a5408aa3f2474ad4d0e0378ed5f2ef71640cb6"}, + {file = "coverage-7.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f017a61399f13aa6d1039f75cd467be388d157cd81f1a119b9d9a68ba6f2830d"}, + {file = "coverage-7.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0915742f4c82208ebf47a2b154a5334155ed9ef9fe6190674b8a46c2fb89cb05"}, + {file = "coverage-7.8.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a40fcf208e021eb14b0fac6bdb045c0e0cab53105f93ba0d03fd934c956143a"}, + {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a1f406a8e0995d654b2ad87c62caf6befa767885301f3b8f6f73e6f3c31ec3a6"}, + {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:77af0f6447a582fdc7de5e06fa3757a3ef87769fbb0fdbdeba78c23049140a47"}, + {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f2d32f95922927186c6dbc8bc60df0d186b6edb828d299ab10898ef3f40052fe"}, + {file = "coverage-7.8.0-cp312-cp312-win32.whl", hash = "sha256:769773614e676f9d8e8a0980dd7740f09a6ea386d0f383db6821df07d0f08545"}, + {file = "coverage-7.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e5d2b9be5b0693cf21eb4ce0ec8d211efb43966f6657807f6859aab3814f946b"}, + {file = "coverage-7.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ac46d0c2dd5820ce93943a501ac5f6548ea81594777ca585bf002aa8854cacd"}, + {file = "coverage-7.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:771eb7587a0563ca5bb6f622b9ed7f9d07bd08900f7589b4febff05f469bea00"}, + {file = "coverage-7.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42421e04069fb2cbcbca5a696c4050b84a43b05392679d4068acbe65449b5c64"}, + {file = "coverage-7.8.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:554fec1199d93ab30adaa751db68acec2b41c5602ac944bb19187cb9a41a8067"}, + {file = "coverage-7.8.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aaeb00761f985007b38cf463b1d160a14a22c34eb3f6a39d9ad6fc27cb73008"}, + {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:581a40c7b94921fffd6457ffe532259813fc68eb2bdda60fa8cc343414ce3733"}, + {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f319bae0321bc838e205bf9e5bc28f0a3165f30c203b610f17ab5552cff90323"}, + {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04bfec25a8ef1c5f41f5e7e5c842f6b615599ca8ba8391ec33a9290d9d2db3a3"}, + {file = "coverage-7.8.0-cp313-cp313-win32.whl", hash = "sha256:dd19608788b50eed889e13a5d71d832edc34fc9dfce606f66e8f9f917eef910d"}, + {file = "coverage-7.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:a9abbccd778d98e9c7e85038e35e91e67f5b520776781d9a1e2ee9d400869487"}, + {file = "coverage-7.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:18c5ae6d061ad5b3e7eef4363fb27a0576012a7447af48be6c75b88494c6cf25"}, + {file = "coverage-7.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:95aa6ae391a22bbbce1b77ddac846c98c5473de0372ba5c463480043a07bff42"}, + {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e013b07ba1c748dacc2a80e69a46286ff145935f260eb8c72df7185bf048f502"}, + {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d766a4f0e5aa1ba056ec3496243150698dc0481902e2b8559314368717be82b1"}, + {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad80e6b4a0c3cb6f10f29ae4c60e991f424e6b14219d46f1e7d442b938ee68a4"}, + {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b87eb6fc9e1bb8f98892a2458781348fa37e6925f35bb6ceb9d4afd54ba36c73"}, + {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d1ba00ae33be84066cfbe7361d4e04dec78445b2b88bdb734d0d1cbab916025a"}, + {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f3c38e4e5ccbdc9198aecc766cedbb134b2d89bf64533973678dfcf07effd883"}, + {file = "coverage-7.8.0-cp313-cp313t-win32.whl", hash = "sha256:379fe315e206b14e21db5240f89dc0774bdd3e25c3c58c2c733c99eca96f1ada"}, + {file = "coverage-7.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2e4b6b87bb0c846a9315e3ab4be2d52fac905100565f4b92f02c445c8799e257"}, + {file = "coverage-7.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa260de59dfb143af06dcf30c2be0b200bed2a73737a8a59248fcb9fa601ef0f"}, + {file = "coverage-7.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96121edfa4c2dfdda409877ea8608dd01de816a4dc4a0523356067b305e4e17a"}, + {file = "coverage-7.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b8af63b9afa1031c0ef05b217faa598f3069148eeee6bb24b79da9012423b82"}, + {file = "coverage-7.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89b1f4af0d4afe495cd4787a68e00f30f1d15939f550e869de90a86efa7e0814"}, + {file = "coverage-7.8.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94ec0be97723ae72d63d3aa41961a0b9a6f5a53ff599813c324548d18e3b9e8c"}, + {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a1d96e780bdb2d0cbb297325711701f7c0b6f89199a57f2049e90064c29f6bd"}, + {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f1d8a2a57b47142b10374902777e798784abf400a004b14f1b0b9eaf1e528ba4"}, + {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cf60dd2696b457b710dd40bf17ad269d5f5457b96442f7f85722bdb16fa6c899"}, + {file = "coverage-7.8.0-cp39-cp39-win32.whl", hash = "sha256:be945402e03de47ba1872cd5236395e0f4ad635526185a930735f66710e1bd3f"}, + {file = "coverage-7.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:90e7fbc6216ecaffa5a880cdc9c77b7418c1dcb166166b78dbc630d07f278cc3"}, + {file = "coverage-7.8.0-pp39.pp310.pp311-none-any.whl", hash = "sha256:b8194fb8e50d556d5849753de991d390c5a1edeeba50f68e3a9253fbd8bf8ccd"}, + {file = "coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7"}, + {file = "coverage-7.8.0.tar.gz", hash = "sha256:7a3d62b3b03b4b6fd41a085f3574874cf946cb4604d2b4d3e8dca8cd570ca501"}, ] [package.dependencies] @@ -376,112 +383,150 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 toml = ["tomli"] [[package]] -name = "filelock" -version = "3.12.2" -description = "A platform independent file lock." +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version < \"3.11\"" files = [ - {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, - {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, ] [package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.18.0" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"}, + {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] +typing = ["typing-extensions (>=4.12.2)"] [[package]] name = "frozenlist" -version = "1.3.3" +version = "1.5.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main"] files = [ - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"}, - {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"}, - {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"}, - {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"}, - {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"}, - {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"}, - {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"}, - {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"}, - {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"}, - {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"}, - {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"}, + {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, + {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, + {file = "frozenlist-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15538c0cbf0e4fa11d1e3a71f823524b0c46299aed6e10ebb4c2089abd8c3bec"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e79225373c317ff1e35f210dd5f1344ff31066ba8067c307ab60254cd3a78ad5"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9272fa73ca71266702c4c3e2d4a28553ea03418e591e377a03b8e3659d94fa76"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:498524025a5b8ba81695761d78c8dd7382ac0b052f34e66939c42df860b8ff17"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92b5278ed9d50fe610185ecd23c55d8b307d75ca18e94c0e7de328089ac5dcba"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f3c8c1dacd037df16e85227bac13cca58c30da836c6f936ba1df0c05d046d8d"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f2ac49a9bedb996086057b75bf93538240538c6d9b38e57c82d51f75a73409d2"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e66cc454f97053b79c2ab09c17fbe3c825ea6b4de20baf1be28919460dd7877f"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3ba5f9a0dfed20337d3e966dc359784c9f96503674c2faf015f7fe8e96798c"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6321899477db90bdeb9299ac3627a6a53c7399c8cd58d25da094007402b039ab"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76e4753701248476e6286f2ef492af900ea67d9706a0155335a40ea21bf3b2f5"}, + {file = "frozenlist-1.5.0-cp310-cp310-win32.whl", hash = "sha256:977701c081c0241d0955c9586ffdd9ce44f7a7795df39b9151cd9a6fd0ce4cfb"}, + {file = "frozenlist-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:189f03b53e64144f90990d29a27ec4f7997d91ed3d01b51fa39d2dbe77540fd4"}, + {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fd74520371c3c4175142d02a976aee0b4cb4a7cc912a60586ffd8d5929979b30"}, + {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2f3f7a0fbc219fb4455264cae4d9f01ad41ae6ee8524500f381de64ffaa077d5"}, + {file = "frozenlist-1.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f47c9c9028f55a04ac254346e92977bf0f166c483c74b4232bee19a6697e4778"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0996c66760924da6e88922756d99b47512a71cfd45215f3570bf1e0b694c206a"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2fe128eb4edeabe11896cb6af88fca5346059f6c8d807e3b910069f39157869"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a8ea951bbb6cacd492e3948b8da8c502a3f814f5d20935aae74b5df2b19cf3d"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de537c11e4aa01d37db0d403b57bd6f0546e71a82347a97c6a9f0dcc532b3a45"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c2623347b933fcb9095841f1cc5d4ff0b278addd743e0e966cb3d460278840d"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cee6798eaf8b1416ef6909b06f7dc04b60755206bddc599f52232606e18179d3"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f5f9da7f5dbc00a604fe74aa02ae7c98bcede8a3b8b9666f9f86fc13993bc71a"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:90646abbc7a5d5c7c19461d2e3eeb76eb0b204919e6ece342feb6032c9325ae9"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:bdac3c7d9b705d253b2ce370fde941836a5f8b3c5c2b8fd70940a3ea3af7f4f2"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03d33c2ddbc1816237a67f66336616416e2bbb6beb306e5f890f2eb22b959cdf"}, + {file = "frozenlist-1.5.0-cp311-cp311-win32.whl", hash = "sha256:237f6b23ee0f44066219dae14c70ae38a63f0440ce6750f868ee08775073f942"}, + {file = "frozenlist-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:0cc974cc93d32c42e7b0f6cf242a6bd941c57c61b618e78b6c0a96cb72788c1d"}, + {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:31115ba75889723431aa9a4e77d5f398f5cf976eea3bdf61749731f62d4a4a21"}, + {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7437601c4d89d070eac8323f121fcf25f88674627505334654fd027b091db09d"}, + {file = "frozenlist-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7948140d9f8ece1745be806f2bfdf390127cf1a763b925c4a805c603df5e697e"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feeb64bc9bcc6b45c6311c9e9b99406660a9c05ca8a5b30d14a78555088b0b3a"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:683173d371daad49cffb8309779e886e59c2f369430ad28fe715f66d08d4ab1a"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d57d8f702221405a9d9b40f9da8ac2e4a1a8b5285aac6100f3393675f0a85ee"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c72000fbcc35b129cb09956836c7d7abf78ab5416595e4857d1cae8d6251a6"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000a77d6034fbad9b6bb880f7ec073027908f1b40254b5d6f26210d2dab1240e"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d7f5a50342475962eb18b740f3beecc685a15b52c91f7d975257e13e029eca9"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:87f724d055eb4785d9be84e9ebf0f24e392ddfad00b3fe036e43f489fafc9039"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6e9080bb2fb195a046e5177f10d9d82b8a204c0736a97a153c2466127de87784"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b93d7aaa36c966fa42efcaf716e6b3900438632a626fb09c049f6a2f09fc631"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:52ef692a4bc60a6dd57f507429636c2af8b6046db8b31b18dac02cbc8f507f7f"}, + {file = "frozenlist-1.5.0-cp312-cp312-win32.whl", hash = "sha256:29d94c256679247b33a3dc96cce0f93cbc69c23bf75ff715919332fdbb6a32b8"}, + {file = "frozenlist-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:8969190d709e7c48ea386db202d708eb94bdb29207a1f269bab1196ce0dcca1f"}, + {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a1a048f9215c90973402e26c01d1cff8a209e1f1b53f72b95c13db61b00f953"}, + {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dd47a5181ce5fcb463b5d9e17ecfdb02b678cca31280639255ce9d0e5aa67af0"}, + {file = "frozenlist-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1431d60b36d15cda188ea222033eec8e0eab488f39a272461f2e6d9e1a8e63c2"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6482a5851f5d72767fbd0e507e80737f9c8646ae7fd303def99bfe813f76cf7f"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44c49271a937625619e862baacbd037a7ef86dd1ee215afc298a417ff3270608"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12f78f98c2f1c2429d42e6a485f433722b0061d5c0b0139efa64f396efb5886b"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce3aa154c452d2467487765e3adc730a8c153af77ad84096bc19ce19a2400840"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b7dc0c4338e6b8b091e8faf0db3168a37101943e687f373dce00959583f7439"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45e0896250900b5aa25180f9aec243e84e92ac84bd4a74d9ad4138ef3f5c97de"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:561eb1c9579d495fddb6da8959fd2a1fca2c6d060d4113f5844b433fc02f2641"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:df6e2f325bfee1f49f81aaac97d2aa757c7646534a06f8f577ce184afe2f0a9e"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:140228863501b44b809fb39ec56b5d4071f4d0aa6d216c19cbb08b8c5a7eadb9"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7707a25d6a77f5d27ea7dc7d1fc608aa0a478193823f88511ef5e6b8a48f9d03"}, + {file = "frozenlist-1.5.0-cp313-cp313-win32.whl", hash = "sha256:31a9ac2b38ab9b5a8933b693db4939764ad3f299fcaa931a3e605bc3460e693c"}, + {file = "frozenlist-1.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:11aabdd62b8b9c4b84081a3c246506d1cddd2dd93ff0ad53ede5defec7886b28"}, + {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:dd94994fc91a6177bfaafd7d9fd951bc8689b0a98168aa26b5f543868548d3ca"}, + {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2d0da8bbec082bf6bf18345b180958775363588678f64998c2b7609e34719b10"}, + {file = "frozenlist-1.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:73f2e31ea8dd7df61a359b731716018c2be196e5bb3b74ddba107f694fbd7604"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:828afae9f17e6de596825cf4228ff28fbdf6065974e5ac1410cecc22f699d2b3"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1577515d35ed5649d52ab4319db757bb881ce3b2b796d7283e6634d99ace307"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2150cc6305a2c2ab33299453e2968611dacb970d2283a14955923062c8d00b10"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a72b7a6e3cd2725eff67cd64c8f13335ee18fc3c7befc05aed043d24c7b9ccb9"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c16d2fa63e0800723139137d667e1056bee1a1cf7965153d2d104b62855e9b99"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:17dcc32fc7bda7ce5875435003220a457bcfa34ab7924a49a1c19f55b6ee185c"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:97160e245ea33d8609cd2b8fd997c850b56db147a304a262abc2b3be021a9171"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f1e6540b7fa044eee0bb5111ada694cf3dc15f2b0347ca125ee9ca984d5e9e6e"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:91d6c171862df0a6c61479d9724f22efb6109111017c87567cfeb7b5d1449fdf"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c1fac3e2ace2eb1052e9f7c7db480818371134410e1f5c55d65e8f3ac6d1407e"}, + {file = "frozenlist-1.5.0-cp38-cp38-win32.whl", hash = "sha256:b97f7b575ab4a8af9b7bc1d2ef7f29d3afee2226bd03ca3875c16451ad5a7723"}, + {file = "frozenlist-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:374ca2dabdccad8e2a76d40b1d037f5bd16824933bf7bcea3e59c891fd4a0923"}, + {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9bbcdfaf4af7ce002694a4e10a0159d5a8d20056a12b05b45cea944a4953f972"}, + {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1893f948bf6681733aaccf36c5232c231e3b5166d607c5fa77773611df6dc336"}, + {file = "frozenlist-1.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b5e23253bb709ef57a8e95e6ae48daa9ac5f265637529e4ce6b003a37b2621f"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f253985bb515ecd89629db13cb58d702035ecd8cfbca7d7a7e29a0e6d39af5f"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04a5c6babd5e8fb7d3c871dc8b321166b80e41b637c31a995ed844a6139942b6"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9fe0f1c29ba24ba6ff6abf688cb0b7cf1efab6b6aa6adc55441773c252f7411"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d72559fa19babe2ccd920273e767c96a49b9d3d38badd7c91a0fdeda8ea08"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15b731db116ab3aedec558573c1a5eec78822b32292fe4f2f0345b7f697745c2"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:366d8f93e3edfe5a918c874702f78faac300209a4d5bf38352b2c1bdc07a766d"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1b96af8c582b94d381a1c1f51ffaedeb77c821c690ea5f01da3d70a487dd0a9b"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c03eff4a41bd4e38415cbed054bbaff4a075b093e2394b6915dca34a40d1e38b"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:50cf5e7ee9b98f22bdecbabf3800ae78ddcc26e4a435515fc72d97903e8488e0"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e76bfbc72353269c44e0bc2cfe171900fbf7f722ad74c9a7b638052afe6a00c"}, + {file = "frozenlist-1.5.0-cp39-cp39-win32.whl", hash = "sha256:666534d15ba8f0fda3f53969117383d5dc021266b3c1a42c9ec4855e4b58b9d3"}, + {file = "frozenlist-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:5c28f4b5dbef8a0d8aad0d4de24d1e9e981728628afaf4ea0792f5d0939372f0"}, + {file = "frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3"}, + {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"}, ] [[package]] name = "fsspec" -version = "2023.1.0" +version = "2025.3.2" description = "File-system specification" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "fsspec-2023.1.0-py3-none-any.whl", hash = "sha256:b833e2e541e9e8cde0ab549414187871243177feb3d344f9d27b25a93f5d8139"}, - {file = "fsspec-2023.1.0.tar.gz", hash = "sha256:fbae7f20ff801eb5f7d0bedf81f25c787c0dfac5e982d98fa3884a9cde2b5411"}, + {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"}, + {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"}, ] [package.extras] @@ -489,8 +534,10 @@ abfs = ["adlfs"] adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] +dev = ["pre-commit", "ruff"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] -entrypoints = ["importlib-metadata"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] gcs = ["gcsfs"] git = ["pygit2"] @@ -498,30 +545,33 @@ github = ["requests"] gs = ["gcsfs"] gui = ["panel"] hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] libarchive = ["libarchive-c"] oci = ["ocifs"] s3 = ["s3fs"] sftp = ["paramiko"] smb = ["smbprotocol"] ssh = ["paramiko"] +test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] +test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] [[package]] name = "huggingface-hub" -version = "0.16.4" +version = "0.30.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" +groups = ["main"] files = [ - {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, - {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, + {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"}, + {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"}, ] [package.dependencies] filelock = "*" -fsspec = "*" -importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +fsspec = ">=2023.5.0" packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -529,314 +579,429 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -inference = ["aiohttp", "pydantic"] -quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +hf-xet = ["hf-xet (>=0.1.4)"] +inference = ["aiohttp"] +quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] -torch = ["torch"] -typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors[torch]", "torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] [[package]] name = "idna" -version = "3.4" +version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false -python-versions = ">=3.5" +python-versions = ">=3.6" +groups = ["main"] files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, ] -[[package]] -name = "importlib-metadata" -version = "6.7.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.7" -files = [ - {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"}, - {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"}, -] - -[package.dependencies] -typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} -zipp = ">=0.5" - [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] [[package]] name = "iniconfig" -version = "2.0.0" +version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["dev"] files = [ - {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, - {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, + {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, + {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, ] [[package]] name = "multidict" -version = "6.0.4" +version = "6.4.3" description = "multidict implementation" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, - {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, - {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, - {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, - {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, - {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, - {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, - {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, - {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, - {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, - {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, - {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, - {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, + {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32a998bd8a64ca48616eac5a8c1cc4fa38fb244a3facf2eeb14abe186e0f6cc5"}, + {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a54ec568f1fc7f3c313c2f3b16e5db346bf3660e1309746e7fccbbfded856188"}, + {file = "multidict-6.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a7be07e5df178430621c716a63151165684d3e9958f2bbfcb644246162007ab7"}, + {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b128dbf1c939674a50dd0b28f12c244d90e5015e751a4f339a96c54f7275e291"}, + {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b9cb19dfd83d35b6ff24a4022376ea6e45a2beba8ef3f0836b8a4b288b6ad685"}, + {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3cf62f8e447ea2c1395afa289b332e49e13d07435369b6f4e41f887db65b40bf"}, + {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:909f7d43ff8f13d1adccb6a397094adc369d4da794407f8dd592c51cf0eae4b1"}, + {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bb8f8302fbc7122033df959e25777b0b7659b1fd6bcb9cb6bed76b5de67afef"}, + {file = "multidict-6.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:224b79471b4f21169ea25ebc37ed6f058040c578e50ade532e2066562597b8a9"}, + {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a7bd27f7ab3204f16967a6f899b3e8e9eb3362c0ab91f2ee659e0345445e0078"}, + {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:99592bd3162e9c664671fd14e578a33bfdba487ea64bcb41d281286d3c870ad7"}, + {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a62d78a1c9072949018cdb05d3c533924ef8ac9bcb06cbf96f6d14772c5cd451"}, + {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ccdde001578347e877ca4f629450973c510e88e8865d5aefbcb89b852ccc666"}, + {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:eccb67b0e78aa2e38a04c5ecc13bab325a43e5159a181a9d1a6723db913cbb3c"}, + {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8b6fcf6054fc4114a27aa865f8840ef3d675f9316e81868e0ad5866184a6cba5"}, + {file = "multidict-6.4.3-cp310-cp310-win32.whl", hash = "sha256:f92c7f62d59373cd93bc9969d2da9b4b21f78283b1379ba012f7ee8127b3152e"}, + {file = "multidict-6.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:b57e28dbc031d13916b946719f213c494a517b442d7b48b29443e79610acd887"}, + {file = "multidict-6.4.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f6f19170197cc29baccd33ccc5b5d6a331058796485857cf34f7635aa25fb0cd"}, + {file = "multidict-6.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2882bf27037eb687e49591690e5d491e677272964f9ec7bc2abbe09108bdfb8"}, + {file = "multidict-6.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fbf226ac85f7d6b6b9ba77db4ec0704fde88463dc17717aec78ec3c8546c70ad"}, + {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e329114f82ad4b9dd291bef614ea8971ec119ecd0f54795109976de75c9a852"}, + {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1f4e0334d7a555c63f5c8952c57ab6f1c7b4f8c7f3442df689fc9f03df315c08"}, + {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:740915eb776617b57142ce0bb13b7596933496e2f798d3d15a20614adf30d229"}, + {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255dac25134d2b141c944b59a0d2f7211ca12a6d4779f7586a98b4b03ea80508"}, + {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4e8535bd4d741039b5aad4285ecd9b902ef9e224711f0b6afda6e38d7ac02c7"}, + {file = "multidict-6.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c433a33be000dd968f5750722eaa0991037be0be4a9d453eba121774985bc8"}, + {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4eb33b0bdc50acd538f45041f5f19945a1f32b909b76d7b117c0c25d8063df56"}, + {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:75482f43465edefd8a5d72724887ccdcd0c83778ded8f0cb1e0594bf71736cc0"}, + {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ce5b3082e86aee80b3925ab4928198450d8e5b6466e11501fe03ad2191c6d777"}, + {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e413152e3212c4d39f82cf83c6f91be44bec9ddea950ce17af87fbf4e32ca6b2"}, + {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8aac2eeff69b71f229a405c0a4b61b54bade8e10163bc7b44fcd257949620618"}, + {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ab583ac203af1d09034be41458feeab7863c0635c650a16f15771e1386abf2d7"}, + {file = "multidict-6.4.3-cp311-cp311-win32.whl", hash = "sha256:1b2019317726f41e81154df636a897de1bfe9228c3724a433894e44cd2512378"}, + {file = "multidict-6.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:43173924fa93c7486402217fab99b60baf78d33806af299c56133a3755f69589"}, + {file = "multidict-6.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f1c2f58f08b36f8475f3ec6f5aeb95270921d418bf18f90dffd6be5c7b0e676"}, + {file = "multidict-6.4.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:26ae9ad364fc61b936fb7bf4c9d8bd53f3a5b4417142cd0be5c509d6f767e2f1"}, + {file = "multidict-6.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:659318c6c8a85f6ecfc06b4e57529e5a78dfdd697260cc81f683492ad7e9435a"}, + {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e1eb72c741fd24d5a28242ce72bb61bc91f8451877131fa3fe930edb195f7054"}, + {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3cd06d88cb7398252284ee75c8db8e680aa0d321451132d0dba12bc995f0adcc"}, + {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4543d8dc6470a82fde92b035a92529317191ce993533c3c0c68f56811164ed07"}, + {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:30a3ebdc068c27e9d6081fca0e2c33fdf132ecea703a72ea216b81a66860adde"}, + {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b038f10e23f277153f86f95c777ba1958bcd5993194fda26a1d06fae98b2f00c"}, + {file = "multidict-6.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c605a2b2dc14282b580454b9b5d14ebe0668381a3a26d0ac39daa0ca115eb2ae"}, + {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8bd2b875f4ca2bb527fe23e318ddd509b7df163407b0fb717df229041c6df5d3"}, + {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c2e98c840c9c8e65c0e04b40c6c5066c8632678cd50c8721fdbcd2e09f21a507"}, + {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:66eb80dd0ab36dbd559635e62fba3083a48a252633164857a1d1684f14326427"}, + {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c23831bdee0a2a3cf21be057b5e5326292f60472fb6c6f86392bbf0de70ba731"}, + {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1535cec6443bfd80d028052e9d17ba6ff8a5a3534c51d285ba56c18af97e9713"}, + {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3b73e7227681f85d19dec46e5b881827cd354aabe46049e1a61d2f9aaa4e285a"}, + {file = "multidict-6.4.3-cp312-cp312-win32.whl", hash = "sha256:8eac0c49df91b88bf91f818e0a24c1c46f3622978e2c27035bfdca98e0e18124"}, + {file = "multidict-6.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:11990b5c757d956cd1db7cb140be50a63216af32cd6506329c2c59d732d802db"}, + {file = "multidict-6.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a76534263d03ae0cfa721fea40fd2b5b9d17a6f85e98025931d41dc49504474"}, + {file = "multidict-6.4.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:805031c2f599eee62ac579843555ed1ce389ae00c7e9f74c2a1b45e0564a88dd"}, + {file = "multidict-6.4.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c56c179839d5dcf51d565132185409d1d5dd8e614ba501eb79023a6cab25576b"}, + {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c64f4ddb3886dd8ab71b68a7431ad4aa01a8fa5be5b11543b29674f29ca0ba3"}, + {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3002a856367c0b41cad6784f5b8d3ab008eda194ed7864aaa58f65312e2abcac"}, + {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d75e621e7d887d539d6e1d789f0c64271c250276c333480a9e1de089611f790"}, + {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:995015cf4a3c0d72cbf453b10a999b92c5629eaf3a0c3e1efb4b5c1f602253bb"}, + {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b0fabae7939d09d7d16a711468c385272fa1b9b7fb0d37e51143585d8e72e0"}, + {file = "multidict-6.4.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:61ed4d82f8a1e67eb9eb04f8587970d78fe7cddb4e4d6230b77eda23d27938f9"}, + {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:062428944a8dc69df9fdc5d5fc6279421e5f9c75a9ee3f586f274ba7b05ab3c8"}, + {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:b90e27b4674e6c405ad6c64e515a505c6d113b832df52fdacb6b1ffd1fa9a1d1"}, + {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7d50d4abf6729921e9613d98344b74241572b751c6b37feed75fb0c37bd5a817"}, + {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:43fe10524fb0a0514be3954be53258e61d87341008ce4914f8e8b92bee6f875d"}, + {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:236966ca6c472ea4e2d3f02f6673ebfd36ba3f23159c323f5a496869bc8e47c9"}, + {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:422a5ec315018e606473ba1f5431e064cf8b2a7468019233dcf8082fabad64c8"}, + {file = "multidict-6.4.3-cp313-cp313-win32.whl", hash = "sha256:f901a5aace8e8c25d78960dcc24c870c8d356660d3b49b93a78bf38eb682aac3"}, + {file = "multidict-6.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:1c152c49e42277bc9a2f7b78bd5fa10b13e88d1b0328221e7aef89d5c60a99a5"}, + {file = "multidict-6.4.3-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:be8751869e28b9c0d368d94f5afcb4234db66fe8496144547b4b6d6a0645cfc6"}, + {file = "multidict-6.4.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0d4b31f8a68dccbcd2c0ea04f0e014f1defc6b78f0eb8b35f2265e8716a6df0c"}, + {file = "multidict-6.4.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:032efeab3049e37eef2ff91271884303becc9e54d740b492a93b7e7266e23756"}, + {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e78006af1a7c8a8007e4f56629d7252668344442f66982368ac06522445e375"}, + {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:daeac9dd30cda8703c417e4fddccd7c4dc0c73421a0b54a7da2713be125846be"}, + {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f6f90700881438953eae443a9c6f8a509808bc3b185246992c4233ccee37fea"}, + {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f84627997008390dd15762128dcf73c3365f4ec0106739cde6c20a07ed198ec8"}, + {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3307b48cd156153b117c0ea54890a3bdbf858a5b296ddd40dc3852e5f16e9b02"}, + {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ead46b0fa1dcf5af503a46e9f1c2e80b5d95c6011526352fa5f42ea201526124"}, + {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:1748cb2743bedc339d63eb1bca314061568793acd603a6e37b09a326334c9f44"}, + {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:acc9fa606f76fc111b4569348cc23a771cb52c61516dcc6bcef46d612edb483b"}, + {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:31469d5832b5885adeb70982e531ce86f8c992334edd2f2254a10fa3182ac504"}, + {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:ba46b51b6e51b4ef7bfb84b82f5db0dc5e300fb222a8a13b8cd4111898a869cf"}, + {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:389cfefb599edf3fcfd5f64c0410da686f90f5f5e2c4d84e14f6797a5a337af4"}, + {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:64bc2bbc5fba7b9db5c2c8d750824f41c6994e3882e6d73c903c2afa78d091e4"}, + {file = "multidict-6.4.3-cp313-cp313t-win32.whl", hash = "sha256:0ecdc12ea44bab2807d6b4a7e5eef25109ab1c82a8240d86d3c1fc9f3b72efd5"}, + {file = "multidict-6.4.3-cp313-cp313t-win_amd64.whl", hash = "sha256:7146a8742ea71b5d7d955bffcef58a9e6e04efba704b52a460134fefd10a8208"}, + {file = "multidict-6.4.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5427a2679e95a642b7f8b0f761e660c845c8e6fe3141cddd6b62005bd133fc21"}, + {file = "multidict-6.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24a8caa26521b9ad09732972927d7b45b66453e6ebd91a3c6a46d811eeb7349b"}, + {file = "multidict-6.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6b5a272bc7c36a2cd1b56ddc6bff02e9ce499f9f14ee4a45c45434ef083f2459"}, + {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edf74dc5e212b8c75165b435c43eb0d5e81b6b300a938a4eb82827119115e840"}, + {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:9f35de41aec4b323c71f54b0ca461ebf694fb48bec62f65221f52e0017955b39"}, + {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae93e0ff43b6f6892999af64097b18561691ffd835e21a8348a441e256592e1f"}, + {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e3929269e9d7eff905d6971d8b8c85e7dbc72c18fb99c8eae6fe0a152f2e343"}, + {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb6214fe1750adc2a1b801a199d64b5a67671bf76ebf24c730b157846d0e90d2"}, + {file = "multidict-6.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d79cf5c0c6284e90f72123f4a3e4add52d6c6ebb4a9054e88df15b8d08444c6"}, + {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2427370f4a255262928cd14533a70d9738dfacadb7563bc3b7f704cc2360fc4e"}, + {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:fbd8d737867912b6c5f99f56782b8cb81f978a97b4437a1c476de90a3e41c9a1"}, + {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0ee1bf613c448997f73fc4efb4ecebebb1c02268028dd4f11f011f02300cf1e8"}, + {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:578568c4ba5f2b8abd956baf8b23790dbfdc953e87d5b110bce343b4a54fc9e7"}, + {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:a059ad6b80de5b84b9fa02a39400319e62edd39d210b4e4f8c4f1243bdac4752"}, + {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:dd53893675b729a965088aaadd6a1f326a72b83742b056c1065bdd2e2a42b4df"}, + {file = "multidict-6.4.3-cp39-cp39-win32.whl", hash = "sha256:abcfed2c4c139f25c2355e180bcc077a7cae91eefbb8b3927bb3f836c9586f1f"}, + {file = "multidict-6.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:b1b389ae17296dd739015d5ddb222ee99fd66adeae910de21ac950e00979d897"}, + {file = "multidict-6.4.3-py3-none-any.whl", hash = "sha256:59fe01ee8e2a1e8ceb3f6dbb216b09c8d9f4ef1c22c4fc825d045a147fa2ebc9"}, + {file = "multidict-6.4.3.tar.gz", hash = "sha256:3ada0b058c9f213c5f95ba301f922d402ac234f1111a7d8fd70f1b99f3c281ec"}, ] +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} + [[package]] name = "packaging" -version = "23.1" +version = "24.2" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main", "dev"] files = [ - {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, - {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] [[package]] name = "pluggy" -version = "1.2.0" +version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["dev"] files = [ - {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, - {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] -[package.dependencies] -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} - [package.extras] dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] [[package]] -name = "py" -version = "1.11.0" -description = "library with cross-python path, ini-parsing, io, code, log facilities" +name = "propcache" +version = "0.3.1" +description = "Accelerated property cache" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, - {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, + {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f27785888d2fdd918bc36de8b8739f2d6c791399552333721b58193f68ea3e98"}, + {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4e89cde74154c7b5957f87a355bb9c8ec929c167b59c83d90654ea36aeb6180"}, + {file = "propcache-0.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:730178f476ef03d3d4d255f0c9fa186cb1d13fd33ffe89d39f2cda4da90ceb71"}, + {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:967a8eec513dbe08330f10137eacb427b2ca52118769e82ebcfcab0fba92a649"}, + {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b9145c35cc87313b5fd480144f8078716007656093d23059e8993d3a8fa730f"}, + {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e64e948ab41411958670f1093c0a57acfdc3bee5cf5b935671bbd5313bcf229"}, + {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:319fa8765bfd6a265e5fa661547556da381e53274bc05094fc9ea50da51bfd46"}, + {file = "propcache-0.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66d8ccbc902ad548312b96ed8d5d266d0d2c6d006fd0f66323e9d8f2dd49be7"}, + {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2d219b0dbabe75e15e581fc1ae796109b07c8ba7d25b9ae8d650da582bed01b0"}, + {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:cd6a55f65241c551eb53f8cf4d2f4af33512c39da5d9777694e9d9c60872f519"}, + {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9979643ffc69b799d50d3a7b72b5164a2e97e117009d7af6dfdd2ab906cb72cd"}, + {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4cf9e93a81979f1424f1a3d155213dc928f1069d697e4353edb8a5eba67c6259"}, + {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2fce1df66915909ff6c824bbb5eb403d2d15f98f1518e583074671a30fe0c21e"}, + {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4d0dfdd9a2ebc77b869a0b04423591ea8823f791293b527dc1bb896c1d6f1136"}, + {file = "propcache-0.3.1-cp310-cp310-win32.whl", hash = "sha256:1f6cc0ad7b4560e5637eb2c994e97b4fa41ba8226069c9277eb5ea7101845b42"}, + {file = "propcache-0.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:47ef24aa6511e388e9894ec16f0fbf3313a53ee68402bc428744a367ec55b833"}, + {file = "propcache-0.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7f30241577d2fef2602113b70ef7231bf4c69a97e04693bde08ddab913ba0ce5"}, + {file = "propcache-0.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43593c6772aa12abc3af7784bff4a41ffa921608dd38b77cf1dfd7f5c4e71371"}, + {file = "propcache-0.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a75801768bbe65499495660b777e018cbe90c7980f07f8aa57d6be79ea6f71da"}, + {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6f1324db48f001c2ca26a25fa25af60711e09b9aaf4b28488602776f4f9a744"}, + {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cdb0f3e1eb6dfc9965d19734d8f9c481b294b5274337a8cb5cb01b462dcb7e0"}, + {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1eb34d90aac9bfbced9a58b266f8946cb5935869ff01b164573a7634d39fbcb5"}, + {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f35c7070eeec2cdaac6fd3fe245226ed2a6292d3ee8c938e5bb645b434c5f256"}, + {file = "propcache-0.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b23c11c2c9e6d4e7300c92e022046ad09b91fd00e36e83c44483df4afa990073"}, + {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3e19ea4ea0bf46179f8a3652ac1426e6dcbaf577ce4b4f65be581e237340420d"}, + {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bd39c92e4c8f6cbf5f08257d6360123af72af9f4da75a690bef50da77362d25f"}, + {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0313e8b923b3814d1c4a524c93dfecea5f39fa95601f6a9b1ac96cd66f89ea0"}, + {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e861ad82892408487be144906a368ddbe2dc6297074ade2d892341b35c59844a"}, + {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:61014615c1274df8da5991a1e5da85a3ccb00c2d4701ac6f3383afd3ca47ab0a"}, + {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:71ebe3fe42656a2328ab08933d420df5f3ab121772eef78f2dc63624157f0ed9"}, + {file = "propcache-0.3.1-cp311-cp311-win32.whl", hash = "sha256:58aa11f4ca8b60113d4b8e32d37e7e78bd8af4d1a5b5cb4979ed856a45e62005"}, + {file = "propcache-0.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:9532ea0b26a401264b1365146c440a6d78269ed41f83f23818d4b79497aeabe7"}, + {file = "propcache-0.3.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f78eb8422acc93d7b69964012ad7048764bb45a54ba7a39bb9e146c72ea29723"}, + {file = "propcache-0.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:89498dd49c2f9a026ee057965cdf8192e5ae070ce7d7a7bd4b66a8e257d0c976"}, + {file = "propcache-0.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09400e98545c998d57d10035ff623266927cb784d13dd2b31fd33b8a5316b85b"}, + {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa8efd8c5adc5a2c9d3b952815ff8f7710cefdcaf5f2c36d26aff51aeca2f12f"}, + {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2fe5c910f6007e716a06d269608d307b4f36e7babee5f36533722660e8c4a70"}, + {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a0ab8cf8cdd2194f8ff979a43ab43049b1df0b37aa64ab7eca04ac14429baeb7"}, + {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:563f9d8c03ad645597b8d010ef4e9eab359faeb11a0a2ac9f7b4bc8c28ebef25"}, + {file = "propcache-0.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb6e0faf8cb6b4beea5d6ed7b5a578254c6d7df54c36ccd3d8b3eb00d6770277"}, + {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1c5c7ab7f2bb3f573d1cb921993006ba2d39e8621019dffb1c5bc94cdbae81e8"}, + {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:050b571b2e96ec942898f8eb46ea4bfbb19bd5502424747e83badc2d4a99a44e"}, + {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e1c4d24b804b3a87e9350f79e2371a705a188d292fd310e663483af6ee6718ee"}, + {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:e4fe2a6d5ce975c117a6bb1e8ccda772d1e7029c1cca1acd209f91d30fa72815"}, + {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:feccd282de1f6322f56f6845bf1207a537227812f0a9bf5571df52bb418d79d5"}, + {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ec314cde7314d2dd0510c6787326bbffcbdc317ecee6b7401ce218b3099075a7"}, + {file = "propcache-0.3.1-cp312-cp312-win32.whl", hash = "sha256:7d2d5a0028d920738372630870e7d9644ce437142197f8c827194fca404bf03b"}, + {file = "propcache-0.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:88c423efef9d7a59dae0614eaed718449c09a5ac79a5f224a8b9664d603f04a3"}, + {file = "propcache-0.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f1528ec4374617a7a753f90f20e2f551121bb558fcb35926f99e3c42367164b8"}, + {file = "propcache-0.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dc1915ec523b3b494933b5424980831b636fe483d7d543f7afb7b3bf00f0c10f"}, + {file = "propcache-0.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a110205022d077da24e60b3df8bcee73971be9575dec5573dd17ae5d81751111"}, + {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d249609e547c04d190e820d0d4c8ca03ed4582bcf8e4e160a6969ddfb57b62e5"}, + {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ced33d827625d0a589e831126ccb4f5c29dfdf6766cac441d23995a65825dcb"}, + {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4114c4ada8f3181af20808bedb250da6bae56660e4b8dfd9cd95d4549c0962f7"}, + {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:975af16f406ce48f1333ec5e912fe11064605d5c5b3f6746969077cc3adeb120"}, + {file = "propcache-0.3.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a34aa3a1abc50740be6ac0ab9d594e274f59960d3ad253cd318af76b996dd654"}, + {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9cec3239c85ed15bfaded997773fdad9fb5662b0a7cbc854a43f291eb183179e"}, + {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:05543250deac8e61084234d5fc54f8ebd254e8f2b39a16b1dce48904f45b744b"}, + {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5cb5918253912e088edbf023788de539219718d3b10aef334476b62d2b53de53"}, + {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f3bbecd2f34d0e6d3c543fdb3b15d6b60dd69970c2b4c822379e5ec8f6f621d5"}, + {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aca63103895c7d960a5b9b044a83f544b233c95e0dcff114389d64d762017af7"}, + {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a0a9898fdb99bf11786265468571e628ba60af80dc3f6eb89a3545540c6b0ef"}, + {file = "propcache-0.3.1-cp313-cp313-win32.whl", hash = "sha256:3a02a28095b5e63128bcae98eb59025924f121f048a62393db682f049bf4ac24"}, + {file = "propcache-0.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:813fbb8b6aea2fc9659815e585e548fe706d6f663fa73dff59a1677d4595a037"}, + {file = "propcache-0.3.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a444192f20f5ce8a5e52761a031b90f5ea6288b1eef42ad4c7e64fef33540b8f"}, + {file = "propcache-0.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0fbe94666e62ebe36cd652f5fc012abfbc2342de99b523f8267a678e4dfdee3c"}, + {file = "propcache-0.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f011f104db880f4e2166bcdcf7f58250f7a465bc6b068dc84c824a3d4a5c94dc"}, + {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e584b6d388aeb0001d6d5c2bd86b26304adde6d9bb9bfa9c4889805021b96de"}, + {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a17583515a04358b034e241f952f1715243482fc2c2945fd99a1b03a0bd77d6"}, + {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5aed8d8308215089c0734a2af4f2e95eeb360660184ad3912686c181e500b2e7"}, + {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8e309ff9a0503ef70dc9a0ebd3e69cf7b3894c9ae2ae81fc10943c37762458"}, + {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b655032b202028a582d27aeedc2e813299f82cb232f969f87a4fde491a233f11"}, + {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9f64d91b751df77931336b5ff7bafbe8845c5770b06630e27acd5dbb71e1931c"}, + {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:19a06db789a4bd896ee91ebc50d059e23b3639c25d58eb35be3ca1cbe967c3bf"}, + {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:bef100c88d8692864651b5f98e871fb090bd65c8a41a1cb0ff2322db39c96c27"}, + {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:87380fb1f3089d2a0b8b00f006ed12bd41bd858fabfa7330c954c70f50ed8757"}, + {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e474fc718e73ba5ec5180358aa07f6aded0ff5f2abe700e3115c37d75c947e18"}, + {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:17d1c688a443355234f3c031349da69444be052613483f3e4158eef751abcd8a"}, + {file = "propcache-0.3.1-cp313-cp313t-win32.whl", hash = "sha256:359e81a949a7619802eb601d66d37072b79b79c2505e6d3fd8b945538411400d"}, + {file = "propcache-0.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e7fb9a84c9abbf2b2683fa3e7b0d7da4d8ecf139a1c635732a8bda29c5214b0e"}, + {file = "propcache-0.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ed5f6d2edbf349bd8d630e81f474d33d6ae5d07760c44d33cd808e2f5c8f4ae6"}, + {file = "propcache-0.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:668ddddc9f3075af019f784456267eb504cb77c2c4bd46cc8402d723b4d200bf"}, + {file = "propcache-0.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0c86e7ceea56376216eba345aa1fc6a8a6b27ac236181f840d1d7e6a1ea9ba5c"}, + {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83be47aa4e35b87c106fc0c84c0fc069d3f9b9b06d3c494cd404ec6747544894"}, + {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:27c6ac6aa9fc7bc662f594ef380707494cb42c22786a558d95fcdedb9aa5d035"}, + {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64a956dff37080b352c1c40b2966b09defb014347043e740d420ca1eb7c9b908"}, + {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82de5da8c8893056603ac2d6a89eb8b4df49abf1a7c19d536984c8dd63f481d5"}, + {file = "propcache-0.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c3c3a203c375b08fd06a20da3cf7aac293b834b6f4f4db71190e8422750cca5"}, + {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b303b194c2e6f171cfddf8b8ba30baefccf03d36a4d9cab7fd0bb68ba476a3d7"}, + {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:916cd229b0150129d645ec51614d38129ee74c03293a9f3f17537be0029a9641"}, + {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a461959ead5b38e2581998700b26346b78cd98540b5524796c175722f18b0294"}, + {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:069e7212890b0bcf9b2be0a03afb0c2d5161d91e1bf51569a64f629acc7defbf"}, + {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ef2e4e91fb3945769e14ce82ed53007195e616a63aa43b40fb7ebaaf907c8d4c"}, + {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8638f99dca15b9dff328fb6273e09f03d1c50d9b6512f3b65a4154588a7595fe"}, + {file = "propcache-0.3.1-cp39-cp39-win32.whl", hash = "sha256:6f173bbfe976105aaa890b712d1759de339d8a7cef2fc0a1714cc1a1e1c47f64"}, + {file = "propcache-0.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:603f1fe4144420374f1a69b907494c3acbc867a581c2d49d4175b0de7cc64566"}, + {file = "propcache-0.3.1-py3-none-any.whl", hash = "sha256:9a8ecf38de50a7f518c21568c80f985e776397b902f1ce0b01f799aba1608b40"}, + {file = "propcache-0.3.1.tar.gz", hash = "sha256:40d980c33765359098837527e18eddefc9a24cea5b45e078a7f3bb5b032c6ecf"}, ] [[package]] name = "pydantic" -version = "2.5.3" +version = "2.11.3" description = "Data validation using Python type hints" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "pydantic-2.5.3-py3-none-any.whl", hash = "sha256:d0caf5954bee831b6bfe7e338c32b9e30c85dfe080c843680783ac2b631673b4"}, - {file = "pydantic-2.5.3.tar.gz", hash = "sha256:b3ef57c62535b0941697cce638c08900d87fcb67e29cfa99e8a68f747f393f7a"}, + {file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"}, + {file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"}, ] [package.dependencies] -annotated-types = ">=0.4.0" -importlib-metadata = {version = "*", markers = "python_version == \"3.7\""} -pydantic-core = "2.14.6" -typing-extensions = ">=4.6.1" +annotated-types = ">=0.6.0" +pydantic-core = "2.33.1" +typing-extensions = ">=4.12.2" +typing-inspection = ">=0.4.0" [package.extras] email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata"] [[package]] name = "pydantic-core" -version = "2.14.6" -description = "" +version = "2.33.1" +description = "Core functionality for Pydantic validation and serialization" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "pydantic_core-2.14.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:72f9a942d739f09cd42fffe5dc759928217649f070056f03c70df14f5770acf9"}, - {file = "pydantic_core-2.14.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6a31d98c0d69776c2576dda4b77b8e0c69ad08e8b539c25c7d0ca0dc19a50d6c"}, - {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aa90562bc079c6c290f0512b21768967f9968e4cfea84ea4ff5af5d917016e4"}, - {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:370ffecb5316ed23b667d99ce4debe53ea664b99cc37bfa2af47bc769056d534"}, - {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f85f3843bdb1fe80e8c206fe6eed7a1caeae897e496542cee499c374a85c6e08"}, - {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862bf828112e19685b76ca499b379338fd4c5c269d897e218b2ae8fcb80139d"}, - {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036137b5ad0cb0004c75b579445a1efccd072387a36c7f217bb8efd1afbe5245"}, - {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92879bce89f91f4b2416eba4429c7b5ca22c45ef4a499c39f0c5c69257522c7c"}, - {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c08de15d50fa190d577e8591f0329a643eeaed696d7771760295998aca6bc66"}, - {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:36099c69f6b14fc2c49d7996cbf4f87ec4f0e66d1c74aa05228583225a07b590"}, - {file = "pydantic_core-2.14.6-cp310-none-win32.whl", hash = "sha256:7be719e4d2ae6c314f72844ba9d69e38dff342bc360379f7c8537c48e23034b7"}, - {file = "pydantic_core-2.14.6-cp310-none-win_amd64.whl", hash = "sha256:36fa402dcdc8ea7f1b0ddcf0df4254cc6b2e08f8cd80e7010d4c4ae6e86b2a87"}, - {file = "pydantic_core-2.14.6-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:dea7fcd62915fb150cdc373212141a30037e11b761fbced340e9db3379b892d4"}, - {file = "pydantic_core-2.14.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffff855100bc066ff2cd3aa4a60bc9534661816b110f0243e59503ec2df38421"}, - {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b027c86c66b8627eb90e57aee1f526df77dc6d8b354ec498be9a757d513b92b"}, - {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00b1087dabcee0b0ffd104f9f53d7d3eaddfaa314cdd6726143af6bc713aa27e"}, - {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75ec284328b60a4e91010c1acade0c30584f28a1f345bc8f72fe8b9e46ec6a96"}, - {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e1f4744eea1501404b20b0ac059ff7e3f96a97d3e3f48ce27a139e053bb370b"}, - {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2602177668f89b38b9f84b7b3435d0a72511ddef45dc14446811759b82235a1"}, - {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c8edaea3089bf908dd27da8f5d9e395c5b4dc092dbcce9b65e7156099b4b937"}, - {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:478e9e7b360dfec451daafe286998d4a1eeaecf6d69c427b834ae771cad4b622"}, - {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b6ca36c12a5120bad343eef193cc0122928c5c7466121da7c20f41160ba00ba2"}, - {file = "pydantic_core-2.14.6-cp311-none-win32.whl", hash = "sha256:2b8719037e570639e6b665a4050add43134d80b687288ba3ade18b22bbb29dd2"}, - {file = "pydantic_core-2.14.6-cp311-none-win_amd64.whl", hash = "sha256:78ee52ecc088c61cce32b2d30a826f929e1708f7b9247dc3b921aec367dc1b23"}, - {file = "pydantic_core-2.14.6-cp311-none-win_arm64.whl", hash = "sha256:a19b794f8fe6569472ff77602437ec4430f9b2b9ec7a1105cfd2232f9ba355e6"}, - {file = "pydantic_core-2.14.6-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:667aa2eac9cd0700af1ddb38b7b1ef246d8cf94c85637cbb03d7757ca4c3fdec"}, - {file = "pydantic_core-2.14.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cdee837710ef6b56ebd20245b83799fce40b265b3b406e51e8ccc5b85b9099b7"}, - {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c5bcf3414367e29f83fd66f7de64509a8fd2368b1edf4351e862910727d3e51"}, - {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a92ae76f75d1915806b77cf459811e772d8f71fd1e4339c99750f0e7f6324f"}, - {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a983cca5ed1dd9a35e9e42ebf9f278d344603bfcb174ff99a5815f953925140a"}, - {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cb92f9061657287eded380d7dc455bbf115430b3aa4741bdc662d02977e7d0af"}, - {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4ace1e220b078c8e48e82c081e35002038657e4b37d403ce940fa679e57113b"}, - {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef633add81832f4b56d3b4c9408b43d530dfca29e68fb1b797dcb861a2c734cd"}, - {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7e90d6cc4aad2cc1f5e16ed56e46cebf4877c62403a311af20459c15da76fd91"}, - {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e8a5ac97ea521d7bde7621d86c30e86b798cdecd985723c4ed737a2aa9e77d0c"}, - {file = "pydantic_core-2.14.6-cp312-none-win32.whl", hash = "sha256:f27207e8ca3e5e021e2402ba942e5b4c629718e665c81b8b306f3c8b1ddbb786"}, - {file = "pydantic_core-2.14.6-cp312-none-win_amd64.whl", hash = "sha256:b3e5fe4538001bb82e2295b8d2a39356a84694c97cb73a566dc36328b9f83b40"}, - {file = "pydantic_core-2.14.6-cp312-none-win_arm64.whl", hash = "sha256:64634ccf9d671c6be242a664a33c4acf12882670b09b3f163cd00a24cffbd74e"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:24368e31be2c88bd69340fbfe741b405302993242ccb476c5c3ff48aeee1afe0"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:e33b0834f1cf779aa839975f9d8755a7c2420510c0fa1e9fa0497de77cd35d2c"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6af4b3f52cc65f8a0bc8b1cd9676f8c21ef3e9132f21fed250f6958bd7223bed"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d15687d7d7f40333bd8266f3814c591c2e2cd263fa2116e314f60d82086e353a"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:095b707bb287bfd534044166ab767bec70a9bba3175dcdc3371782175c14e43c"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94fc0e6621e07d1e91c44e016cc0b189b48db053061cc22d6298a611de8071bb"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce830e480f6774608dedfd4a90c42aac4a7af0a711f1b52f807130c2e434c06"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a306cdd2ad3a7d795d8e617a58c3a2ed0f76c8496fb7621b6cd514eb1532cae8"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f5fa187bde8524b1e37ba894db13aadd64faa884657473b03a019f625cee9a8"}, - {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:438027a975cc213a47c5d70672e0d29776082155cfae540c4e225716586be75e"}, - {file = "pydantic_core-2.14.6-cp37-none-win32.whl", hash = "sha256:f96ae96a060a8072ceff4cfde89d261837b4294a4f28b84a28765470d502ccc6"}, - {file = "pydantic_core-2.14.6-cp37-none-win_amd64.whl", hash = "sha256:e646c0e282e960345314f42f2cea5e0b5f56938c093541ea6dbf11aec2862391"}, - {file = "pydantic_core-2.14.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:db453f2da3f59a348f514cfbfeb042393b68720787bbef2b4c6068ea362c8149"}, - {file = "pydantic_core-2.14.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3860c62057acd95cc84044e758e47b18dcd8871a328ebc8ccdefd18b0d26a21b"}, - {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36026d8f99c58d7044413e1b819a67ca0e0b8ebe0f25e775e6c3d1fabb3c38fb"}, - {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8ed1af8692bd8d2a29d702f1a2e6065416d76897d726e45a1775b1444f5928a7"}, - {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:314ccc4264ce7d854941231cf71b592e30d8d368a71e50197c905874feacc8a8"}, - {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:982487f8931067a32e72d40ab6b47b1628a9c5d344be7f1a4e668fb462d2da42"}, - {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dbe357bc4ddda078f79d2a36fc1dd0494a7f2fad83a0a684465b6f24b46fe80"}, - {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2f6ffc6701a0eb28648c845f4945a194dc7ab3c651f535b81793251e1185ac3d"}, - {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f5025db12fc6de7bc1104d826d5aee1d172f9ba6ca936bf6474c2148ac336c1"}, - {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dab03ed811ed1c71d700ed08bde8431cf429bbe59e423394f0f4055f1ca0ea60"}, - {file = "pydantic_core-2.14.6-cp38-none-win32.whl", hash = "sha256:dfcbebdb3c4b6f739a91769aea5ed615023f3c88cb70df812849aef634c25fbe"}, - {file = "pydantic_core-2.14.6-cp38-none-win_amd64.whl", hash = "sha256:99b14dbea2fdb563d8b5a57c9badfcd72083f6006caf8e126b491519c7d64ca8"}, - {file = "pydantic_core-2.14.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:4ce8299b481bcb68e5c82002b96e411796b844d72b3e92a3fbedfe8e19813eab"}, - {file = "pydantic_core-2.14.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b9a9d92f10772d2a181b5ca339dee066ab7d1c9a34ae2421b2a52556e719756f"}, - {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd9e98b408384989ea4ab60206b8e100d8687da18b5c813c11e92fd8212a98e0"}, - {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4f86f1f318e56f5cbb282fe61eb84767aee743ebe32c7c0834690ebea50c0a6b"}, - {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86ce5fcfc3accf3a07a729779d0b86c5d0309a4764c897d86c11089be61da160"}, - {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dcf1978be02153c6a31692d4fbcc2a3f1db9da36039ead23173bc256ee3b91b"}, - {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eedf97be7bc3dbc8addcef4142f4b4164066df0c6f36397ae4aaed3eb187d8ab"}, - {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5f916acf8afbcab6bacbb376ba7dc61f845367901ecd5e328fc4d4aef2fcab0"}, - {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8a14c192c1d724c3acbfb3f10a958c55a2638391319ce8078cb36c02283959b9"}, - {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0348b1dc6b76041516e8a854ff95b21c55f5a411c3297d2ca52f5528e49d8411"}, - {file = "pydantic_core-2.14.6-cp39-none-win32.whl", hash = "sha256:de2a0645a923ba57c5527497daf8ec5df69c6eadf869e9cd46e86349146e5975"}, - {file = "pydantic_core-2.14.6-cp39-none-win_amd64.whl", hash = "sha256:aca48506a9c20f68ee61c87f2008f81f8ee99f8d7f0104bff3c47e2d148f89d9"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d5c28525c19f5bb1e09511669bb57353d22b94cf8b65f3a8d141c389a55dec95"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:78d0768ee59baa3de0f4adac9e3748b4b1fffc52143caebddfd5ea2961595277"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b93785eadaef932e4fe9c6e12ba67beb1b3f1e5495631419c784ab87e975670"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a874f21f87c485310944b2b2734cd6d318765bcbb7515eead33af9641816506e"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89f4477d915ea43b4ceea6756f63f0288941b6443a2b28c69004fe07fde0d0d"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:172de779e2a153d36ee690dbc49c6db568d7b33b18dc56b69a7514aecbcf380d"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dfcebb950aa7e667ec226a442722134539e77c575f6cfaa423f24371bb8d2e94"}, - {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:55a23dcd98c858c0db44fc5c04fc7ed81c4b4d33c653a7c45ddaebf6563a2f66"}, - {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4241204e4b36ab5ae466ecec5c4c16527a054c69f99bba20f6f75232a6a534e2"}, - {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e574de99d735b3fc8364cba9912c2bec2da78775eba95cbb225ef7dda6acea24"}, - {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1302a54f87b5cd8528e4d6d1bf2133b6aa7c6122ff8e9dc5220fbc1e07bffebd"}, - {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8e81e4b55930e5ffab4a68db1af431629cf2e4066dbdbfef65348b8ab804ea8"}, - {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c99462ffc538717b3e60151dfaf91125f637e801f5ab008f81c402f1dff0cd0f"}, - {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e4cf2d5829f6963a5483ec01578ee76d329eb5caf330ecd05b3edd697e7d768a"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:cf10b7d58ae4a1f07fccbf4a0a956d705356fea05fb4c70608bb6fa81d103cda"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:399ac0891c284fa8eb998bcfa323f2234858f5d2efca3950ae58c8f88830f145"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c6a5c79b28003543db3ba67d1df336f253a87d3112dac3a51b94f7d48e4c0e1"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599c87d79cab2a6a2a9df4aefe0455e61e7d2aeede2f8577c1b7c0aec643ee8e"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43e166ad47ba900f2542a80d83f9fc65fe99eb63ceec4debec160ae729824052"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a0b5db001b98e1c649dd55afa928e75aa4087e587b9524a4992316fa23c9fba"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:747265448cb57a9f37572a488a57d873fd96bf51e5bb7edb52cfb37124516da4"}, - {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7ebe3416785f65c28f4f9441e916bfc8a54179c8dea73c23023f7086fa601c5d"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:86c963186ca5e50d5c8287b1d1c9d3f8f024cbe343d048c5bd282aec2d8641f2"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e0641b506486f0b4cd1500a2a65740243e8670a2549bb02bc4556a83af84ae03"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71d72ca5eaaa8d38c8df16b7deb1a2da4f650c41b58bb142f3fb75d5ad4a611f"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e524624eace5c59af499cd97dc18bb201dc6a7a2da24bfc66ef151c69a5f2a"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3dde6cac75e0b0902778978d3b1646ca9f438654395a362cb21d9ad34b24acf"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:00646784f6cd993b1e1c0e7b0fdcbccc375d539db95555477771c27555e3c556"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:23598acb8ccaa3d1d875ef3b35cb6376535095e9405d91a3d57a8c7db5d29341"}, - {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7f41533d7e3cf9520065f610b41ac1c76bc2161415955fbcead4981b22c7611e"}, - {file = "pydantic_core-2.14.6.tar.gz", hash = "sha256:1fd0c1d395372843fba13a51c28e3bb9d59bd7aebfeb17358ffaaa1e4dbbe948"}, + {file = "pydantic_core-2.33.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3077cfdb6125cc8dab61b155fdd714663e401f0e6883f9632118ec12cf42df26"}, + {file = "pydantic_core-2.33.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ffab8b2908d152e74862d276cf5017c81a2f3719f14e8e3e8d6b83fda863927"}, + {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5183e4f6a2d468787243ebcd70cf4098c247e60d73fb7d68d5bc1e1beaa0c4db"}, + {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:398a38d323f37714023be1e0285765f0a27243a8b1506b7b7de87b647b517e48"}, + {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87d3776f0001b43acebfa86f8c64019c043b55cc5a6a2e313d728b5c95b46969"}, + {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c566dd9c5f63d22226409553531f89de0cac55397f2ab8d97d6f06cfce6d947e"}, + {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0d5f3acc81452c56895e90643a625302bd6be351e7010664151cc55b7b97f89"}, + {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3a07fadec2a13274a8d861d3d37c61e97a816beae717efccaa4b36dfcaadcde"}, + {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f99aeda58dce827f76963ee87a0ebe75e648c72ff9ba1174a253f6744f518f65"}, + {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:902dbc832141aa0ec374f4310f1e4e7febeebc3256f00dc359a9ac3f264a45dc"}, + {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe44d56aa0b00d66640aa84a3cbe80b7a3ccdc6f0b1ca71090696a6d4777c091"}, + {file = "pydantic_core-2.33.1-cp310-cp310-win32.whl", hash = "sha256:ed3eb16d51257c763539bde21e011092f127a2202692afaeaccb50db55a31383"}, + {file = "pydantic_core-2.33.1-cp310-cp310-win_amd64.whl", hash = "sha256:694ad99a7f6718c1a498dc170ca430687a39894a60327f548e02a9c7ee4b6504"}, + {file = "pydantic_core-2.33.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e966fc3caaf9f1d96b349b0341c70c8d6573bf1bac7261f7b0ba88f96c56c24"}, + {file = "pydantic_core-2.33.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bfd0adeee563d59c598ceabddf2c92eec77abcb3f4a391b19aa7366170bd9e30"}, + {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91815221101ad3c6b507804178a7bb5cb7b2ead9ecd600041669c8d805ebd595"}, + {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9fea9c1869bb4742d174a57b4700c6dadea951df8b06de40c2fedb4f02931c2e"}, + {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d20eb4861329bb2484c021b9d9a977566ab16d84000a57e28061151c62b349a"}, + {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb935c5591573ae3201640579f30128ccc10739b45663f93c06796854405505"}, + {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c964fd24e6166420d18fb53996d8c9fd6eac9bf5ae3ec3d03015be4414ce497f"}, + {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:681d65e9011f7392db5aa002b7423cc442d6a673c635668c227c6c8d0e5a4f77"}, + {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e100c52f7355a48413e2999bfb4e139d2977a904495441b374f3d4fb4a170961"}, + {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:048831bd363490be79acdd3232f74a0e9951b11b2b4cc058aeb72b22fdc3abe1"}, + {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bdc84017d28459c00db6f918a7272a5190bec3090058334e43a76afb279eac7c"}, + {file = "pydantic_core-2.33.1-cp311-cp311-win32.whl", hash = "sha256:32cd11c5914d1179df70406427097c7dcde19fddf1418c787540f4b730289896"}, + {file = "pydantic_core-2.33.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ea62419ba8c397e7da28a9170a16219d310d2cf4970dbc65c32faf20d828c83"}, + {file = "pydantic_core-2.33.1-cp311-cp311-win_arm64.whl", hash = "sha256:fc903512177361e868bc1f5b80ac8c8a6e05fcdd574a5fb5ffeac5a9982b9e89"}, + {file = "pydantic_core-2.33.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1293d7febb995e9d3ec3ea09caf1a26214eec45b0f29f6074abb004723fc1de8"}, + {file = "pydantic_core-2.33.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:99b56acd433386c8f20be5c4000786d1e7ca0523c8eefc995d14d79c7a081498"}, + {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35a5ec3fa8c2fe6c53e1b2ccc2454398f95d5393ab398478f53e1afbbeb4d939"}, + {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b172f7b9d2f3abc0efd12e3386f7e48b576ef309544ac3a63e5e9cdd2e24585d"}, + {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9097b9f17f91eea659b9ec58148c0747ec354a42f7389b9d50701610d86f812e"}, + {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc77ec5b7e2118b152b0d886c7514a4653bcb58c6b1d760134a9fab915f777b3"}, + {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3d15245b08fa4a84cefc6c9222e6f37c98111c8679fbd94aa145f9a0ae23d"}, + {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef99779001d7ac2e2461d8ab55d3373fe7315caefdbecd8ced75304ae5a6fc6b"}, + {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fc6bf8869e193855e8d91d91f6bf59699a5cdfaa47a404e278e776dd7f168b39"}, + {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:b1caa0bc2741b043db7823843e1bde8aaa58a55a58fda06083b0569f8b45693a"}, + {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ec259f62538e8bf364903a7d0d0239447059f9434b284f5536e8402b7dd198db"}, + {file = "pydantic_core-2.33.1-cp312-cp312-win32.whl", hash = "sha256:e14f369c98a7c15772b9da98987f58e2b509a93235582838bd0d1d8c08b68fda"}, + {file = "pydantic_core-2.33.1-cp312-cp312-win_amd64.whl", hash = "sha256:1c607801d85e2e123357b3893f82c97a42856192997b95b4d8325deb1cd0c5f4"}, + {file = "pydantic_core-2.33.1-cp312-cp312-win_arm64.whl", hash = "sha256:8d13f0276806ee722e70a1c93da19748594f19ac4299c7e41237fc791d1861ea"}, + {file = "pydantic_core-2.33.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:70af6a21237b53d1fe7b9325b20e65cbf2f0a848cf77bed492b029139701e66a"}, + {file = "pydantic_core-2.33.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:282b3fe1bbbe5ae35224a0dbd05aed9ccabccd241e8e6b60370484234b456266"}, + {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b315e596282bbb5822d0c7ee9d255595bd7506d1cb20c2911a4da0b970187d3"}, + {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dfae24cf9921875ca0ca6a8ecb4bb2f13c855794ed0d468d6abbec6e6dcd44a"}, + {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd8ecfde08d8bfadaea669e83c63939af76f4cf5538a72597016edfa3fad516"}, + {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f593494876eae852dc98c43c6f260f45abdbfeec9e4324e31a481d948214764"}, + {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948b73114f47fd7016088e5186d13faf5e1b2fe83f5e320e371f035557fd264d"}, + {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e11f3864eb516af21b01e25fac915a82e9ddad3bb0fb9e95a246067398b435a4"}, + {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:549150be302428b56fdad0c23c2741dcdb5572413776826c965619a25d9c6bde"}, + {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:495bc156026efafd9ef2d82372bd38afce78ddd82bf28ef5276c469e57c0c83e"}, + {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ec79de2a8680b1a67a07490bddf9636d5c2fab609ba8c57597e855fa5fa4dacd"}, + {file = "pydantic_core-2.33.1-cp313-cp313-win32.whl", hash = "sha256:ee12a7be1742f81b8a65b36c6921022301d466b82d80315d215c4c691724986f"}, + {file = "pydantic_core-2.33.1-cp313-cp313-win_amd64.whl", hash = "sha256:ede9b407e39949d2afc46385ce6bd6e11588660c26f80576c11c958e6647bc40"}, + {file = "pydantic_core-2.33.1-cp313-cp313-win_arm64.whl", hash = "sha256:aa687a23d4b7871a00e03ca96a09cad0f28f443690d300500603bd0adba4b523"}, + {file = "pydantic_core-2.33.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:401d7b76e1000d0dd5538e6381d28febdcacb097c8d340dde7d7fc6e13e9f95d"}, + {file = "pydantic_core-2.33.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7aeb055a42d734c0255c9e489ac67e75397d59c6fbe60d155851e9782f276a9c"}, + {file = "pydantic_core-2.33.1-cp313-cp313t-win_amd64.whl", hash = "sha256:338ea9b73e6e109f15ab439e62cb3b78aa752c7fd9536794112e14bee02c8d18"}, + {file = "pydantic_core-2.33.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5ab77f45d33d264de66e1884fca158bc920cb5e27fd0764a72f72f5756ae8bdb"}, + {file = "pydantic_core-2.33.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7aaba1b4b03aaea7bb59e1b5856d734be011d3e6d98f5bcaa98cb30f375f2ad"}, + {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fb66263e9ba8fea2aa85e1e5578980d127fb37d7f2e292773e7bc3a38fb0c7b"}, + {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f2648b9262607a7fb41d782cc263b48032ff7a03a835581abbf7a3bec62bcf5"}, + {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:723c5630c4259400818b4ad096735a829074601805d07f8cafc366d95786d331"}, + {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d100e3ae783d2167782391e0c1c7a20a31f55f8015f3293647544df3f9c67824"}, + {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177d50460bc976a0369920b6c744d927b0ecb8606fb56858ff542560251b19e5"}, + {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3edde68d1a1f9af1273b2fe798997b33f90308fb6d44d8550c89fc6a3647cf6"}, + {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a62c3c3ef6a7e2c45f7853b10b5bc4ddefd6ee3cd31024754a1a5842da7d598d"}, + {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:c91dbb0ab683fa0cd64a6e81907c8ff41d6497c346890e26b23de7ee55353f96"}, + {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f466e8bf0a62dc43e068c12166281c2eca72121dd2adc1040f3aa1e21ef8599"}, + {file = "pydantic_core-2.33.1-cp39-cp39-win32.whl", hash = "sha256:ab0277cedb698749caada82e5d099dc9fed3f906a30d4c382d1a21725777a1e5"}, + {file = "pydantic_core-2.33.1-cp39-cp39-win_amd64.whl", hash = "sha256:5773da0ee2d17136b1f1c6fbde543398d452a6ad2a7b54ea1033e2daa739b8d2"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c834f54f8f4640fd7e4b193f80eb25a0602bba9e19b3cd2fc7ffe8199f5ae02"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:049e0de24cf23766f12cc5cc71d8abc07d4a9deb9061b334b62093dedc7cb068"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a28239037b3d6f16916a4c831a5a0eadf856bdd6d2e92c10a0da3a59eadcf3e"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d3da303ab5f378a268fa7d45f37d7d85c3ec19769f28d2cc0c61826a8de21fe"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:25626fb37b3c543818c14821afe0fd3830bc327a43953bc88db924b68c5723f1"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3ab2d36e20fbfcce8f02d73c33a8a7362980cff717926bbae030b93ae46b56c7"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:2f9284e11c751b003fd4215ad92d325d92c9cb19ee6729ebd87e3250072cdcde"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:048c01eee07d37cbd066fc512b9d8b5ea88ceeb4e629ab94b3e56965ad655add"}, + {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5ccd429694cf26af7997595d627dd2637e7932214486f55b8a357edaac9dae8c"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3a371dc00282c4b84246509a5ddc808e61b9864aa1eae9ecc92bb1268b82db4a"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:f59295ecc75a1788af8ba92f2e8c6eeaa5a94c22fc4d151e8d9638814f85c8fc"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08530b8ac922003033f399128505f513e30ca770527cc8bbacf75a84fcc2c74b"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae370459da6a5466978c0eacf90690cb57ec9d533f8e63e564ef3822bfa04fe"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3de2777e3b9f4d603112f78006f4ae0acb936e95f06da6cb1a45fbad6bdb4b5"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a64e81e8cba118e108d7126362ea30e021291b7805d47e4896e52c791be2761"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:52928d8c1b6bda03cc6d811e8923dffc87a2d3c8b3bfd2ce16471c7147a24850"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1b30d92c9412beb5ac6b10a3eb7ef92ccb14e3f2a8d7732e2d739f58b3aa7544"}, + {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f995719707e0e29f0f41a8aa3bcea6e761a36c9136104d3189eafb83f5cec5e5"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7edbc454a29fc6aeae1e1eecba4f07b63b8d76e76a748532233c4c167b4cb9ea"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:ad05b683963f69a1d5d2c2bdab1274a31221ca737dbbceaa32bcb67359453cdd"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df6a94bf9452c6da9b5d76ed229a5683d0306ccb91cca8e1eea883189780d568"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7965c13b3967909a09ecc91f21d09cfc4576bf78140b988904e94f130f188396"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3f1fdb790440a34f6ecf7679e1863b825cb5ffde858a9197f851168ed08371e5"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:5277aec8d879f8d05168fdd17ae811dd313b8ff894aeeaf7cd34ad28b4d77e33"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8ab581d3530611897d863d1a649fb0644b860286b4718db919bfd51ece41f10b"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0483847fa9ad5e3412265c1bd72aad35235512d9ce9d27d81a56d935ef489672"}, + {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:de9e06abe3cc5ec6a2d5f75bc99b0bdca4f5c719a5b34026f8c57efbdecd2ee3"}, + {file = "pydantic_core-2.33.1.tar.gz", hash = "sha256:bcc9c6fdb0ced789245b02b7d6603e17d1563064ddcfc36f046b61c0c05dd9df"}, ] [package.dependencies] @@ -844,134 +1009,139 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pytest" -version = "6.2.5" +version = "8.3.5" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" +groups = ["dev"] files = [ - {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, - {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, + {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"}, + {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, ] [package.dependencies] -atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} -attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<2.0" -py = ">=1.8.2" -toml = "*" +pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-asyncio" -version = "0.17.2" +version = "0.26.0" description = "Pytest support for asyncio" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["dev"] files = [ - {file = "pytest-asyncio-0.17.2.tar.gz", hash = "sha256:6d895b02432c028e6957d25fc936494e78c6305736e785d9fee408b1efbc7ff4"}, - {file = "pytest_asyncio-0.17.2-py3-none-any.whl", hash = "sha256:e0fe5dbea40516b661ef1bcfe0bd9461c2847c4ef4bb40012324f2454fb7d56d"}, + {file = "pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0"}, + {file = "pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f"}, ] [package.dependencies] -pytest = ">=6.1.0" -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.8\""} +pytest = ">=8.2,<9" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.10\""} [package.extras] -testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] name = "pytest-cov" -version = "3.0.0" +version = "6.1.1" description = "Pytest plugin for measuring coverage." optional = false -python-versions = ">=3.6" +python-versions = ">=3.9" +groups = ["dev"] files = [ - {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, - {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, + {file = "pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde"}, + {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"}, ] [package.dependencies] -coverage = {version = ">=5.2.1", extras = ["toml"]} +coverage = {version = ">=7.5", extras = ["toml"]} pytest = ">=4.6" [package.extras] -testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "pyyaml" -version = "6.0.1" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" +groups = ["main"] files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] name = "requests" -version = "2.31.0" +version = "2.32.3" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main"] files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -984,180 +1154,220 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] -[[package]] -name = "toml" -version = "0.10.2" -description = "Python Library for Tom's Obvious, Minimal Language" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, - {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, -] - [[package]] name = "tomli" -version = "2.0.1" +version = "2.2.1" description = "A lil' TOML parser" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_full_version <= \"3.11.0a6\"" files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] [[package]] name = "tqdm" -version = "4.66.1" +version = "4.67.1" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ - {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, - {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, + {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, + {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, ] [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] -dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"] +discord = ["requests"] notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] [[package]] name = "typing-extensions" -version = "4.7.1" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.13.2" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +groups = ["main", "dev"] files = [ - {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, - {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, + {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"}, + {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"}, ] +markers = {dev = "python_version < \"3.10\""} + +[[package]] +name = "typing-inspection" +version = "0.4.0" +description = "Runtime typing introspection tools" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f"}, + {file = "typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122"}, +] + +[package.dependencies] +typing-extensions = ">=4.12.0" [[package]] name = "urllib3" -version = "2.0.5" +version = "2.4.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"}, - {file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"}, + {file = "urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813"}, + {file = "urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] [[package]] name = "yarl" -version = "1.9.2" +version = "1.19.0" description = "Yet another URL library" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" +groups = ["main"] files = [ - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, - {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, - {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, - {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, - {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, - {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, - {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, - {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, - {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, - {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, - {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, - {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, - {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, + {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0bae32f8ebd35c04d6528cedb4a26b8bf25339d3616b04613b97347f919b76d3"}, + {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8015a076daf77823e7ebdcba474156587391dab4e70c732822960368c01251e6"}, + {file = "yarl-1.19.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9973ac95327f5d699eb620286c39365990b240031672b5c436a4cd00539596c5"}, + {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd4b5fbd7b9dde785cfeb486b8cca211a0b138d4f3a7da27db89a25b3c482e5c"}, + {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:75460740005de5a912b19f657848aef419387426a40f581b1dc9fac0eb9addb5"}, + {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57abd66ca913f2cfbb51eb3dbbbac3648f1f6983f614a4446e0802e241441d2a"}, + {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ade37911b7c99ce28a959147cb28bffbd14cea9e7dd91021e06a8d2359a5aa"}, + {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8346ec72ada749a6b5d82bff7be72578eab056ad7ec38c04f668a685abde6af0"}, + {file = "yarl-1.19.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e4cb14a6ee5b6649ccf1c6d648b4da9220e8277d4d4380593c03cc08d8fe937"}, + {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:66fc1c2926a73a2fb46e4b92e3a6c03904d9bc3a0b65e01cb7d2b84146a8bd3b"}, + {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:5a70201dd1e0a4304849b6445a9891d7210604c27e67da59091d5412bc19e51c"}, + {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4807aab1bdeab6ae6f296be46337a260ae4b1f3a8c2fcd373e236b4b2b46efd"}, + {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ae584afe81a1de4c1bb06672481050f0d001cad13163e3c019477409f638f9b7"}, + {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:30eaf4459df6e91f21b2999d1ee18f891bcd51e3cbe1de301b4858c84385895b"}, + {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0e617d45d03c8dec0dfce6f51f3e1b8a31aa81aaf4a4d1442fdb232bcf0c6d8c"}, + {file = "yarl-1.19.0-cp310-cp310-win32.whl", hash = "sha256:32ba32d0fa23893fd8ea8d05bdb05de6eb19d7f2106787024fd969f4ba5466cb"}, + {file = "yarl-1.19.0-cp310-cp310-win_amd64.whl", hash = "sha256:545575ecfcd465891b51546c2bcafdde0acd2c62c2097d8d71902050b20e4922"}, + {file = "yarl-1.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:163ff326680de5f6d4966954cf9e3fe1bf980f5fee2255e46e89b8cf0f3418b5"}, + {file = "yarl-1.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a626c4d9cca298d1be8625cff4b17004a9066330ac82d132bbda64a4c17c18d3"}, + {file = "yarl-1.19.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:961c3e401ea7f13d02b8bb7cb0c709152a632a6e14cdc8119e9c6ee5596cd45d"}, + {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a39d7b807ab58e633ed760f80195cbd145b58ba265436af35f9080f1810dfe64"}, + {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c4228978fb59c6b10f60124ba8e311c26151e176df364e996f3f8ff8b93971b5"}, + {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba536b17ecf3c74a94239ec1137a3ad3caea8c0e4deb8c8d2ffe847d870a8c5"}, + {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a251e00e445d2e9df7b827c9843c0b87f58a3254aaa3f162fb610747491fe00f"}, + {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9b92431d8b4d4ca5ccbfdbac95b05a3a6cd70cd73aa62f32f9627acfde7549c"}, + {file = "yarl-1.19.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec2f56edaf476f70b5831bbd59700b53d9dd011b1f77cd4846b5ab5c5eafdb3f"}, + {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:acf9b92c4245ac8b59bc7ec66a38d3dcb8d1f97fac934672529562bb824ecadb"}, + {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:57711f1465c06fee8825b95c0b83e82991e6d9425f9a042c3c19070a70ac92bf"}, + {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:528e86f5b1de0ad8dd758ddef4e0ed24f5d946d4a1cef80ffb2d4fca4e10f122"}, + {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3b77173663e075d9e5a57e09d711e9da2f3266be729ecca0b8ae78190990d260"}, + {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:d8717924cf0a825b62b1a96fc7d28aab7f55a81bf5338b8ef41d7a76ab9223e9"}, + {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0df9f0221a78d858793f40cbea3915c29f969c11366646a92ca47e080a14f881"}, + {file = "yarl-1.19.0-cp311-cp311-win32.whl", hash = "sha256:8b3ade62678ee2c7c10dcd6be19045135e9badad53108f7d2ed14896ee396045"}, + {file = "yarl-1.19.0-cp311-cp311-win_amd64.whl", hash = "sha256:0626ee31edb23ac36bdffe607231de2cca055ad3a5e2dc5da587ef8bc6a321bc"}, + {file = "yarl-1.19.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:7b687c334da3ff8eab848c9620c47a253d005e78335e9ce0d6868ed7e8fd170b"}, + {file = "yarl-1.19.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b0fe766febcf523a2930b819c87bb92407ae1368662c1bc267234e79b20ff894"}, + {file = "yarl-1.19.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:742ceffd3c7beeb2b20d47cdb92c513eef83c9ef88c46829f88d5b06be6734ee"}, + {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2af682a1e97437382ee0791eacbf540318bd487a942e068e7e0a6c571fadbbd3"}, + {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:63702f1a098d0eaaea755e9c9d63172be1acb9e2d4aeb28b187092bcc9ca2d17"}, + {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3560dcba3c71ae7382975dc1e912ee76e50b4cd7c34b454ed620d55464f11876"}, + {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68972df6a0cc47c8abaf77525a76ee5c5f6ea9bbdb79b9565b3234ded3c5e675"}, + {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5684e7ff93ea74e47542232bd132f608df4d449f8968fde6b05aaf9e08a140f9"}, + {file = "yarl-1.19.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8182ad422bfacdebd4759ce3adc6055c0c79d4740aea1104e05652a81cd868c6"}, + {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aee5b90a5a9b71ac57400a7bdd0feaa27c51e8f961decc8d412e720a004a1791"}, + {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:8c0b2371858d5a814b08542d5d548adb03ff2d7ab32f23160e54e92250961a72"}, + {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:cd430c2b7df4ae92498da09e9b12cad5bdbb140d22d138f9e507de1aa3edfea3"}, + {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a93208282c0ccdf73065fd76c6c129bd428dba5ff65d338ae7d2ab27169861a0"}, + {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:b8179280cdeb4c36eb18d6534a328f9d40da60d2b96ac4a295c5f93e2799e9d9"}, + {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eda3c2b42dc0c389b7cfda2c4df81c12eeb552019e0de28bde8f913fc3d1fcf3"}, + {file = "yarl-1.19.0-cp312-cp312-win32.whl", hash = "sha256:57f3fed859af367b9ca316ecc05ce79ce327d6466342734305aa5cc380e4d8be"}, + {file = "yarl-1.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:5507c1f7dd3d41251b67eecba331c8b2157cfd324849879bebf74676ce76aff7"}, + {file = "yarl-1.19.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:59281b9ed27bc410e0793833bcbe7fc149739d56ffa071d1e0fe70536a4f7b61"}, + {file = "yarl-1.19.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d27a6482ad5e05e8bafd47bf42866f8a1c0c3345abcb48d4511b3c29ecc197dc"}, + {file = "yarl-1.19.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7a8e19fd5a6fdf19a91f2409665c7a089ffe7b9b5394ab33c0eec04cbecdd01f"}, + {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cda34ab19099c3a1685ad48fe45172536610c312b993310b5f1ca3eb83453b36"}, + {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7908a25d33f94852b479910f9cae6cdb9e2a509894e8d5f416c8342c0253c397"}, + {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e66c14d162bac94973e767b24de5d7e6c5153f7305a64ff4fcba701210bcd638"}, + {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c03607bf932aa4cfae371e2dc9ca8b76faf031f106dac6a6ff1458418140c165"}, + {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9931343d1c1f4e77421687b6b94bbebd8a15a64ab8279adf6fbb047eff47e536"}, + {file = "yarl-1.19.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:262087a8a0d73e1d169d45c2baf968126f93c97cf403e1af23a7d5455d52721f"}, + {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:70f384921c24e703d249a6ccdabeb57dd6312b568b504c69e428a8dd3e8e68ca"}, + {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:756b9ea5292a2c180d1fe782a377bc4159b3cfefaca7e41b5b0a00328ef62fa9"}, + {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cbeb9c145d534c240a63b6ecc8a8dd451faeb67b3dc61d729ec197bb93e29497"}, + {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:087ae8f8319848c18e0d114d0f56131a9c017f29200ab1413b0137ad7c83e2ae"}, + {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362f5480ba527b6c26ff58cff1f229afe8b7fdd54ee5ffac2ab827c1a75fc71c"}, + {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f408d4b4315e814e5c3668094e33d885f13c7809cbe831cbdc5b1bb8c7a448f4"}, + {file = "yarl-1.19.0-cp313-cp313-win32.whl", hash = "sha256:24e4c367ad69988a2283dd45ea88172561ca24b2326b9781e164eb46eea68345"}, + {file = "yarl-1.19.0-cp313-cp313-win_amd64.whl", hash = "sha256:0110f91c57ab43d1538dfa92d61c45e33b84df9257bd08fcfcda90cce931cbc9"}, + {file = "yarl-1.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:85ac908cd5a97bbd3048cca9f1bf37b932ea26c3885099444f34b0bf5d5e9fa6"}, + {file = "yarl-1.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6ba0931b559f1345df48a78521c31cfe356585670e8be22af84a33a39f7b9221"}, + {file = "yarl-1.19.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5bc503e1c1fee1b86bcb58db67c032957a52cae39fe8ddd95441f414ffbab83e"}, + {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d995122dcaf180fd4830a9aa425abddab7c0246107c21ecca2fa085611fa7ce9"}, + {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:217f69e60a14da4eed454a030ea8283f8fbd01a7d6d81e57efb865856822489b"}, + {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aad67c8f13a4b79990082f72ef09c078a77de2b39899aabf3960a48069704973"}, + {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dff065a1a8ed051d7e641369ba1ad030d5a707afac54cf4ede7069b959898835"}, + {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada882e26b16ee651ab6544ce956f2f4beaed38261238f67c2a96db748e17741"}, + {file = "yarl-1.19.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67a56b1acc7093451ea2de0687aa3bd4e58d6b4ef6cbeeaad137b45203deaade"}, + {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e97d2f0a06b39e231e59ebab0e6eec45c7683b339e8262299ac952707bdf7688"}, + {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:a5288adb7c59d0f54e4ad58d86fb06d4b26e08a59ed06d00a1aac978c0e32884"}, + {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1efbf4d03e6eddf5da27752e0b67a8e70599053436e9344d0969532baa99df53"}, + {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:f228f42f29cc87db67020f7d71624102b2c837686e55317b16e1d3ef2747a993"}, + {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:c515f7dd60ca724e4c62b34aeaa603188964abed2eb66bb8e220f7f104d5a187"}, + {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4815ec6d3d68a96557fa71bd36661b45ac773fb50e5cfa31a7e843edb098f060"}, + {file = "yarl-1.19.0-cp39-cp39-win32.whl", hash = "sha256:9fac2dd1c5ecb921359d9546bc23a6dcc18c6acd50c6d96f118188d68010f497"}, + {file = "yarl-1.19.0-cp39-cp39-win_amd64.whl", hash = "sha256:5864f539ce86b935053bfa18205fa08ce38e9a40ea4d51b19ce923345f0ed5db"}, + {file = "yarl-1.19.0-py3-none-any.whl", hash = "sha256:a727101eb27f66727576630d02985d8a065d09cd0b5fcbe38a5793f71b2a97ef"}, + {file = "yarl-1.19.0.tar.gz", hash = "sha256:01e02bb80ae0dbed44273c304095295106e1d9470460e773268a27d11e594892"}, ] [package.dependencies] idna = ">=2.0" multidict = ">=4.0" -typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} - -[[package]] -name = "zipp" -version = "3.15.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.7" -files = [ - {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"}, - {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +propcache = ">=0.2.1" [metadata] -lock-version = "2.0" -python-versions = "^3.7" -content-hash = "b7fab8703967f2616ea59a98a437cd30f97f0c8d2a06e399d688814a2a2c64f8" +lock-version = "2.1" +python-versions = "^3.9" +content-hash = "f136e898d37b7c7db1ccceb1822ade280d3542ca19cdd9dcf583cb9aefef11c6" diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 47ef9d717..1448d7618 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -11,15 +11,15 @@ repository = "https://github.com/huggingface/text-generation-inference" [tool.poetry.dependencies] -python = "^3.7" +python = "^3.9" pydantic = "> 2, < 3" -aiohttp = "^3.8" +aiohttp = "^3.11" huggingface-hub = ">= 0.12, < 1.0" -[tool.poetry.dev-dependencies] -pytest = "^6.2.5" -pytest-asyncio = "^0.17.2" -pytest-cov = "^3.0.0" +[tool.poetry.group.dev.dependencies] +pytest = "^8" +pytest-asyncio = "^0.26" +pytest-cov = "^6.0.0" [tool.pytest.ini_options] asyncio_mode = "auto" From 73e797528df1ceadabf687f1f15882f759f12a62 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 14 Apr 2025 22:13:53 +0530 Subject: [PATCH 3/3] L4 fixes (#3161) add fix --- router/src/config.rs | 9 ++++++--- router/src/lib.rs | 2 +- router/src/validation.rs | 6 +++++- server/text_generation_server/models/__init__.py | 1 - 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index fcda26122..93b6f4fa4 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -229,10 +229,13 @@ impl Llama4 { pub fn pixel_shuffle_ratio(&self) -> f64 { self.vision_config.pixel_shuffle_ratio } - pub fn get_aspect_ratios(&self, height: usize, width: usize) -> (usize, usize) { + pub fn get_aspect_ratios( + &self, + height: usize, + width: usize, + max_chunks: usize, + ) -> (usize, usize) { let patch_size = self.vision_config.image_size; - // How to avoid hardcoding this? - let max_chunks = 15; let supported = find_supported_resolutions(max_chunks, patch_size); let (target_h, target_w) = get_best_fit(height, width, &supported, false); (target_h / patch_size, target_w / patch_size) diff --git a/router/src/lib.rs b/router/src/lib.rs index 50adb5cf6..3c1a01b3c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -204,7 +204,7 @@ pub struct Gemma3Processor { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Llama4Processor { #[serde(default)] - do_image_splitting: bool, + max_patches: usize, } #[derive(Debug, Clone, Deserialize, Default)] diff --git a/router/src/validation.rs b/router/src/validation.rs index dfe9dd4d2..b29391b77 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -698,10 +698,14 @@ fn image_tokens( let image_height = config.image_size(); let patch_size = config.patch_size(); let pixel_shuffle_ratio = config.pixel_shuffle_ratio(); + let max_patches = match preprocessor_config { + Some(HubPreprocessorConfig::Llama4Processor(cfg)) => cfg.max_patches, + _ => panic!("Expected Llama4Processor in preprocessor_config"), + }; let downsample_ratio = (1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize; - let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width); + let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width, max_patches); let image_width = image_height; // Assuming pixel shape: [H][W][C] let num_patches_per_chunk = diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 84437bf32..291ee5fba 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1041,7 +1041,6 @@ def get_model( trust_remote_code=trust_remote_code, processor_kwargs={ "use_fast": True, - "size": {"height": 336, "width": 336}, }, ) elif model_type == BAICHUAN: