mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)
* clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix TP in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix phimoe issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality initial PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptq issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable fp8 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add warmup_decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * missing gptj change... Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9a8d0462e1
commit
d62c941c56
@ -95,7 +95,7 @@ RUN cd server && \
|
|||||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
pip install . --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
|
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
|
@ -16,15 +16,9 @@ app = typer.Typer()
|
|||||||
|
|
||||||
|
|
||||||
class Quantization(str, Enum):
|
class Quantization(str, Enum):
|
||||||
bitsandbytes = "bitsandbytes"
|
|
||||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
|
||||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
awq = "awq"
|
awq = "awq"
|
||||||
eetq = "eetq"
|
|
||||||
exl2 = "exl2"
|
|
||||||
fp8 = "fp8"
|
fp8 = "fp8"
|
||||||
marlin = "marlin"
|
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
@ -105,6 +99,9 @@ def serve(
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"bitsandbytes-nf4",
|
"bitsandbytes-nf4",
|
||||||
"bitsandbytes-fp4",
|
"bitsandbytes-fp4",
|
||||||
|
"gptq",
|
||||||
|
"awq",
|
||||||
|
"fp8",
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
@ -112,7 +109,7 @@ def serve(
|
|||||||
|
|
||||||
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||||||
|
|
||||||
if sharded:
|
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
|
||||||
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||||||
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
logger.info("CLI SHARDED = {}".format(num_shard))
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||||
|
@ -1,43 +1,28 @@
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from .common import (
|
||||||
import os
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
trim_attn_metadata,
|
||||||
|
trim_seqlen_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
from .common import Seqlen
|
from .hpu import (
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
|
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
from .cuda import (
|
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from .rocm import (
|
|
||||||
attention,
|
|
||||||
paged_attention,
|
|
||||||
reshape_and_cache,
|
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
)
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from .ipex import (
|
|
||||||
attention,
|
|
||||||
paged_attention,
|
|
||||||
reshape_and_cache,
|
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
|
||||||
|
|
||||||
|
|
||||||
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
|
from .kv_cache import KVCache, get_kv_scales
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
|
"get_kv_scales",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
|
"KVCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
|
"HPUPagedAttentionMetadata",
|
||||||
|
"trim_seqlen_metadata",
|
||||||
|
"trim_attn_metadata",
|
||||||
]
|
]
|
||||||
|
@ -1,31 +1,94 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict
|
||||||
|
import collections
|
||||||
|
|
||||||
|
_TYPE_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
@dataclass
|
||||||
|
class HPUPagedAttentionMetadata:
|
||||||
|
"""Metadata for PagedAttention."""
|
||||||
|
|
||||||
|
block_list: Optional[torch.Tensor]
|
||||||
|
block_mapping: Optional[torch.Tensor]
|
||||||
|
block_usage: Optional[torch.Tensor]
|
||||||
|
block_scales: Optional[torch.Tensor]
|
||||||
|
block_groups: Optional[torch.Tensor]
|
||||||
|
attn_bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def subtuple(
|
||||||
|
obj: object,
|
||||||
|
typename: str,
|
||||||
|
to_copy: List[str],
|
||||||
|
to_override: Optional[Dict[str, object]] = None,
|
||||||
|
):
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
if to_override is None:
|
||||||
|
to_override = {}
|
||||||
|
fields = set(to_copy) | set(to_override.keys())
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
values = {key: obj[key] for key in fields if key in obj}
|
||||||
|
else:
|
||||||
|
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
||||||
|
if typename not in _TYPE_CACHE:
|
||||||
|
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
|
||||||
|
return _TYPE_CACHE[typename](**values)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||||
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
|
# Attention metadata is going to be hashed by PT bridge, and
|
||||||
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||||
|
|
||||||
|
# Before you put more keys in here, make sure you know their
|
||||||
|
# value type and make sure you know how it's going to be hashed.
|
||||||
|
# You can find that information in input_hash function
|
||||||
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||||
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||||
|
|
||||||
|
# If you use primitive types here - they will get hashed based
|
||||||
|
# on their value. You *will* get lots of excessive graph captures
|
||||||
|
# (and an OOM eventually) if you decide to put something like
|
||||||
|
# seq_len int here.
|
||||||
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||||
|
# get hashed using their metadata, not their values:
|
||||||
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||||
|
# input_hash(123) != input_hash(321)
|
||||||
|
# input_hash("abc") != input_hash("cba")
|
||||||
|
attention_metadata = subtuple(
|
||||||
|
metadata,
|
||||||
|
"TrimmedAttentionMetadata",
|
||||||
|
[
|
||||||
|
"block_list",
|
||||||
|
"block_mapping",
|
||||||
|
"block_usage",
|
||||||
|
"block_scales",
|
||||||
|
"block_groups",
|
||||||
|
"attn_bias",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return attention_metadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
prefix_lengths: torch.Tensor
|
cache_lengths: torch.Tensor
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
prefix_lengths,
|
cache_lengths,
|
||||||
cu_seqlen_q=None,
|
cu_seqlen_q=None,
|
||||||
max_q=None,
|
|
||||||
max_k=None,
|
|
||||||
):
|
):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.prefix_lengths = prefix_lengths
|
self.cache_lengths = cache_lengths
|
||||||
device = self.input_lengths.device
|
device = self.input_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.input_lengths.shape
|
||||||
if cu_seqlen_q is None:
|
if cu_seqlen_q is None:
|
||||||
@ -34,39 +97,51 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
max_q = 1
|
|
||||||
else:
|
|
||||||
assert max_q is not None
|
|
||||||
assert max_k is not None
|
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
# Although FA2 might not want the clamping
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
total = self.input_lengths + self.prefix_lengths
|
total = self.input_lengths + self.cache_lengths
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
self.max_q = max_q
|
|
||||||
self.max_k = max_k
|
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@dataclass
|
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
||||||
class Seqlen:
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
input_lengths: torch.Tensor
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
prefix_lengths: torch.Tensor
|
# Attention metadata is going to be hashed by PT bridge, and
|
||||||
cu_seqlen_q: torch.Tensor
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def clamp(self, max):
|
# Before you put more keys in here, make sure you know their
|
||||||
if SYSTEM == "rocm":
|
# value type and make sure you know how it's going to be hashed.
|
||||||
return self
|
# You can find that information in input_hash function
|
||||||
raise NotImplementedError("Not implemented seqlen for paged")
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||||
|
|
||||||
|
# If you use primitive types here - they will get hashed based
|
||||||
|
# on their value. You *will* get lots of excessive graph captures
|
||||||
|
# (and an OOM eventually) if you decide to put something like
|
||||||
|
# seq_len int here.
|
||||||
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||||
|
# get hashed using their metadata, not their values:
|
||||||
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||||
|
# input_hash(123) != input_hash(321)
|
||||||
|
# input_hash("abc") != input_hash("cba")
|
||||||
|
attention_metadata = subtuple(
|
||||||
|
metadata,
|
||||||
|
"TrimmedSeqlen",
|
||||||
|
[
|
||||||
|
"input_lengths",
|
||||||
|
"cache_lengths",
|
||||||
|
"cu_seqlen_q",
|
||||||
|
"cu_seqlen_k",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return attention_metadata
|
||||||
|
@ -1,357 +0,0 @@
|
|||||||
import torch
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.models.globals import (
|
|
||||||
ATTENTION,
|
|
||||||
BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
_PARTITION_SIZE = 512
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
max_s: int,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
|
||||||
# Copyright 2023 The vLLM team. All rights
|
|
||||||
# reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
|
||||||
# block_size = value_cache.shape[3]
|
|
||||||
block_size = BLOCK_SIZE
|
|
||||||
num_seqs, num_heads, head_size = query.shape
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
|
||||||
# to parallelize.
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
from text_generation_server.layers.attention.flashinfer import decode_state
|
|
||||||
|
|
||||||
return decode_state.get().forward(
|
|
||||||
query.contiguous(),
|
|
||||||
paged_kv_cache=(key_cache, value_cache),
|
|
||||||
logits_soft_cap=softcap,
|
|
||||||
sm_scale=softmax_scale,
|
|
||||||
)
|
|
||||||
elif ATTENTION == "flashdecoding":
|
|
||||||
max_q = 1
|
|
||||||
max_k = max_s
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
|
|
||||||
# TODO fixme when flash contains the fix.
|
|
||||||
# Number of splits is not correctly handled
|
|
||||||
# by the current path
|
|
||||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
|
||||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
|
||||||
if softcap is None:
|
|
||||||
softcap = 0.0
|
|
||||||
out = flash_attn_2_cuda.varlen_fwd(
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
None,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_k,
|
|
||||||
None, # pad_k
|
|
||||||
None,
|
|
||||||
block_tables,
|
|
||||||
None,
|
|
||||||
max_q,
|
|
||||||
max_k,
|
|
||||||
0.0, # dropout
|
|
||||||
softmax_scale,
|
|
||||||
False, # zero_tensors
|
|
||||||
True, # causal
|
|
||||||
-1, # Window_left
|
|
||||||
-1, # Window right
|
|
||||||
softcap,
|
|
||||||
False, # return softmax
|
|
||||||
None, # generator
|
|
||||||
)
|
|
||||||
return out[0]
|
|
||||||
else:
|
|
||||||
if softcap is not None:
|
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
|
||||||
input_lengths = seqlen.input_lengths
|
|
||||||
from vllm._C import ops
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
|
||||||
)
|
|
||||||
if use_v1:
|
|
||||||
ops.paged_attention_v1(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Run PagedAttention V2.
|
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
|
||||||
tmp_output = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
|
||||||
dtype=out.dtype,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
exp_sums = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
max_logits = torch.empty_like(exp_sums)
|
|
||||||
|
|
||||||
ops.paged_attention_v2(
|
|
||||||
out,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
is_ampere_or_newer = major >= 8 and minor >= 0
|
|
||||||
if not is_ampere_or_newer:
|
|
||||||
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
|
|
||||||
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
|
|
||||||
V2 = True
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
V2 = False
|
|
||||||
except ImportError as e:
|
|
||||||
if major >= 8:
|
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention V2 is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
||||||
)
|
|
||||||
elif is_sm75:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = V2
|
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap=0.0,
|
|
||||||
):
|
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
|
||||||
prefill_with_paged_kv_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
return prefill_with_paged_kv_state.get().forward(
|
|
||||||
q.contiguous(),
|
|
||||||
causal=causal,
|
|
||||||
paged_kv_cache=(key_cache, value_cache),
|
|
||||||
logits_soft_cap=softcap,
|
|
||||||
sm_scale=softmax_scale,
|
|
||||||
window_left=window_size_left,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif V2:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap=0.0,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_k,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
block_tables,
|
|
||||||
None,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
window_size_left,
|
|
||||||
0,
|
|
||||||
softcap,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap=None,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"window_size_left is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
||||||
if k.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if k.shape[1] == 1:
|
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = k.shape
|
|
||||||
k = (
|
|
||||||
k.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
if v.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = v.shape
|
|
||||||
v = (
|
|
||||||
v.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# Prefill in the cache with every kind of attention, unless we
|
|
||||||
# have a configuration that requires flash-attention v1, which
|
|
||||||
# does not support block tables.
|
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
|
@ -1,813 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""
|
|
||||||
Fused Attention
|
|
||||||
===============
|
|
||||||
|
|
||||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
|
||||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
|
||||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
|
||||||
|
|
||||||
Features supported:
|
|
||||||
|
|
||||||
1) Fwd with causal masking
|
|
||||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
|
||||||
3) Support for different sequence lengths for q and k
|
|
||||||
4) Nested tensor API currently does not support dropout or bias.
|
|
||||||
|
|
||||||
Not currently supported:
|
|
||||||
|
|
||||||
1) Non power of two head dims
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
torch_dtype: tl.constexpr = torch.float16
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def cdiv_fn(x, y):
|
|
||||||
return (x + y - 1) // y
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def max_fn(x, y):
|
|
||||||
return tl.math.max(x, y)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|
||||||
ms = tl.arange(0, m)
|
|
||||||
ns = tl.arange(0, n)
|
|
||||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|
||||||
rng_offsets = dropout_offsets(
|
|
||||||
philox_seed, philox_offset, dropout_p, m, n, stride
|
|
||||||
).to(tl.uint32)
|
|
||||||
# TODO: use tl.randint for better performance
|
|
||||||
return tl.rand(philox_seed, rng_offsets)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|
||||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
|
||||||
rng_keep = rng_output > dropout_p
|
|
||||||
return rng_keep
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def load_fn(block_ptr, first, second, pad):
|
|
||||||
if first and second:
|
|
||||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
|
||||||
elif first:
|
|
||||||
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
|
|
||||||
elif second:
|
|
||||||
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
|
|
||||||
else:
|
|
||||||
tensor = tl.load(block_ptr)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_fwd_inner(
|
|
||||||
acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
K_block_ptr,
|
|
||||||
V_block_ptr,
|
|
||||||
start_m,
|
|
||||||
actual_seqlen_k,
|
|
||||||
dropout_p,
|
|
||||||
philox_seed,
|
|
||||||
batch_philox_offset,
|
|
||||||
encoded_softmax_block_ptr,
|
|
||||||
block_min,
|
|
||||||
block_max,
|
|
||||||
offs_n_causal,
|
|
||||||
masked_blocks,
|
|
||||||
n_extra_tokens,
|
|
||||||
bias_ptr,
|
|
||||||
IS_CAUSAL: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
OFFS_M: tl.constexpr,
|
|
||||||
OFFS_N: tl.constexpr,
|
|
||||||
PRE_LOAD_V: tl.constexpr,
|
|
||||||
MASK_STEPS: tl.constexpr,
|
|
||||||
ENABLE_DROPOUT: tl.constexpr,
|
|
||||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
|
||||||
PADDED_HEAD: tl.constexpr,
|
|
||||||
):
|
|
||||||
# loop over k, v, and update accumulator
|
|
||||||
for start_n in range(block_min, block_max, BLOCK_N):
|
|
||||||
# For padded blocks, we will overrun the tensor size if
|
|
||||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
|
||||||
k = load_fn(
|
|
||||||
K_block_ptr,
|
|
||||||
PADDED_HEAD,
|
|
||||||
MASK_STEPS and (n_extra_tokens != 0),
|
|
||||||
"zero",
|
|
||||||
)
|
|
||||||
if PRE_LOAD_V:
|
|
||||||
v = load_fn(
|
|
||||||
V_block_ptr,
|
|
||||||
MASK_STEPS and (n_extra_tokens != 0),
|
|
||||||
PADDED_HEAD,
|
|
||||||
"zero",
|
|
||||||
)
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
# We start from end of seqlen_k so only the first iteration would need
|
|
||||||
# to be checked for padding if it is not a multiple of block_n
|
|
||||||
# TODO: This can be optimized to only be true for the padded block.
|
|
||||||
if MASK_STEPS: # noqa: SIM102
|
|
||||||
# If this is the last block / iteration, we want to
|
|
||||||
# mask if the sequence length is not a multiple of block size
|
|
||||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
|
||||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
|
||||||
# check if this masking works for that case.
|
|
||||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
|
||||||
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
|
||||||
size_n = start_n + OFFS_N[None, :]
|
|
||||||
mask = size_n < boundary_m[:, None]
|
|
||||||
qk = tl.where(mask, qk, float("-inf"))
|
|
||||||
if IS_CAUSAL:
|
|
||||||
causal_boundary = start_n + offs_n_causal
|
|
||||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
|
||||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
|
||||||
# -- compute qk ----
|
|
||||||
qk += tl.dot(q, k)
|
|
||||||
if bias_ptr is not None:
|
|
||||||
bias = load_fn(
|
|
||||||
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
|
||||||
)
|
|
||||||
# While bias is added after multiplying qk with sm_scale, our
|
|
||||||
# optimization to use 2^x instead of e^x results in an additional
|
|
||||||
# scale factor of log2(e) which we must also multiply the bias with.
|
|
||||||
qk += bias * 1.44269504089
|
|
||||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
|
||||||
qk = qk - m_ij[:, None]
|
|
||||||
p = tl.math.exp2(qk)
|
|
||||||
|
|
||||||
# CAVEAT: Must update l_ij before applying dropout
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
if ENABLE_DROPOUT:
|
|
||||||
philox_offset = (
|
|
||||||
batch_philox_offset
|
|
||||||
+ start_m * BLOCK_M * actual_seqlen_k
|
|
||||||
+ start_n
|
|
||||||
- BLOCK_N
|
|
||||||
)
|
|
||||||
keep = dropout_mask(
|
|
||||||
philox_seed,
|
|
||||||
philox_offset,
|
|
||||||
dropout_p,
|
|
||||||
BLOCK_M,
|
|
||||||
BLOCK_N,
|
|
||||||
actual_seqlen_k,
|
|
||||||
)
|
|
||||||
if RETURN_ENCODED_SOFTMAX:
|
|
||||||
tl.store(
|
|
||||||
encoded_softmax_block_ptr,
|
|
||||||
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
|
||||||
)
|
|
||||||
p = tl.where(keep, p, 0.0)
|
|
||||||
elif RETURN_ENCODED_SOFTMAX:
|
|
||||||
tl.store(
|
|
||||||
encoded_softmax_block_ptr,
|
|
||||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
|
||||||
)
|
|
||||||
# -- update output accumulator --
|
|
||||||
alpha = tl.math.exp2(m_i - m_ij)
|
|
||||||
acc = acc * alpha[:, None]
|
|
||||||
if not PRE_LOAD_V:
|
|
||||||
v = load_fn(
|
|
||||||
V_block_ptr,
|
|
||||||
MASK_STEPS and (n_extra_tokens != 0),
|
|
||||||
PADDED_HEAD,
|
|
||||||
"zero",
|
|
||||||
)
|
|
||||||
# -- update m_i and l_i
|
|
||||||
l_i = l_i * alpha + l_ij
|
|
||||||
# update m_i and l_i
|
|
||||||
m_i = m_ij
|
|
||||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
|
||||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
|
||||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
||||||
if bias_ptr is not None:
|
|
||||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
|
||||||
if RETURN_ENCODED_SOFTMAX:
|
|
||||||
encoded_softmax_block_ptr = tl.advance(
|
|
||||||
encoded_softmax_block_ptr, (0, BLOCK_N)
|
|
||||||
)
|
|
||||||
return acc, l_i, m_i
|
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
|
||||||
configs=[
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 256,
|
|
||||||
"BLOCK_N": 64,
|
|
||||||
"waves_per_eu": 2,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=8,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 128,
|
|
||||||
"BLOCK_N": 128,
|
|
||||||
"waves_per_eu": 2,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 256,
|
|
||||||
"BLOCK_N": 128,
|
|
||||||
"waves_per_eu": 2,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=8,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 128,
|
|
||||||
"BLOCK_N": 64,
|
|
||||||
"waves_per_eu": 3,
|
|
||||||
"PRE_LOAD_V": True,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 128,
|
|
||||||
"BLOCK_N": 64,
|
|
||||||
"waves_per_eu": 3,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 64,
|
|
||||||
"BLOCK_N": 64,
|
|
||||||
"waves_per_eu": 4,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=8,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 32,
|
|
||||||
"BLOCK_N": 32,
|
|
||||||
"waves_per_eu": 4,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=8,
|
|
||||||
),
|
|
||||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
|
||||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
|
||||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 16,
|
|
||||||
"BLOCK_N": 16,
|
|
||||||
"waves_per_eu": 1,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_M": 128,
|
|
||||||
"BLOCK_N": 64,
|
|
||||||
"waves_per_eu": 1,
|
|
||||||
"PRE_LOAD_V": False,
|
|
||||||
},
|
|
||||||
num_stages=1,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
|
||||||
)
|
|
||||||
@triton.jit
|
|
||||||
def attn_fwd(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
bias,
|
|
||||||
sm_scale,
|
|
||||||
L,
|
|
||||||
Out,
|
|
||||||
stride_qz,
|
|
||||||
stride_qh,
|
|
||||||
stride_qm,
|
|
||||||
stride_qk,
|
|
||||||
stride_kz,
|
|
||||||
stride_kh,
|
|
||||||
stride_kn,
|
|
||||||
stride_kk,
|
|
||||||
stride_vz,
|
|
||||||
stride_vh,
|
|
||||||
stride_vk,
|
|
||||||
stride_vn,
|
|
||||||
stride_oz,
|
|
||||||
stride_oh,
|
|
||||||
stride_om,
|
|
||||||
stride_on,
|
|
||||||
stride_bz,
|
|
||||||
stride_bh,
|
|
||||||
stride_bm,
|
|
||||||
stride_bn,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
dropout_p,
|
|
||||||
philox_seed,
|
|
||||||
philox_offset_base,
|
|
||||||
encoded_softmax,
|
|
||||||
HQ: tl.constexpr,
|
|
||||||
HK: tl.constexpr,
|
|
||||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
|
||||||
MAX_SEQLENS_Q: tl.constexpr,
|
|
||||||
MAX_SEQLENS_K: tl.constexpr,
|
|
||||||
VARLEN: tl.constexpr,
|
|
||||||
IS_CAUSAL: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
PRE_LOAD_V: tl.constexpr,
|
|
||||||
BIAS_TYPE: tl.constexpr,
|
|
||||||
ENABLE_DROPOUT: tl.constexpr,
|
|
||||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
|
||||||
):
|
|
||||||
start_m = tl.program_id(0)
|
|
||||||
off_h_q = tl.program_id(1)
|
|
||||||
off_z = tl.program_id(2)
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
if VARLEN:
|
|
||||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
|
||||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
|
||||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
|
||||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
|
||||||
# small for all start_m so for those we return early.
|
|
||||||
if start_m * BLOCK_M > seqlen_q:
|
|
||||||
return
|
|
||||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
|
||||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
|
||||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
|
||||||
else:
|
|
||||||
cu_seqlens_q_start = 0
|
|
||||||
cu_seqlens_k_start = 0
|
|
||||||
seqlen_q = MAX_SEQLENS_Q
|
|
||||||
seqlen_k = MAX_SEQLENS_K
|
|
||||||
|
|
||||||
# Now we compute whether we need to exit early due to causal masking.
|
|
||||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
|
||||||
# are completely masked, resulting in 0s written to the output, and
|
|
||||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
|
||||||
# This block of code determines what N is, and if this WG is operating
|
|
||||||
# on those M rows.
|
|
||||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
|
||||||
if IS_CAUSAL:
|
|
||||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
|
||||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
|
||||||
# the causal mask boundary is bottom right aligned, and ends at either
|
|
||||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
|
||||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
|
||||||
# matrix
|
|
||||||
n_blocks_seqlen = cdiv_fn(
|
|
||||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
|
||||||
)
|
|
||||||
# This is what adjusts the block_max for the current WG, only
|
|
||||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
|
||||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
|
||||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
|
||||||
# part of the blocks that are all 0. We exit early.
|
|
||||||
if n_blocks <= 0:
|
|
||||||
o_offset = (
|
|
||||||
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
|
||||||
)
|
|
||||||
O_block_ptr = tl.make_block_ptr(
|
|
||||||
base=Out + o_offset,
|
|
||||||
shape=(seqlen_q, BLOCK_DMODEL),
|
|
||||||
strides=(stride_om, stride_on),
|
|
||||||
offsets=(start_m * BLOCK_M, 0),
|
|
||||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
||||||
order=(1, 0),
|
|
||||||
)
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
|
||||||
# We still need to write 0s to the result
|
|
||||||
# tl.store(O_block_ptr,
|
|
||||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
|
||||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
|
||||||
# + offs_m
|
|
||||||
# We store inf to LSE, not -inf because in the bwd pass,
|
|
||||||
# we subtract this
|
|
||||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
|
||||||
# for these masked blocks.
|
|
||||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
|
||||||
# tl.store(l_ptrs, l)
|
|
||||||
# TODO: Should dropout and return encoded softmax be handled here?
|
|
||||||
return
|
|
||||||
|
|
||||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
|
||||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
|
||||||
if GROUP_SIZE != 1:
|
|
||||||
off_h_k = off_h_q // GROUP_SIZE
|
|
||||||
else:
|
|
||||||
off_h_k = off_h_q
|
|
||||||
|
|
||||||
n_extra_tokens = 0
|
|
||||||
if seqlen_k < BLOCK_N:
|
|
||||||
n_extra_tokens = BLOCK_N - seqlen_k
|
|
||||||
elif seqlen_k % BLOCK_N:
|
|
||||||
n_extra_tokens = seqlen_k % BLOCK_N
|
|
||||||
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
|
||||||
|
|
||||||
# Compute pointers for all the tensors used in this kernel.
|
|
||||||
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
|
||||||
Q_block_ptr = tl.make_block_ptr(
|
|
||||||
base=Q + q_offset,
|
|
||||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
|
||||||
strides=(stride_qm, stride_qk),
|
|
||||||
offsets=(start_m * BLOCK_M, 0),
|
|
||||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
||||||
order=(1, 0),
|
|
||||||
)
|
|
||||||
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
|
||||||
K_block_ptr = tl.make_block_ptr(
|
|
||||||
base=K + k_offset,
|
|
||||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
|
||||||
strides=(stride_kk, stride_kn),
|
|
||||||
offsets=(0, 0),
|
|
||||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
||||||
order=(0, 1),
|
|
||||||
)
|
|
||||||
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
|
||||||
V_block_ptr = tl.make_block_ptr(
|
|
||||||
base=V + v_offset,
|
|
||||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
|
||||||
strides=(stride_vk, stride_vn),
|
|
||||||
offsets=(0, 0),
|
|
||||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
|
||||||
order=(1, 0),
|
|
||||||
)
|
|
||||||
if BIAS_TYPE != 0:
|
|
||||||
bias_ptr = tl.make_block_ptr(
|
|
||||||
base=bias + off_h_q * stride_bh,
|
|
||||||
shape=(seqlen_q, seqlen_k),
|
|
||||||
strides=(stride_bm, stride_bn),
|
|
||||||
offsets=(start_m * BLOCK_M, 0),
|
|
||||||
block_shape=(BLOCK_M, BLOCK_N),
|
|
||||||
order=(1, 0),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
bias_ptr = None
|
|
||||||
if ENABLE_DROPOUT:
|
|
||||||
batch_philox_offset = (
|
|
||||||
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch_philox_offset = 0
|
|
||||||
# We can ask to return the dropout mask without actually doing any dropout.
|
|
||||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
|
||||||
# valid.
|
|
||||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
|
||||||
if RETURN_ENCODED_SOFTMAX:
|
|
||||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
|
||||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
|
||||||
shape=(seqlen_q, seqlen_k),
|
|
||||||
strides=(seqlen_k, 1),
|
|
||||||
offsets=(start_m * BLOCK_M, 0),
|
|
||||||
block_shape=(BLOCK_M, BLOCK_N),
|
|
||||||
order=(1, 0),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoded_softmax_block_ptr = 0
|
|
||||||
# initialize pointer to m and l
|
|
||||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
|
||||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
||||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
|
||||||
# have native e^x support in HW.
|
|
||||||
qk_scale = sm_scale * 1.44269504089
|
|
||||||
# Q is loaded once at the beginning and shared by all N blocks.
|
|
||||||
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
|
|
||||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
|
||||||
|
|
||||||
# Here we compute how many full and masked blocks we have.
|
|
||||||
padded_block_k = n_extra_tokens != 0
|
|
||||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
|
||||||
if IS_CAUSAL:
|
|
||||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
|
||||||
# Additionally there might be one more due to dissimilar seqlens.
|
|
||||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
|
||||||
else:
|
|
||||||
# Padding on Q does not need to be masked in the FA loop.
|
|
||||||
masked_blocks = padded_block_k
|
|
||||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
|
||||||
# block. In this case we might exceed n_blocks so pick the min.
|
|
||||||
masked_blocks = min(masked_blocks, n_blocks)
|
|
||||||
n_full_blocks = n_blocks - masked_blocks
|
|
||||||
block_min = 0
|
|
||||||
block_max = n_blocks * BLOCK_N
|
|
||||||
# Compute for full blocks. Here we set causal to false regardless of its
|
|
||||||
# value because there is no masking. Similarly we do not need padding.
|
|
||||||
if n_full_blocks > 0:
|
|
||||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
|
||||||
acc, l_i, m_i = _attn_fwd_inner(
|
|
||||||
acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
K_block_ptr,
|
|
||||||
V_block_ptr,
|
|
||||||
start_m,
|
|
||||||
seqlen_k,
|
|
||||||
dropout_p,
|
|
||||||
philox_seed,
|
|
||||||
batch_philox_offset,
|
|
||||||
encoded_softmax_block_ptr,
|
|
||||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
|
||||||
block_min,
|
|
||||||
block_max,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
bias_ptr,
|
|
||||||
# IS_CAUSAL, ....
|
|
||||||
False,
|
|
||||||
BLOCK_M,
|
|
||||||
BLOCK_DMODEL,
|
|
||||||
BLOCK_N,
|
|
||||||
offs_m,
|
|
||||||
offs_n,
|
|
||||||
# _, MASK_STEPS, ...
|
|
||||||
PRE_LOAD_V,
|
|
||||||
False,
|
|
||||||
ENABLE_DROPOUT,
|
|
||||||
RETURN_ENCODED_SOFTMAX,
|
|
||||||
PADDED_HEAD,
|
|
||||||
)
|
|
||||||
block_min = block_max
|
|
||||||
block_max = n_blocks * BLOCK_N
|
|
||||||
|
|
||||||
tl.debug_barrier()
|
|
||||||
# Remaining blocks, if any, are full / not masked.
|
|
||||||
if masked_blocks > 0:
|
|
||||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
|
||||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
|
||||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
|
||||||
if bias_ptr is not None:
|
|
||||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
|
||||||
if RETURN_ENCODED_SOFTMAX:
|
|
||||||
encoded_softmax_block_ptr = tl.advance(
|
|
||||||
encoded_softmax_block_ptr, (0, n_full_blocks)
|
|
||||||
)
|
|
||||||
acc, l_i, m_i = _attn_fwd_inner(
|
|
||||||
acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
K_block_ptr,
|
|
||||||
V_block_ptr,
|
|
||||||
start_m,
|
|
||||||
seqlen_k,
|
|
||||||
dropout_p,
|
|
||||||
philox_seed,
|
|
||||||
batch_philox_offset,
|
|
||||||
encoded_softmax_block_ptr,
|
|
||||||
block_min,
|
|
||||||
block_max,
|
|
||||||
offs_n_causal,
|
|
||||||
masked_blocks,
|
|
||||||
n_extra_tokens,
|
|
||||||
bias_ptr,
|
|
||||||
IS_CAUSAL,
|
|
||||||
BLOCK_M,
|
|
||||||
BLOCK_DMODEL,
|
|
||||||
BLOCK_N,
|
|
||||||
offs_m,
|
|
||||||
offs_n,
|
|
||||||
# _, MASK_STEPS, ...
|
|
||||||
PRE_LOAD_V,
|
|
||||||
True,
|
|
||||||
ENABLE_DROPOUT,
|
|
||||||
RETURN_ENCODED_SOFTMAX,
|
|
||||||
PADDED_HEAD,
|
|
||||||
)
|
|
||||||
# epilogue
|
|
||||||
acc = acc / l_i[:, None]
|
|
||||||
if ENABLE_DROPOUT:
|
|
||||||
acc = acc / (1 - dropout_p)
|
|
||||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
|
||||||
# then we have one block with a row of all NaNs which come from computing
|
|
||||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
|
||||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
|
||||||
end_m_idx = (start_m + 1) * BLOCK_M
|
|
||||||
start_m_idx = start_m * BLOCK_M
|
|
||||||
causal_start_idx = seqlen_q - seqlen_k
|
|
||||||
acc = acc.to(Out.type.element_ty)
|
|
||||||
if IS_CAUSAL: # noqa: SIM102
|
|
||||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
|
||||||
out_mask_boundary = tl.full(
|
|
||||||
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
|
||||||
)
|
|
||||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
|
||||||
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
|
||||||
z = 0.0
|
|
||||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
|
||||||
# write back LSE
|
|
||||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
|
||||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
|
||||||
# few rows. This is only true for the last M block. For others,
|
|
||||||
# overflow_size will be -ve
|
|
||||||
# overflow_size = end_m_idx - seqlen_q
|
|
||||||
# if overflow_size > 0:
|
|
||||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
|
||||||
# # This is a > check because mask being 0 blocks the store.
|
|
||||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
|
||||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
|
||||||
# else:
|
|
||||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
|
||||||
|
|
||||||
# write back O
|
|
||||||
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
|
||||||
O_block_ptr = tl.make_block_ptr(
|
|
||||||
base=Out + o_offset,
|
|
||||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
|
||||||
strides=(stride_om, stride_on),
|
|
||||||
offsets=(start_m * BLOCK_M, 0),
|
|
||||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
||||||
order=(1, 0),
|
|
||||||
)
|
|
||||||
# Need boundary check on this to make sure the padding from the
|
|
||||||
# Q and KV tensors in both dims are not part of what we store back.
|
|
||||||
# TODO: Do the boundary check optionally.
|
|
||||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
|
||||||
|
|
||||||
|
|
||||||
def check_args(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
o,
|
|
||||||
varlen=True,
|
|
||||||
max_seqlens=None,
|
|
||||||
cu_seqlens_q=None,
|
|
||||||
cu_seqlens_k=None,
|
|
||||||
):
|
|
||||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
|
||||||
if varlen:
|
|
||||||
assert q.dim() == 3
|
|
||||||
total_q, nheads_q, head_size = q.shape
|
|
||||||
total_k, nheads_k, _ = k.shape
|
|
||||||
assert cu_seqlens_q is not None
|
|
||||||
assert cu_seqlens_k is not None
|
|
||||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
|
||||||
else:
|
|
||||||
assert q.dim() == 4
|
|
||||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
|
||||||
_, nheads_k, seqlen_k, _ = k.shape
|
|
||||||
assert max_seqlens > 0
|
|
||||||
assert k.shape == v.shape
|
|
||||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
|
||||||
# TODO: Change assert if we support qkl f8 and v f16
|
|
||||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
|
||||||
# TODO: Fix assert to check head size <=256 once supported
|
|
||||||
assert head_size <= 128
|
|
||||||
assert o.shape == q.shape
|
|
||||||
assert (nheads_q % nheads_k) == 0
|
|
||||||
|
|
||||||
|
|
||||||
class _attention(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
o,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlens_q,
|
|
||||||
max_seqlens_k,
|
|
||||||
causal=False,
|
|
||||||
sm_scale=1.0,
|
|
||||||
bias=None,
|
|
||||||
):
|
|
||||||
if o is None:
|
|
||||||
o = torch.empty_like(q, dtype=v.dtype)
|
|
||||||
|
|
||||||
check_args(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
o,
|
|
||||||
varlen=True,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
)
|
|
||||||
if True: # varlen
|
|
||||||
total_q, nheads_q, head_size = q.shape
|
|
||||||
total_k, nheads_k, _ = k.shape
|
|
||||||
batch = len(cu_seqlens_q) - 1
|
|
||||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
|
||||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
|
||||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
|
||||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
|
||||||
else:
|
|
||||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
|
||||||
_, seqlen_k, nheads_k, _ = k.shape
|
|
||||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
|
||||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
|
||||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
|
||||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
|
||||||
|
|
||||||
# Get closest power of 2 over or equal to 32.
|
|
||||||
padded_d_model = 1 << (head_size - 1).bit_length()
|
|
||||||
padded_d_model = max(padded_d_model, 16)
|
|
||||||
|
|
||||||
def grid(META):
|
|
||||||
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
|
|
||||||
|
|
||||||
encoded_softmax = None
|
|
||||||
|
|
||||||
# Seed the RNG so we get reproducible results for testing.
|
|
||||||
philox_seed = 0x1BF52
|
|
||||||
philox_offset = 0x1D4B42
|
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
bias_strides = (
|
|
||||||
bias.stride(0),
|
|
||||||
bias.stride(1),
|
|
||||||
bias.stride(2),
|
|
||||||
bias.stride(3),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
bias_strides = (0, 0, 0, 0)
|
|
||||||
|
|
||||||
attn_fwd[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
bias,
|
|
||||||
sm_scale,
|
|
||||||
None,
|
|
||||||
o,
|
|
||||||
*q_strides,
|
|
||||||
*k_strides,
|
|
||||||
*v_strides,
|
|
||||||
*o_strides,
|
|
||||||
*bias_strides,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
dropout_p=0.0,
|
|
||||||
philox_seed=philox_seed,
|
|
||||||
philox_offset_base=philox_offset,
|
|
||||||
encoded_softmax=encoded_softmax,
|
|
||||||
HQ=nheads_q,
|
|
||||||
HK=nheads_k,
|
|
||||||
ACTUAL_BLOCK_DMODEL=head_size,
|
|
||||||
MAX_SEQLENS_Q=max_seqlens_q,
|
|
||||||
MAX_SEQLENS_K=max_seqlens_k,
|
|
||||||
IS_CAUSAL=causal,
|
|
||||||
VARLEN=True,
|
|
||||||
BLOCK_DMODEL=padded_d_model,
|
|
||||||
BIAS_TYPE=0 if bias is None else 1,
|
|
||||||
ENABLE_DROPOUT=False,
|
|
||||||
RETURN_ENCODED_SOFTMAX=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.grid = grid
|
|
||||||
ctx.sm_scale = sm_scale
|
|
||||||
ctx.BLOCK_DMODEL = head_size
|
|
||||||
ctx.causal = causal
|
|
||||||
ctx.dropout_p = 0.0
|
|
||||||
ctx.philox_seed = philox_seed
|
|
||||||
ctx.philox_offset = philox_offset
|
|
||||||
ctx.encoded_softmax = encoded_softmax
|
|
||||||
ctx.return_encoded_softmax = False
|
|
||||||
return o, encoded_softmax
|
|
||||||
|
|
||||||
|
|
||||||
triton_attention = _attention.apply
|
|
@ -1,251 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import flashinfer
|
|
||||||
import torch
|
|
||||||
|
|
||||||
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
|
|
||||||
"prefill_state"
|
|
||||||
)
|
|
||||||
|
|
||||||
prefill_with_paged_kv_state: ContextVar[
|
|
||||||
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
|
||||||
] = ContextVar("prefill_with_paged_kv_state")
|
|
||||||
|
|
||||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
|
||||||
"decode_state"
|
|
||||||
)
|
|
||||||
|
|
||||||
workspace: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_workspace(device):
|
|
||||||
"""Get shared flashinfer workspace."""
|
|
||||||
global workspace
|
|
||||||
if workspace is None:
|
|
||||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
|
||||||
return workspace
|
|
||||||
|
|
||||||
|
|
||||||
def create_prefill_with_paged_kv_state(
|
|
||||||
*,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
"""Create a prefill state that uses the KV cache."""
|
|
||||||
workspace_buffer = get_workspace(device)
|
|
||||||
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
||||||
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def use_prefill_with_paged_kv_state(
|
|
||||||
*,
|
|
||||||
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
page_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
window_left: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Context manager to set the active flashinfer prefill state to the given
|
|
||||||
`state` and parameters. This state will be used by all calls to the
|
|
||||||
`attention` function while the context manager is active.
|
|
||||||
"""
|
|
||||||
|
|
||||||
indptr = torch.zeros(
|
|
||||||
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
# Round up to page size and then calculate the cumulative sum to get
|
|
||||||
# the indices into the block table.
|
|
||||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
|
||||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
|
||||||
indptr[1:].cumsum_(-1)
|
|
||||||
|
|
||||||
# Get the lengths of the last page in a block.
|
|
||||||
if page_size == 1:
|
|
||||||
last_page_len = torch.ones(
|
|
||||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
last_page_len = torch.empty(
|
|
||||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
||||||
)
|
|
||||||
torch.sub(input_lengths, 1, out=last_page_len)
|
|
||||||
last_page_len.remainder_(page_size)
|
|
||||||
last_page_len += 1
|
|
||||||
|
|
||||||
token = prefill_with_paged_kv_state.set(state)
|
|
||||||
try:
|
|
||||||
state.begin_forward(
|
|
||||||
qo_indptr=cu_seqlens,
|
|
||||||
paged_kv_indptr=indptr,
|
|
||||||
paged_kv_indices=block_tables,
|
|
||||||
paged_kv_last_page_len=last_page_len,
|
|
||||||
num_qo_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_size,
|
|
||||||
q_data_type=dtype,
|
|
||||||
page_size=page_size,
|
|
||||||
window_left=window_left,
|
|
||||||
)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
|
||||||
prefill_with_paged_kv_state.reset(token)
|
|
||||||
|
|
||||||
|
|
||||||
def create_prefill_state(
|
|
||||||
*,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
"""Create a prefill state."""
|
|
||||||
workspace_buffer = get_workspace(device)
|
|
||||||
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
|
||||||
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def use_prefill_state(
|
|
||||||
*,
|
|
||||||
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
window_left: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Context manager to set the active flashinfer prefill state to the given
|
|
||||||
`state` and parameters. This state will be used by all calls to the
|
|
||||||
`attention` function while the context manager is active.
|
|
||||||
"""
|
|
||||||
|
|
||||||
token = prefill_state.set(state)
|
|
||||||
try:
|
|
||||||
state.begin_forward(
|
|
||||||
qo_indptr=cu_seqlens,
|
|
||||||
kv_indptr=cu_seqlens,
|
|
||||||
num_qo_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_size,
|
|
||||||
q_data_type=dtype,
|
|
||||||
window_left=window_left,
|
|
||||||
)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
|
||||||
prefill_state.reset(token)
|
|
||||||
|
|
||||||
|
|
||||||
def create_decode_state(
|
|
||||||
*,
|
|
||||||
device: torch.device,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
):
|
|
||||||
"""Create a decode state."""
|
|
||||||
workspace_buffer = get_workspace(device)
|
|
||||||
num_groups = num_heads // num_kv_heads
|
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
workspace_buffer,
|
|
||||||
kv_layout="NHD",
|
|
||||||
use_cuda_graph=False,
|
|
||||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
|
||||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_decode_state_cuda_graphs(
|
|
||||||
*,
|
|
||||||
device: torch.device,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
block_tables_ptr: torch.Tensor,
|
|
||||||
last_page_len: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a decode state for use with CUDA Graphs. `block_tables`,
|
|
||||||
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
|
|
||||||
therefore stored as part of the state.
|
|
||||||
"""
|
|
||||||
workspace_buffer = get_workspace(device)
|
|
||||||
num_groups = num_heads // num_kv_heads
|
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
workspace_buffer,
|
|
||||||
kv_layout="NHD",
|
|
||||||
use_cuda_graph=True,
|
|
||||||
paged_kv_indices_buffer=block_tables,
|
|
||||||
paged_kv_indptr_buffer=block_tables_ptr,
|
|
||||||
paged_kv_last_page_len_buffer=last_page_len,
|
|
||||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
|
||||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def use_decode_state(
|
|
||||||
*,
|
|
||||||
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
page_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
window_left: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Context manager to set the active flashinfer decoding state to the given
|
|
||||||
`state` and parameters. This state will be used by all calls to the
|
|
||||||
`paged_attention` function while the context manager is active.
|
|
||||||
"""
|
|
||||||
indptr = torch.zeros(
|
|
||||||
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
# Round up to page size and then calculate the cumulative sum to get
|
|
||||||
# the indices into the block table.
|
|
||||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
|
||||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
|
||||||
indptr[1:].cumsum_(-1)
|
|
||||||
|
|
||||||
# Get the lengths of the last page in a block.
|
|
||||||
last_page_len = torch.empty(
|
|
||||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
||||||
)
|
|
||||||
torch.sub(input_lengths, 1, out=last_page_len)
|
|
||||||
last_page_len.remainder_(page_size)
|
|
||||||
last_page_len += 1
|
|
||||||
|
|
||||||
token = decode_state.set(state)
|
|
||||||
|
|
||||||
try:
|
|
||||||
state.begin_forward(
|
|
||||||
indptr=indptr,
|
|
||||||
indices=block_tables,
|
|
||||||
last_page_len=last_page_len,
|
|
||||||
num_qo_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_size,
|
|
||||||
page_size=page_size,
|
|
||||||
data_type=dtype,
|
|
||||||
q_data_type=dtype,
|
|
||||||
window_left=window_left,
|
|
||||||
)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
|
||||||
decode_state.reset(token)
|
|
@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from typing import Optional
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
|
from vllm_hpu_extension import ops
|
||||||
|
from vllm_hpu_extension.utils import Matmul
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
import os
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_cache(cache, blocks):
|
||||||
|
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||||
|
return cache[: blocks.size(0)]
|
||||||
|
else:
|
||||||
|
return cache.index_select(0, blocks)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
*,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size_left: int = -1,
|
||||||
|
causal: bool = True,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
):
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
bs = seqlen.input_lengths.shape[0]
|
||||||
|
_, head_num, head_size = query.shape
|
||||||
|
_, kv_head_num, head_size = key.shape
|
||||||
|
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
||||||
|
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
|
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=softmax_scale,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=seqlen.input_lengths,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||||
|
):
|
||||||
|
batch_size, head_num, head_size = query.shape
|
||||||
|
output = ops.flat_pa(
|
||||||
|
query=query.view(batch_size, 1, head_num * head_size),
|
||||||
|
key_cache=kv_cache.key,
|
||||||
|
value_cache=kv_cache.value,
|
||||||
|
block_list=hpu_attention_meta.block_list,
|
||||||
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
|
block_scales=hpu_attention_meta.block_scales,
|
||||||
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
|
scale=softmax_scale,
|
||||||
|
matmul_qk_op=Matmul(),
|
||||||
|
matmul_av_op=Matmul(),
|
||||||
|
batch2block_matmul_op=Matmul(),
|
||||||
|
block2batch_matmul_op=Matmul(),
|
||||||
|
keys_fetch_func=fetch_from_cache,
|
||||||
|
values_fetch_func=fetch_from_cache,
|
||||||
|
)
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(batch_size, head_num, head_size)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
]
|
@ -1,82 +0,0 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
import torch
|
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
|
||||||
PREFILL_IN_KV_CACHE = False
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
|
||||||
ipex.llm.functional.varlen_attention(
|
|
||||||
q.contiguous() if q.device.type == "xpu" else q,
|
|
||||||
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
|
||||||
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_q,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
max_s: int,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(query)
|
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
seqlen.input_lengths,
|
|
||||||
BLOCK_SIZE,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
return out
|
|
@ -0,0 +1,139 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from vllm_hpu_extension import cache_ops
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVScales:
|
||||||
|
"""
|
||||||
|
Key-value scales for FP8 KV cache.
|
||||||
|
|
||||||
|
This data class stores key and value scales both as a GPU tensor and
|
||||||
|
as a GPU float. This inconvenience is necessary because some functions
|
||||||
|
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
||||||
|
(e.g. flashinfer) take scales as a CPU scalar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_scale: torch.Tensor
|
||||||
|
value_scale: torch.Tensor
|
||||||
|
key_scale_cpu: float = field(init=False)
|
||||||
|
value_scale_cpu: float = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
||||||
|
raise ValueError("Key and value scales must be scalar tensors.")
|
||||||
|
|
||||||
|
self.key_scale_cpu = self.key_scale.item()
|
||||||
|
self.value_scale_cpu = self.value_scale.item()
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
"""
|
||||||
|
Key-value cache for attention layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Construct the key-value cache for a layer."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
self.kv_cache = (
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the data type of the cache."""
|
||||||
|
return self.kv_cache[0].dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
"""Get the key cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
"""Get the value cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[1]
|
||||||
|
|
||||||
|
def store(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
):
|
||||||
|
"""Store the key and value at the given slots."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
key_cache = self.kv_cache[0]
|
||||||
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
|
paged_reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
block_idx = slots // BLOCK_SIZE
|
||||||
|
block_offset = slots % BLOCK_SIZE
|
||||||
|
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||||
|
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
"""Load KV cache scales."""
|
||||||
|
|
||||||
|
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
||||||
|
value_scale = key_scale
|
||||||
|
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
||||||
|
f"{prefix}.v_scale"
|
||||||
|
):
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
||||||
|
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
||||||
|
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
||||||
|
# Fall back to older more coarse-grained scale when available.
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
||||||
|
value_scale = key_scale
|
||||||
|
|
||||||
|
return KVScales(key_scale=key_scale, value_scale=value_scale)
|
@ -1,308 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
|
|
||||||
_PARTITION_SIZE_V1V2 = 512
|
|
||||||
_PARTITION_SIZE_CUSTOM = 256
|
|
||||||
|
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
|
||||||
|
|
||||||
PREFILL_IN_KV_CACHE = False
|
|
||||||
|
|
||||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
|
||||||
try:
|
|
||||||
if use_rocm_custom_paged_attn:
|
|
||||||
from vllm._custom_C import paged_attention_custom
|
|
||||||
except ImportError as e:
|
|
||||||
log_master(
|
|
||||||
logger.info,
|
|
||||||
f"Custom Paged Attention not available. Complete error: {e}",
|
|
||||||
)
|
|
||||||
use_rocm_custom_paged_attn = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import vllm._custom_ops as ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if ATTENTION == "flashdecoding":
|
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
max_s: int,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
|
||||||
# Copyright 2023 The vLLM team. All rights
|
|
||||||
# reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
if softcap is not None:
|
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
num_seqs, num_heads, head_size = query.shape
|
|
||||||
|
|
||||||
num_kv_heads = key_cache.shape[1]
|
|
||||||
gqa_ratio = num_heads // num_kv_heads
|
|
||||||
use_custom = (
|
|
||||||
use_rocm_custom_paged_attn
|
|
||||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
|
||||||
and (head_size == 128 or head_size == 64)
|
|
||||||
and (block_size == 16 or block_size == 32)
|
|
||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
|
||||||
and max_s <= 32768
|
|
||||||
)
|
|
||||||
|
|
||||||
if not use_custom:
|
|
||||||
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
|
||||||
else:
|
|
||||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
|
||||||
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
|
||||||
input_lengths = seqlen.input_lengths
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
|
||||||
# to parallelize.
|
|
||||||
import vllm._custom_ops as ops
|
|
||||||
|
|
||||||
use_v1 = (
|
|
||||||
max_s <= 8192
|
|
||||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
|
||||||
and not use_custom
|
|
||||||
)
|
|
||||||
if use_v1:
|
|
||||||
ops.paged_attention_v1(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Run PagedAttention V2.
|
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
|
||||||
tmp_output = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
|
||||||
dtype=out.dtype,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
exp_sums = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
max_logits = torch.empty_like(exp_sums)
|
|
||||||
|
|
||||||
if not use_custom:
|
|
||||||
ops.paged_attention_v2(
|
|
||||||
out,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
paged_attention_custom(
|
|
||||||
out,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
num_kv_heads,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
if ENGINE != "triton":
|
|
||||||
try:
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
|
|
||||||
log_master(
|
|
||||||
logger.info,
|
|
||||||
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
if major >= 8:
|
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention V2 is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
||||||
)
|
|
||||||
elif is_sm75:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
|
||||||
name = torch.cuda.get_device_name(idx)
|
|
||||||
if "MI210" not in name and "MI250" not in name:
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
||||||
)
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU with ROCm capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
|
||||||
if ENGINE == "ck":
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap: float = 0.0,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
window_size_left,
|
|
||||||
0,
|
|
||||||
softcap,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
elif ENGINE == "triton":
|
|
||||||
from .flash_attn_triton import triton_attention
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
|
||||||
output, _ = triton_attention(
|
|
||||||
q,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
causal,
|
|
||||||
softmax_scale,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
|
@ -0,0 +1,3 @@
|
|||||||
|
from .hpu import WQLinear
|
||||||
|
|
||||||
|
__all__ = ["WQLinear"]
|
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import habana_frameworks.torch.hpu # noqa: F401
|
||||||
|
|
||||||
|
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||||
|
except Exception as e:
|
||||||
|
hpu_import_exception = e
|
||||||
|
|
||||||
|
def error_raiser_hpu(*args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_from_uint4 = error_raiser_hpu
|
||||||
|
|
||||||
|
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||||
|
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||||
|
|
||||||
|
# unpacking columnwise
|
||||||
|
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||||
|
torch.int8 # smallest dtype available
|
||||||
|
)
|
||||||
|
iweights = iweights.view(iweights.shape[0], -1)
|
||||||
|
|
||||||
|
# unpacking columnwise
|
||||||
|
if qzeros is not None:
|
||||||
|
izeros = torch.bitwise_right_shift(
|
||||||
|
qzeros[:, :, None], shifts[None, None, :]
|
||||||
|
).to(
|
||||||
|
torch.int8 # smallest dtype available
|
||||||
|
)
|
||||||
|
izeros = izeros.view(izeros.shape[0], -1)
|
||||||
|
else:
|
||||||
|
izeros = qzeros
|
||||||
|
|
||||||
|
return iweights, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||||
|
reverse_order_tensor = torch.arange(
|
||||||
|
iweights.shape[-1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=izeros.device,
|
||||||
|
)
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||||
|
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||||
|
|
||||||
|
if izeros is not None:
|
||||||
|
izeros = izeros[:, reverse_order_tensor]
|
||||||
|
iweights = iweights[:, reverse_order_tensor]
|
||||||
|
|
||||||
|
return iweights, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_weight_and_zeros(qweight, qzeros, bits):
|
||||||
|
# Unpack the qweight and qzeros tensors
|
||||||
|
iweight, izeros = unpack_awq(qweight, qzeros, bits)
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
|
||||||
|
|
||||||
|
# overflow checks
|
||||||
|
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||||
|
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||||
|
|
||||||
|
return iweight, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(input, bits=4):
|
||||||
|
normal = input.to(torch.int32)
|
||||||
|
q = torch.zeros(
|
||||||
|
(normal.shape[0], normal.shape[1] // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < q.shape[1]:
|
||||||
|
for j in range(i, i + (32 // bits)):
|
||||||
|
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||||
|
i += 32 // bits
|
||||||
|
col += 1
|
||||||
|
q = q.to(torch.int32)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class WQLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if w_bit not in [4]:
|
||||||
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
|
self.in_features = qweight.shape[0]
|
||||||
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
|
self.w_bit = w_bit
|
||||||
|
self.group_size = group_size if group_size != -1 else self.in_features
|
||||||
|
# quick sanity check (make sure aligment)
|
||||||
|
assert self.in_features % self.group_size == 0
|
||||||
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.qzeros = qzeros
|
||||||
|
self.scales = scales
|
||||||
|
self.bias = bias
|
||||||
|
self._preprocessing()
|
||||||
|
|
||||||
|
def _preprocessing(self):
|
||||||
|
device = self.qweight.device
|
||||||
|
weight, zeros = unpack_weight_and_zeros(
|
||||||
|
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
|
||||||
|
)
|
||||||
|
self.qweight = pack_tensor(weight).to(device)
|
||||||
|
self.qzeros = pack_tensor(zeros).to(device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||||
|
outputs = torch.matmul(x, weights)
|
||||||
|
|
||||||
|
outputs = outputs + self.bias if self.bias is not None else outputs
|
||||||
|
outputs = outputs.reshape(out_shape)
|
||||||
|
return outputs
|
@ -1,49 +0,0 @@
|
|||||||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import awq_inference_engine # with CUDA kernels
|
|
||||||
|
|
||||||
|
|
||||||
# class ScaledActivation(nn.Module):
|
|
||||||
# def __init__(self, module, scales):
|
|
||||||
# super().__init__()
|
|
||||||
# self.act = module
|
|
||||||
# self.scales = nn.Parameter(scales.data)
|
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class WQLinear(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if w_bit not in [4]:
|
|
||||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
|
||||||
|
|
||||||
self.in_features = qweight.shape[0]
|
|
||||||
self.out_features = qweight.shape[1] * 32 // w_bit
|
|
||||||
|
|
||||||
self.w_bit = w_bit
|
|
||||||
self.group_size = group_size if group_size != -1 else self.in_features
|
|
||||||
# quick sanity check (make sure aligment)
|
|
||||||
assert self.in_features % self.group_size == 0
|
|
||||||
assert self.out_features % (32 // self.w_bit) == 0
|
|
||||||
|
|
||||||
self.qweight = qweight
|
|
||||||
self.qzeros = qzeros
|
|
||||||
self.scales = scales
|
|
||||||
self.bias = bias
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, x):
|
|
||||||
out_shape = x.shape[:-1] + (self.out_features,)
|
|
||||||
out = awq_inference_engine.gemm_forward_cuda(
|
|
||||||
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
|
||||||
)
|
|
||||||
out = out + self.bias if self.bias is not None else out
|
|
||||||
return out.reshape(out_shape)
|
|
@ -1,43 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from EETQ import quant_weights, w8_a16_gemm
|
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EETQWeight(UnquantizedWeight):
|
|
||||||
weight: torch.Tensor
|
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
|
||||||
try:
|
|
||||||
from text_generation_server.layers.eetq import EETQLinear
|
|
||||||
|
|
||||||
return EETQLinear(self.weight, bias)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EETQLinear(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
device = weight.device
|
|
||||||
if weight.dtype != torch.float16:
|
|
||||||
weight = weight.to(dtype=torch.float16)
|
|
||||||
weight = torch.t(weight).contiguous().cpu()
|
|
||||||
weight, scale = quant_weights(weight, torch.int8, False)
|
|
||||||
|
|
||||||
self.weight = weight.cuda(device)
|
|
||||||
self.scale = scale.cuda(device)
|
|
||||||
self.bias = bias.cuda(device) if bias is not None else None
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
output = w8_a16_gemm(input, self.weight, self.scale)
|
|
||||||
output = output + self.bias if self.bias is not None else output
|
|
||||||
return output
|
|
@ -1,102 +1,154 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Union, List
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
Weight,
|
Weight,
|
||||||
WeightsLoader,
|
WeightsLoader,
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_master, log_once
|
|
||||||
import importlib.util
|
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||||
|
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||||
|
import habana_frameworks.torch.utils.experimental as htexp
|
||||||
|
|
||||||
|
w8a8_block_fp8_matmul = None
|
||||||
|
per_token_group_quant_fp8 = None
|
||||||
|
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
FBGEMM_MM_AVAILABLE = False
|
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||||
FBGEMM_DYN_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
def is_fbgemm_gpu_available():
|
|
||||||
try:
|
|
||||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if is_fbgemm_gpu_available():
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
|
||||||
FBGEMM_MM_AVAILABLE = major == 9
|
|
||||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
|
||||||
else:
|
|
||||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear() -> torch.nn.Module:
|
|
||||||
"""
|
"""
|
||||||
Return an FP8 linear `Module` that is compatible with the current system.
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
|
||||||
if major == 8:
|
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
|
||||||
|
|
||||||
# On other systems let Torch decide if the hardware supports FP8.
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def normalize_e4m3fn_to_native_float8(
|
||||||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
weight: torch.Tensor,
|
||||||
):
|
weight_scale: torch.Tensor,
|
||||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
return weight, weight_scale, input_scale
|
||||||
)
|
|
||||||
return qweight, scale
|
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
|
||||||
finfo = torch.finfo(qdtype)
|
def per_tensor_dequantize(
|
||||||
# Calculate the scale as dtype max divided by absmax
|
tensor: torch.Tensor,
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
inv_scale: Union[float, torch.Tensor],
|
||||||
# scale and clamp the tensor to bring it to
|
dtype: torch.dtype = torch.float16,
|
||||||
# the representative range of float8 data type
|
) -> torch.Tensor:
|
||||||
# (as default cast is unsaturated)
|
device = tensor.device
|
||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
dtype = torch.bfloat16
|
||||||
# Return both float8 data and the inverse scale (as float),
|
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||||
# as both required as inputs to torch._scaled_mm
|
# dequant on cpu to avoid nan on gaudi2
|
||||||
qweight = qweight.to(qdtype)
|
tensor = tensor.to("cpu")
|
||||||
scale = scale.float().reciprocal()
|
|
||||||
return qweight, scale
|
fake_qweight = tensor.to(dtype).to(device)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(
|
||||||
|
weight[start:end, :], weight_scale[idx], dtype
|
||||||
|
)
|
||||||
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
|
weight_dq, max_w_scale
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight, max_w_scale_normalized
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_quantize(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
scalar: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||||
|
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||||
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
|
be used without modification).
|
||||||
|
"""
|
||||||
|
shape = weight.shape
|
||||||
|
qweight, scale = scaled_fp8_quant(
|
||||||
|
weight.reshape(-1, shape[-1]),
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=scale_upper_bound,
|
||||||
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
|
use_per_token_if_dynamic=not scalar,
|
||||||
|
)
|
||||||
|
|
||||||
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
|
|
||||||
class HybridFP8UnquantLoader(WeightsLoader):
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scale_ub: Optional[float],
|
||||||
|
to_fp8: bool,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
self.activation_scale_ub = activation_scale_ub
|
self.activation_scale_ub = activation_scale_ub
|
||||||
self.to_fp8 = to_fp8
|
self.to_fp8 = to_fp8
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
def get_weights(self, weights: "Weights", prefix: str):
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
w = weights.get_tensor(f"{prefix}.weight")
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
if self.weight_block_size is not None:
|
||||||
scale = (
|
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
|
||||||
.expand(w.shape[0])
|
|
||||||
)
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
if self.to_fp8:
|
if self.to_fp8:
|
||||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
@ -116,6 +168,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale",
|
f"{prefix}.weight_scale",
|
||||||
@ -123,11 +176,29 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
)
|
||||||
|
if input_scale.numel() > 1:
|
||||||
|
input_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.input_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
@ -148,15 +219,48 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=dim)
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
scale = [
|
scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
input_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
@ -169,14 +273,35 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = (
|
if self.weight_block_size is not None:
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.max()
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
)
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
@ -191,83 +316,126 @@ class Fp8Weight(Weight):
|
|||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
weight_scale: Optional[torch.Tensor] = None
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
activation_scale_ub: Optional[float] = None
|
activation_scale_ub: Optional[float] = None
|
||||||
|
force_w8a16: bool = False
|
||||||
|
weight_block_size: Optional[List[int]] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
|
||||||
|
self.weight, bias, self.dtype
|
||||||
|
)
|
||||||
# This is not checked by the fbgemm kernels, but they require contiguous
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
self.weight_scale = self.weight_scale.contiguous()
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
|
||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
weight=self.weight,
|
||||||
|
scale=self.weight_scale,
|
||||||
|
dtype=self.dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=self.input_scale,
|
||||||
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
|
_device_identity_cache = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qweight,
|
qweight: torch.Tensor,
|
||||||
scale,
|
scale: torch.Tensor,
|
||||||
scale_upper_bound,
|
dtype: torch.dtype,
|
||||||
bias,
|
bias: Optional[torch.Tensor] = None,
|
||||||
dtype,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[float] = None,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if FBGEMM_MM_AVAILABLE:
|
|
||||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale
|
self.scale = scale.float()
|
||||||
self.scale_upper_bound = (
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
torch.tensor(
|
self.weight_block_size = weight_block_size
|
||||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
self.scale_upper_bound = scale_upper_bound
|
||||||
)
|
|
||||||
if scale_upper_bound is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
qweight, scale = fp8_quantize(weight, scalar=True)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
qweight=qweight,
|
||||||
|
scale=scale,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=None,
|
||||||
|
scale_upper_bound=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
def from_fp8(
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
cls,
|
||||||
# fbgemm needs float32 scales.
|
weight: torch.Tensor,
|
||||||
scale = scale.float()
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Fp8Linear":
|
||||||
|
input_scale = kwargs.get("input_scale", None)
|
||||||
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
weight_block_size = kwargs.get("weight_block_size", None)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_upper_bound=input_scale,
|
input_scale=input_scale,
|
||||||
|
scale_upper_bound=scale_upper_bound,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_shared_device_identity(cls, device):
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
|
if device not in cls._device_identity_cache:
|
||||||
|
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||||
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if self.weight_block_size is not None:
|
||||||
qinput, scale = fp8_quantize(
|
# https://arxiv.org/pdf/2412.19437
|
||||||
input, scale_upper_bound=self.scale_upper_bound
|
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||||
)
|
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||||
|
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
# channels).
|
||||||
|
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||||
|
output = w8a8_block_fp8_matmul(
|
||||||
qinput,
|
qinput,
|
||||||
self.qweight,
|
self.qweight,
|
||||||
scale,
|
scale,
|
||||||
self.scale,
|
self.scale,
|
||||||
use_fast_accum=True,
|
self.weight_block_size,
|
||||||
bias=self.bias,
|
output_dtype=input.dtype,
|
||||||
)
|
)
|
||||||
return y.to(self.dtype)
|
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(input, scalar=True)
|
if self.bias is not None:
|
||||||
output, _ = torch._scaled_mm(
|
output = output + self.bias
|
||||||
|
return output.to(dtype=input.dtype)
|
||||||
|
|
||||||
|
qinput, scale = fp8_quantize(
|
||||||
|
input,
|
||||||
|
self.input_scale,
|
||||||
|
scale_upper_bound=self.scale_upper_bound,
|
||||||
|
scalar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
self.qweight.t(),
|
self.qweight.t(),
|
||||||
out_dtype=self.dtype,
|
out_dtype=self.dtype,
|
||||||
@ -275,11 +443,16 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
scale_b=self.scale,
|
scale_b=self.scale,
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
return scale.reshape(-1).expand(shape[0])
|
return scale.reshape(-1)
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
import os
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
from .hpu import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQWeight(Weight):
|
class GPTQWeight(Weight):
|
||||||
qweight: torch.Tensor
|
qweight: torch.Tensor
|
||||||
@ -30,13 +31,8 @@ class GPTQWeight(Weight):
|
|||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.use_awq_kernel:
|
if self.use_awq_kernel:
|
||||||
if SYSTEM == "rocm":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
|
||||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
from text_generation_server.layers.awq.quantize import WQLinear
|
||||||
|
|
||||||
return WQLinear(
|
return WQLinear(
|
||||||
w_bit=self.bits,
|
w_bit=self.bits,
|
||||||
@ -50,18 +46,7 @@ class GPTQWeight(Weight):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||||
)
|
)
|
||||||
elif self.use_exllama:
|
|
||||||
try:
|
|
||||||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
|
||||||
except ImportError:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ExllamaQuantLinear(self, bias)
|
|
||||||
else:
|
else:
|
||||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|
||||||
|
|
||||||
return QuantLinear(
|
return QuantLinear(
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.qzeros,
|
self.qzeros,
|
||||||
@ -118,23 +103,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import (
|
|
||||||
HAS_EXLLAMA,
|
|
||||||
CAN_EXLLAMA,
|
|
||||||
GPTQWeight,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_exllama:
|
|
||||||
if not HAS_EXLLAMA:
|
|
||||||
if CAN_EXLLAMA:
|
|
||||||
log_once(
|
|
||||||
logger.warning,
|
|
||||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
|
||||||
)
|
|
||||||
use_exllama = False
|
|
||||||
else:
|
|
||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
|
||||||
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
@ -247,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||||
|
|
||||||
use_exllama = (
|
|
||||||
self.bits == 4
|
|
||||||
and HAS_EXLLAMA
|
|
||||||
and self.quantize == "gptq"
|
|
||||||
and not self.desc_act
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
@ -298,6 +259,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
|
desc_act = self.desc_act
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
@ -321,7 +283,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
if (
|
if (
|
||||||
not torch.equal(
|
not torch.equal(
|
||||||
g_idx.cpu(),
|
# Remove g_idx[0] to adapt the check with TP>1.
|
||||||
|
(g_idx - g_idx[0]).cpu(),
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[i // self.groupsize for i in range(g_idx.shape[0])],
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -332,34 +295,22 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
desc_act = True
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import (
|
from text_generation_server.layers.gptq import (
|
||||||
CAN_EXLLAMA,
|
|
||||||
HAS_EXLLAMA,
|
|
||||||
GPTQWeight,
|
GPTQWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_exllama:
|
if not desc_act and self.groupsize != -1:
|
||||||
if not HAS_EXLLAMA:
|
|
||||||
if CAN_EXLLAMA:
|
|
||||||
log_once(
|
|
||||||
logger.warning,
|
|
||||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
|
||||||
)
|
|
||||||
use_exllama = False
|
|
||||||
else:
|
|
||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
|
||||||
|
|
||||||
if use_exllama and self.groupsize != -1:
|
|
||||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
if g_idx is not None:
|
||||||
|
# qzeros, scales sharded, and g_idx must be adjusted accordingly
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
else:
|
else:
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
if use_exllama and g_idx is not None:
|
|
||||||
g_idx = g_idx - g_idx[0]
|
|
||||||
|
|
||||||
if self.quantize == "gptq" and self.quant_method == "awq":
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
@ -392,7 +343,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_gptq_params(self, weights: Weights):
|
def _get_gptq_params(self, weights: Weights):
|
||||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||||
self.bits = weights.get_tensor("gptq_bits").item()
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
self.desc_act = False
|
self.desc_act = False
|
||||||
@ -400,41 +351,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
# before the `gptq_sym` setting tensor was added.
|
# before the `gptq_sym` setting tensor was added.
|
||||||
self.sym = (
|
self.sym = (
|
||||||
weights.get_tensor("gptq_sym").item()
|
weights.get_tensor("gptq_sym").item()
|
||||||
if weights._has_tensor("gptq_sym")
|
if weights.has_tensor("gptq_sym")
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
self.quant_method = "gptq"
|
self.quant_method = "gptq"
|
||||||
|
|
||||||
|
|
||||||
# Needs to be at the end because circular import.
|
|
||||||
try:
|
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
|
||||||
except Exception:
|
|
||||||
major = 1
|
|
||||||
|
|
||||||
HAS_EXLLAMA = False
|
|
||||||
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
|
||||||
HAS_EXLLAMA = False
|
|
||||||
elif CAN_EXLLAMA:
|
|
||||||
try:
|
|
||||||
if V2:
|
|
||||||
from text_generation_server.layers.gptq.exllamav2 import (
|
|
||||||
QuantLinear as ExllamaQuantLinear, # noqa: F401
|
|
||||||
create_exllama_buffers, # noqa: F401
|
|
||||||
set_device, # noqa: F401
|
|
||||||
)
|
|
||||||
|
|
||||||
HAS_EXLLAMA = "2"
|
|
||||||
else:
|
|
||||||
from text_generation_server.layers.gptq.exllama import (
|
|
||||||
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
|
|
||||||
create_exllama_buffers, # noqa: F401
|
|
||||||
set_device, # noqa: F401
|
|
||||||
)
|
|
||||||
|
|
||||||
HAS_EXLLAMA = "1"
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
@ -1,261 +0,0 @@
|
|||||||
# https://github.com/fpgaminer/GPTQ-triton
|
|
||||||
"""
|
|
||||||
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import builtins
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import triton
|
|
||||||
|
|
||||||
|
|
||||||
class Autotuner(triton.KernelInterface):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
fn,
|
|
||||||
arg_names,
|
|
||||||
configs,
|
|
||||||
key,
|
|
||||||
reset_to_zero,
|
|
||||||
prune_configs_by: Dict = None,
|
|
||||||
nearest_power_of_two: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
||||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
||||||
'top_k': number of configs to bench
|
|
||||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
|
||||||
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
|
||||||
"""
|
|
||||||
if not configs:
|
|
||||||
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
|
||||||
else:
|
|
||||||
self.configs = configs
|
|
||||||
self.key_idx = [arg_names.index(k) for k in key]
|
|
||||||
self.nearest_power_of_two = nearest_power_of_two
|
|
||||||
self.cache = {}
|
|
||||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
|
||||||
self.hook = lambda args: 0
|
|
||||||
if reset_to_zero is not None:
|
|
||||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
|
||||||
|
|
||||||
def _hook(args):
|
|
||||||
for i in self.reset_idx:
|
|
||||||
args[i].zero_()
|
|
||||||
|
|
||||||
self.hook = _hook
|
|
||||||
self.arg_names = arg_names
|
|
||||||
# prune configs
|
|
||||||
if prune_configs_by:
|
|
||||||
perf_model, top_k = (
|
|
||||||
prune_configs_by["perf_model"],
|
|
||||||
prune_configs_by["top_k"],
|
|
||||||
)
|
|
||||||
if "early_config_prune" in prune_configs_by:
|
|
||||||
early_config_prune = prune_configs_by["early_config_prune"]
|
|
||||||
else:
|
|
||||||
perf_model, top_k, early_config_prune = None, None, None
|
|
||||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
|
||||||
self.early_config_prune = early_config_prune
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def _bench(self, *args, config, **meta):
|
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
|
||||||
# as kwargs and by the autotuner
|
|
||||||
conflicts = meta.keys() & config.kwargs.keys()
|
|
||||||
if conflicts:
|
|
||||||
raise ValueError(
|
|
||||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
|
||||||
" Make sure that you don't re-define auto-tuned symbols."
|
|
||||||
)
|
|
||||||
# augment meta-parameters with tunable ones
|
|
||||||
current = dict(meta, **config.kwargs)
|
|
||||||
|
|
||||||
def kernel_call():
|
|
||||||
if config.pre_hook:
|
|
||||||
config.pre_hook(self.nargs)
|
|
||||||
self.hook(args)
|
|
||||||
self.fn.run(
|
|
||||||
*args,
|
|
||||||
num_warps=config.num_warps,
|
|
||||||
num_stages=config.num_stages,
|
|
||||||
**current,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
|
||||||
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
|
||||||
return triton.testing.do_bench(
|
|
||||||
kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40
|
|
||||||
)
|
|
||||||
except triton.OutOfResources:
|
|
||||||
return [float("inf"), float("inf"), float("inf")]
|
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
|
||||||
self.nargs = dict(zip(self.arg_names, args))
|
|
||||||
if len(self.configs) > 1:
|
|
||||||
key = tuple(args[i] for i in self.key_idx)
|
|
||||||
|
|
||||||
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
|
||||||
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
|
||||||
if self.nearest_power_of_two:
|
|
||||||
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
|
||||||
|
|
||||||
if key not in self.cache:
|
|
||||||
# prune configs
|
|
||||||
pruned_configs = self.prune_configs(kwargs)
|
|
||||||
bench_start = time.time()
|
|
||||||
timings = {
|
|
||||||
config: self._bench(*args, config=config, **kwargs)
|
|
||||||
for config in pruned_configs
|
|
||||||
}
|
|
||||||
bench_end = time.time()
|
|
||||||
self.bench_time = bench_end - bench_start
|
|
||||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
|
||||||
self.hook(args)
|
|
||||||
self.configs_timings = timings
|
|
||||||
config = self.cache[key]
|
|
||||||
else:
|
|
||||||
config = self.configs[0]
|
|
||||||
self.best_config = config
|
|
||||||
if config.pre_hook is not None:
|
|
||||||
config.pre_hook(self.nargs)
|
|
||||||
return self.fn.run(
|
|
||||||
*args,
|
|
||||||
num_warps=config.num_warps,
|
|
||||||
num_stages=config.num_stages,
|
|
||||||
**kwargs,
|
|
||||||
**config.kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prune_configs(self, kwargs):
|
|
||||||
pruned_configs = self.configs
|
|
||||||
if self.early_config_prune:
|
|
||||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
|
||||||
if self.perf_model:
|
|
||||||
top_k = self.configs_top_k
|
|
||||||
if isinstance(top_k, float) and top_k <= 1.0:
|
|
||||||
top_k = int(len(self.configs) * top_k)
|
|
||||||
if len(pruned_configs) > top_k:
|
|
||||||
est_timing = {
|
|
||||||
config: self.perf_model(
|
|
||||||
**self.nargs,
|
|
||||||
**kwargs,
|
|
||||||
**config.kwargs,
|
|
||||||
num_stages=config.num_stages,
|
|
||||||
num_warps=config.num_warps,
|
|
||||||
)
|
|
||||||
for config in pruned_configs
|
|
||||||
}
|
|
||||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
|
|
||||||
:top_k
|
|
||||||
]
|
|
||||||
return pruned_configs
|
|
||||||
|
|
||||||
def warmup(self, *args, **kwargs):
|
|
||||||
self.nargs = dict(zip(self.arg_names, args))
|
|
||||||
for config in self.prune_configs(kwargs):
|
|
||||||
self.fn.warmup(
|
|
||||||
*args,
|
|
||||||
num_warps=config.num_warps,
|
|
||||||
num_stages=config.num_stages,
|
|
||||||
**kwargs,
|
|
||||||
**config.kwargs,
|
|
||||||
)
|
|
||||||
self.nargs = None
|
|
||||||
|
|
||||||
|
|
||||||
def autotune(
|
|
||||||
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
|
||||||
.. highlight:: python
|
|
||||||
.. code-block:: python
|
|
||||||
@triton.autotune(configs=[
|
|
||||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
|
||||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
|
||||||
],
|
|
||||||
key=['x_size'] # the two above configs will be evaluated anytime
|
|
||||||
# the value of x_size changes
|
|
||||||
)
|
|
||||||
@triton.jit
|
|
||||||
def kernel(x_ptr, x_size, **META):
|
|
||||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
|
||||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
|
||||||
This means that whatever value the kernel updates will be updated multiple times.
|
|
||||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
|
||||||
reset the value of the provided tensor to `zero` before running any configuration.
|
|
||||||
:param configs: a list of :code:`triton.Config` objects
|
|
||||||
:type configs: list[triton.Config]
|
|
||||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
||||||
:type key: list[str]
|
|
||||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
||||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
||||||
'top_k': number of configs to bench
|
|
||||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
|
||||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
|
||||||
:type reset_to_zero: list[str]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(fn):
|
|
||||||
return Autotuner(
|
|
||||||
fn,
|
|
||||||
fn.arg_names,
|
|
||||||
configs,
|
|
||||||
key,
|
|
||||||
reset_to_zero,
|
|
||||||
prune_configs_by,
|
|
||||||
nearest_power_of_two,
|
|
||||||
)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def matmul248_kernel_config_pruner(configs, nargs):
|
|
||||||
"""
|
|
||||||
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
|
|
||||||
"""
|
|
||||||
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
|
|
||||||
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
|
|
||||||
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
|
|
||||||
|
|
||||||
used = set()
|
|
||||||
for config in configs:
|
|
||||||
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
|
|
||||||
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
|
|
||||||
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
|
|
||||||
group_size_m = config.kwargs["GROUP_SIZE_M"]
|
|
||||||
|
|
||||||
if (
|
|
||||||
block_size_m,
|
|
||||||
block_size_n,
|
|
||||||
block_size_k,
|
|
||||||
group_size_m,
|
|
||||||
config.num_stages,
|
|
||||||
config.num_warps,
|
|
||||||
) in used:
|
|
||||||
continue
|
|
||||||
|
|
||||||
used.add(
|
|
||||||
(
|
|
||||||
block_size_m,
|
|
||||||
block_size_n,
|
|
||||||
block_size_k,
|
|
||||||
group_size_m,
|
|
||||||
config.num_stages,
|
|
||||||
config.num_warps,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
|
||||||
"GROUP_SIZE_M": group_size_m,
|
|
||||||
},
|
|
||||||
num_stages=config.num_stages,
|
|
||||||
num_warps=config.num_warps,
|
|
||||||
)
|
|
@ -1,134 +0,0 @@
|
|||||||
from text_generation_server.layers.gptq import GPTQWeight
|
|
||||||
import torch
|
|
||||||
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
|
||||||
none_tensor = torch.empty((1, 1), device="meta")
|
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
|
||||||
"""Construct Q4Matrix, return handle"""
|
|
||||||
return make_q4(
|
|
||||||
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ext_q4_matmul(x, q4, q4_width):
|
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
|
||||||
outshape = x.shape[:-1] + (q4_width,)
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
|
|
||||||
|
|
||||||
q4_matmul(x, q4, output)
|
|
||||||
|
|
||||||
return output.view(outshape)
|
|
||||||
|
|
||||||
|
|
||||||
MAX_DQ = 1
|
|
||||||
MAX_INNER = 1
|
|
||||||
ACT_ORDER = False
|
|
||||||
DEVICE = None
|
|
||||||
|
|
||||||
TEMP_STATE = None
|
|
||||||
TEMP_DQ = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_device(device):
|
|
||||||
global DEVICE
|
|
||||||
DEVICE = device
|
|
||||||
|
|
||||||
|
|
||||||
def create_exllama_buffers(max_total_tokens: int):
|
|
||||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
|
||||||
|
|
||||||
assert DEVICE is not None, "call set_device first"
|
|
||||||
|
|
||||||
if not ACT_ORDER:
|
|
||||||
max_total_tokens = 1
|
|
||||||
|
|
||||||
# This temp_state buffer is required to reorder X in the act-order case.
|
|
||||||
temp_state = torch.zeros(
|
|
||||||
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
|
||||||
)
|
|
||||||
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
|
||||||
|
|
||||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
|
||||||
prepare_buffers(DEVICE, temp_state, temp_dq)
|
|
||||||
|
|
||||||
matmul_recons_thd = 8
|
|
||||||
matmul_fused_remap = False
|
|
||||||
matmul_no_half2 = False
|
|
||||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
||||||
|
|
||||||
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
|
||||||
|
|
||||||
|
|
||||||
class Ex4bitLinear(torch.nn.Module):
|
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
|
||||||
|
|
||||||
def __init__(self, weight: GPTQWeight, bias):
|
|
||||||
super().__init__()
|
|
||||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
|
||||||
assert weight.bits == 4
|
|
||||||
|
|
||||||
self.device = weight.qweight.device
|
|
||||||
self.qweight = weight.qweight
|
|
||||||
self.qzeros = weight.qzeros
|
|
||||||
self.scales = weight.scales
|
|
||||||
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
|
|
||||||
self.bias = bias if bias is not None else None
|
|
||||||
|
|
||||||
if self.g_idx is not None and (
|
|
||||||
(self.g_idx == 0).all()
|
|
||||||
or torch.equal(
|
|
||||||
weight.g_idx.cpu(),
|
|
||||||
torch.tensor(
|
|
||||||
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
|
|
||||||
dtype=torch.int32,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
self.empty_g_idx = True
|
|
||||||
self.g_idx = None
|
|
||||||
|
|
||||||
assert self.device.type == "cuda"
|
|
||||||
assert self.device.index is not None
|
|
||||||
|
|
||||||
self.q4 = ext_make_q4(
|
|
||||||
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
|
|
||||||
)
|
|
||||||
|
|
||||||
self.height = weight.qweight.shape[0] * 8
|
|
||||||
self.width = weight.qweight.shape[1]
|
|
||||||
|
|
||||||
# Infer groupsize from height of qzeros
|
|
||||||
self.groupsize = None
|
|
||||||
if self.qzeros.shape[0] > 1:
|
|
||||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
|
||||||
|
|
||||||
if self.groupsize is not None:
|
|
||||||
assert weight.groupsize == self.groupsize
|
|
||||||
|
|
||||||
# Handle act-order matrix
|
|
||||||
if self.g_idx is not None:
|
|
||||||
if self.groupsize is None:
|
|
||||||
raise ValueError("Found group index but no groupsize. What do?")
|
|
||||||
self.act_order = True
|
|
||||||
else:
|
|
||||||
self.act_order = False
|
|
||||||
|
|
||||||
DEVICE = self.qweight.device
|
|
||||||
|
|
||||||
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
|
|
||||||
|
|
||||||
if self.act_order:
|
|
||||||
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
|
||||||
|
|
||||||
ACT_ORDER = True
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = ext_q4_matmul(x, self.q4, self.width)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
out.add_(self.bias)
|
|
||||||
return out
|
|
@ -1,267 +0,0 @@
|
|||||||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
|
|
||||||
try:
|
|
||||||
from exllamav2.ext import exllamav2_ext
|
|
||||||
|
|
||||||
make_q_matrix = exllamav2_ext.make_q_matrix
|
|
||||||
gemm_half_q_half = exllamav2_ext.gemm_half_q_half
|
|
||||||
except ImportError:
|
|
||||||
log_master(logger.warning, "exllamav2_kernels not installed.")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
|
||||||
none_tensor = torch.empty((1, 1), device="meta")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _ExtraTensors:
|
|
||||||
"""Additional generated quantizer tensors."""
|
|
||||||
|
|
||||||
q_group_map: Optional[torch.Tensor] = None
|
|
||||||
q_invperm: Optional[torch.Tensor] = None
|
|
||||||
q_perm: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
|
||||||
output_shape = x.shape[:-1] + (q4_width,)
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
|
|
||||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
|
||||||
return output.view(output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
|
|
||||||
gr = q_groups.tolist()
|
|
||||||
group_map = []
|
|
||||||
num_groups = len(gr) // 2
|
|
||||||
|
|
||||||
for i in range(num_groups):
|
|
||||||
bits = gr[i * 2]
|
|
||||||
if i < num_groups - 1:
|
|
||||||
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
|
|
||||||
else:
|
|
||||||
qrows = num_qrows - gr[i * 2 + 1]
|
|
||||||
rows = qrows * 32 // bits
|
|
||||||
for j in range(rows):
|
|
||||||
group_map += [i]
|
|
||||||
group_map += [rows - j]
|
|
||||||
|
|
||||||
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
|
|
||||||
|
|
||||||
|
|
||||||
# Create Q matrix
|
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q_matrix(
|
|
||||||
w: Exl2Weight | GPTQWeight,
|
|
||||||
extra: _ExtraTensors,
|
|
||||||
temp_dq,
|
|
||||||
key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create Q matrix
|
|
||||||
"""
|
|
||||||
# max_dq_size = 512*(1024**2)
|
|
||||||
# max_dq_rows = max_dq_size // out_features[0]
|
|
||||||
max_dq_rows = 0
|
|
||||||
|
|
||||||
# EXL2
|
|
||||||
if isinstance(w, Exl2Weight):
|
|
||||||
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
|
||||||
extra.q_perm = torch.argsort(w.q_invperm).short()
|
|
||||||
|
|
||||||
return make_q_matrix(
|
|
||||||
w.q_weight,
|
|
||||||
extra.q_perm,
|
|
||||||
w.q_invperm,
|
|
||||||
w.q_scale,
|
|
||||||
w.q_scale_max,
|
|
||||||
w.q_groups,
|
|
||||||
extra.q_group_map,
|
|
||||||
none_tensor, # zeros
|
|
||||||
none_tensor, # scales
|
|
||||||
none_tensor, # g_idx
|
|
||||||
none_tensor, # bias
|
|
||||||
temp_dq,
|
|
||||||
max_dq_rows,
|
|
||||||
)
|
|
||||||
# GPTQ
|
|
||||||
elif isinstance(w, GPTQWeight):
|
|
||||||
if w.scales.dtype == torch.float:
|
|
||||||
w.scales = w.scales.half()
|
|
||||||
|
|
||||||
# GPTQ with g_idx (act_order)
|
|
||||||
if w.g_idx is not None and not (w.g_idx == 0).all().item():
|
|
||||||
extra.q_perm = torch.empty(
|
|
||||||
(w.qweight.shape[0] * 8,),
|
|
||||||
dtype=torch.short,
|
|
||||||
device=w.qweight.device,
|
|
||||||
)
|
|
||||||
extra.q_invperm = torch.empty_like(extra.q_perm)
|
|
||||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
|
||||||
return make_q_matrix(
|
|
||||||
w.qweight,
|
|
||||||
extra.q_perm,
|
|
||||||
extra.q_invperm,
|
|
||||||
none_tensor, # q_scale
|
|
||||||
none_tensor, # q_scale_max
|
|
||||||
none_tensor, # q_groups
|
|
||||||
none_tensor, # q_group_map
|
|
||||||
w.qzeros,
|
|
||||||
w.scales,
|
|
||||||
w.g_idx.cpu(),
|
|
||||||
none_tensor, # bias
|
|
||||||
temp_dq,
|
|
||||||
max_dq_rows,
|
|
||||||
)
|
|
||||||
# GPTQ without g_idx
|
|
||||||
else:
|
|
||||||
return make_q_matrix(
|
|
||||||
w.qweight,
|
|
||||||
none_tensor, # q_perm
|
|
||||||
none_tensor, # q_invperm
|
|
||||||
none_tensor, # q_scale
|
|
||||||
none_tensor, # q_scale_max
|
|
||||||
none_tensor, # q_groups
|
|
||||||
none_tensor, # q_group_map
|
|
||||||
w.qzeros,
|
|
||||||
w.scales,
|
|
||||||
none_tensor, # g_idx
|
|
||||||
none_tensor, # bias
|
|
||||||
temp_dq,
|
|
||||||
max_dq_rows,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
RuntimeError("Cannot create handle")
|
|
||||||
|
|
||||||
|
|
||||||
DEVICE = None
|
|
||||||
LAYERS = []
|
|
||||||
|
|
||||||
|
|
||||||
def set_device(device):
|
|
||||||
global DEVICE
|
|
||||||
DEVICE = device
|
|
||||||
|
|
||||||
|
|
||||||
def create_exllama_buffers(max_total_tokens: int):
|
|
||||||
global LAYERS, DEVICE
|
|
||||||
|
|
||||||
# No need to initialize scratch space if there are no layers
|
|
||||||
# that use ExLLamav2.
|
|
||||||
if len(LAYERS) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find the size of the scratch space.
|
|
||||||
scratch_bytes = max(
|
|
||||||
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
|
|
||||||
for layer in LAYERS
|
|
||||||
)
|
|
||||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
|
||||||
|
|
||||||
for layer in LAYERS:
|
|
||||||
layer.post_init(temp_dq)
|
|
||||||
|
|
||||||
|
|
||||||
class QuantLinear(nn.Module):
|
|
||||||
QUANT_TYPE = "exllamav2"
|
|
||||||
|
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight: Exl2Weight | GPTQWeight,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.q_handle = None
|
|
||||||
self.q_tensors = weight
|
|
||||||
self.extra_tensors = _ExtraTensors()
|
|
||||||
|
|
||||||
if isinstance(weight, Exl2Weight):
|
|
||||||
self.infeatures = weight.q_invperm.shape[0]
|
|
||||||
self.outfeatures = weight.q_weight.shape[1]
|
|
||||||
elif isinstance(weight, GPTQWeight):
|
|
||||||
if weight.bits != 4:
|
|
||||||
raise ValueError(
|
|
||||||
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
|
|
||||||
self.outfeatures = weight.qweight.shape[1]
|
|
||||||
|
|
||||||
self.padding = -self.outfeatures % 32
|
|
||||||
self.outfeatures = self.outfeatures + self.padding
|
|
||||||
|
|
||||||
self.device = weight.device
|
|
||||||
self.bias = bias if bias is not None else None
|
|
||||||
|
|
||||||
global LAYERS
|
|
||||||
LAYERS.append(self)
|
|
||||||
|
|
||||||
def post_init(self, temp_dq):
|
|
||||||
device = self.q_tensors.device
|
|
||||||
assert device.type == "cuda"
|
|
||||||
assert device.index is not None
|
|
||||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
|
||||||
|
|
||||||
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
|
||||||
# and `Memory access fault by GPU node-2` will EAT you.
|
|
||||||
self.temp_dq = temp_dq
|
|
||||||
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
|
|
||||||
|
|
||||||
def forward(self, x, force_cuda=False):
|
|
||||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
output.add_(self.bias)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def temp_dq_size(self):
|
|
||||||
return self.infeatures * self.outfeatures * 2 + 128
|
|
||||||
|
|
||||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
|
||||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
|
||||||
|
|
||||||
def scratch_space_fixed(self, max_input_len, max_batch_size):
|
|
||||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2DeviceTensors:
|
|
||||||
|
|
||||||
device_idx: int
|
|
||||||
scratch_bytes: int
|
|
||||||
scratch_idx: int
|
|
||||||
scratch: torch.tensor = None
|
|
||||||
|
|
||||||
def __init__(self, device, scratch_bytes):
|
|
||||||
self.device = device
|
|
||||||
self.scratch_bytes = scratch_bytes
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
self.scratch = torch.empty(
|
|
||||||
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_scratch_slice(self, size_bytes):
|
|
||||||
|
|
||||||
if self.scratch is None:
|
|
||||||
self.prepare()
|
|
||||||
|
|
||||||
size_bytes = ((size_bytes + 127) // 128) * 128
|
|
||||||
size_half = size_bytes // 2
|
|
||||||
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
|
||||||
return scratch_slice
|
|
186
backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
Normal file
186
backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||||
|
except Exception as e:
|
||||||
|
hpu_import_exception = e
|
||||||
|
|
||||||
|
def error_raiser_hpu(*args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_from_uint4 = error_raiser_hpu
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(input, bits=4):
|
||||||
|
normal = input.to(torch.int32)
|
||||||
|
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < q.shape[1]:
|
||||||
|
for j in range(i, i + (32 // bits)):
|
||||||
|
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||||
|
i += 32 // bits
|
||||||
|
col += 1
|
||||||
|
q = q.to(torch.int32)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("qweight", qweight)
|
||||||
|
self.register_buffer("qzeros", qzeros)
|
||||||
|
self.register_buffer("scales", scales)
|
||||||
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
if bias is not None:
|
||||||
|
self.register_buffer("bias", bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
self.bits = bits
|
||||||
|
self.maxq = 2**self.bits - 1
|
||||||
|
self.groupsize = groupsize
|
||||||
|
|
||||||
|
self.outfeatures = qweight.shape[1]
|
||||||
|
self.infeatures = qweight.shape[0] * 32 // bits
|
||||||
|
self.wf = torch.tensor(
|
||||||
|
list(range(0, 32, self.bits)), dtype=torch.int32
|
||||||
|
).unsqueeze(0)
|
||||||
|
self._preprocessing()
|
||||||
|
|
||||||
|
def unpack_zeros_from_cuda_old_format(self):
|
||||||
|
zeros = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||||
|
self.wf.unsqueeze(0),
|
||||||
|
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||||
|
|
||||||
|
zeros = zeros + 1
|
||||||
|
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
|
||||||
|
self.scales.dtype
|
||||||
|
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
|
||||||
|
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
|
||||||
|
return zeros
|
||||||
|
|
||||||
|
def unpack_weight_from_cuda_old_format(self):
|
||||||
|
weight = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||||
|
self.wf.unsqueeze(-1),
|
||||||
|
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||||
|
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
|
||||||
|
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _preprocessing(self):
|
||||||
|
orig_device = self.qweight.device
|
||||||
|
self.qweight = self.qweight.cpu()
|
||||||
|
weight = self.unpack_weight_from_cuda_old_format()
|
||||||
|
new_qweight = pack_tensor(weight)
|
||||||
|
self.qweight = new_qweight.to(orig_device)
|
||||||
|
# TODO: Support group indexing and remove the check
|
||||||
|
columns = self.qweight.shape[0]
|
||||||
|
g_idx_trivial = [i // self.groupsize for i in range(columns)]
|
||||||
|
g_idx_trivial = torch.tensor(
|
||||||
|
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
self.g_idx, g_idx_trivial
|
||||||
|
), "Non-trivial tensor g_idx is not supported"
|
||||||
|
self.qzeros = self.qzeros.cpu()
|
||||||
|
zeros = self.unpack_zeros_from_cuda_old_format()
|
||||||
|
new_qzeros = pack_tensor(zeros)
|
||||||
|
self.qzeros = new_qzeros.to(orig_device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||||
|
qzeros = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
scales = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||||
|
)
|
||||||
|
g_idx = torch.tensor(
|
||||||
|
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
|
|
||||||
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||||
|
|
||||||
|
scales = scales.t().contiguous()
|
||||||
|
zeros = zeros.t().contiguous()
|
||||||
|
scale_zeros = zeros * scales
|
||||||
|
self.scales = scales.clone().half()
|
||||||
|
if linear.bias is not None:
|
||||||
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
|
intweight = []
|
||||||
|
for idx in range(self.infeatures):
|
||||||
|
intweight.append(
|
||||||
|
torch.round(
|
||||||
|
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||||
|
/ self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
|
intweight = torch.cat(intweight, dim=1)
|
||||||
|
intweight = intweight.t().contiguous()
|
||||||
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
qweight = np.zeros(
|
||||||
|
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
row = 0
|
||||||
|
while row < qweight.shape[0]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
row += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = qweight.astype(np.int32)
|
||||||
|
self.qweight = torch.from_numpy(qweight)
|
||||||
|
|
||||||
|
zeros -= 1
|
||||||
|
zeros = zeros.numpy().astype(np.uint32)
|
||||||
|
qzeros = np.zeros(
|
||||||
|
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < qzeros.shape[1]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
col += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qzeros = qzeros.astype(np.int32)
|
||||||
|
self.qzeros = torch.from_numpy(qzeros)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||||
|
out = torch.matmul(x, weight)
|
||||||
|
out = out.reshape(out_shape)
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out
|
@ -1,359 +0,0 @@
|
|||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.cuda.amp import custom_fwd
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from . import custom_autotune
|
|
||||||
|
|
||||||
|
|
||||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
|
||||||
@custom_autotune.autotune(
|
|
||||||
configs=[
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 256,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=4,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=4,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=4,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=4,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=4,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=2,
|
|
||||||
num_warps=8,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=3,
|
|
||||||
num_warps=8,
|
|
||||||
),
|
|
||||||
triton.Config(
|
|
||||||
{
|
|
||||||
"BLOCK_SIZE_M": 32,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
},
|
|
||||||
num_stages=2,
|
|
||||||
num_warps=4,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
key=["M", "N", "K"],
|
|
||||||
nearest_power_of_two=True,
|
|
||||||
prune_configs_by={
|
|
||||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
|
||||||
"perf_model": None,
|
|
||||||
"top_k": None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@triton.jit
|
|
||||||
def matmul_248_kernel(
|
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
scales_ptr,
|
|
||||||
zeros_ptr,
|
|
||||||
g_ptr,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
bits,
|
|
||||||
maxq,
|
|
||||||
stride_am,
|
|
||||||
stride_ak,
|
|
||||||
stride_bk,
|
|
||||||
stride_bn,
|
|
||||||
stride_cm,
|
|
||||||
stride_cn,
|
|
||||||
stride_scales,
|
|
||||||
stride_zeros,
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute the matrix multiplication C = A x B.
|
|
||||||
A is of shape (M, K) float16
|
|
||||||
B is of shape (K//8, N) int32
|
|
||||||
C is of shape (M, N) float16
|
|
||||||
scales is of shape (G, N) float16
|
|
||||||
zeros is of shape (G, N) float16
|
|
||||||
g_ptr is of shape (K) int32
|
|
||||||
"""
|
|
||||||
infearure_per_bits = 32 // bits
|
|
||||||
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
||||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
|
||||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
||||||
group_id = pid // num_pid_in_group
|
|
||||||
first_pid_m = group_id * GROUP_SIZE_M
|
|
||||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
||||||
pid_m = first_pid_m + (pid % group_size_m)
|
|
||||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
||||||
|
|
||||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
||||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
a_ptrs = a_ptr + (
|
|
||||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
|
||||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
|
||||||
a_mask = offs_am[:, None] < M
|
|
||||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
|
||||||
b_ptrs = b_ptr + (
|
|
||||||
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
|
||||||
+ offs_bn[None, :] * stride_bn
|
|
||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
|
||||||
g_ptrs = g_ptr + offs_k
|
|
||||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
|
||||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
|
||||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
|
||||||
|
|
||||||
shifter = (offs_k % infearure_per_bits) * bits
|
|
||||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
||||||
|
|
||||||
for k in range(0, num_pid_k):
|
|
||||||
g_idx = tl.load(g_ptrs)
|
|
||||||
|
|
||||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
|
||||||
scales = tl.load(
|
|
||||||
scales_ptrs + g_idx[:, None] * stride_scales
|
|
||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
|
||||||
zeros = tl.load(
|
|
||||||
zeros_ptrs + g_idx[:, None] * stride_zeros
|
|
||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
|
||||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
|
||||||
|
|
||||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
|
||||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
|
||||||
b = (b - zeros) * scales # Scale and shift
|
|
||||||
|
|
||||||
accumulator += tl.dot(a, b)
|
|
||||||
a_ptrs += BLOCK_SIZE_K
|
|
||||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
|
||||||
g_ptrs += BLOCK_SIZE_K
|
|
||||||
|
|
||||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
|
||||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|
||||||
with torch.cuda.device(input.device):
|
|
||||||
output = torch.empty(
|
|
||||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
|
||||||
)
|
|
||||||
|
|
||||||
def grid(META):
|
|
||||||
return (
|
|
||||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
|
||||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
matmul_248_kernel[grid](
|
|
||||||
input,
|
|
||||||
qweight,
|
|
||||||
output,
|
|
||||||
scales,
|
|
||||||
qzeros,
|
|
||||||
g_idx,
|
|
||||||
input.shape[0],
|
|
||||||
qweight.shape[1],
|
|
||||||
input.shape[1],
|
|
||||||
bits,
|
|
||||||
maxq,
|
|
||||||
input.stride(0),
|
|
||||||
input.stride(1),
|
|
||||||
qweight.stride(0),
|
|
||||||
qweight.stride(1),
|
|
||||||
output.stride(0),
|
|
||||||
output.stride(1),
|
|
||||||
scales.stride(0),
|
|
||||||
qzeros.stride(0),
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class QuantLinearFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
|
||||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|
||||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class QuantLinear(nn.Module):
|
|
||||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
|
||||||
super().__init__()
|
|
||||||
self.register_buffer("qweight", qweight)
|
|
||||||
self.register_buffer("qzeros", qzeros)
|
|
||||||
self.register_buffer("scales", scales)
|
|
||||||
self.register_buffer("g_idx", g_idx)
|
|
||||||
if bias is not None:
|
|
||||||
self.register_buffer("bias", bias)
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
if bits not in [2, 4, 8]:
|
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
|
||||||
self.bits = bits
|
|
||||||
self.maxq = 2**self.bits - 1
|
|
||||||
self.groupsize = groupsize
|
|
||||||
|
|
||||||
self.outfeatures = qweight.shape[1]
|
|
||||||
self.infeatures = qweight.shape[0] * 32 // bits
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
|
||||||
if bits not in [2, 4, 8]:
|
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
|
||||||
|
|
||||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
|
||||||
qzeros = torch.zeros(
|
|
||||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
scales = torch.zeros(
|
|
||||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
|
||||||
)
|
|
||||||
g_idx = torch.tensor(
|
|
||||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
|
||||||
)
|
|
||||||
if bias:
|
|
||||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
|
||||||
else:
|
|
||||||
bias = None
|
|
||||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros, g_idx=None):
|
|
||||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
|
||||||
|
|
||||||
scales = scales.t().contiguous()
|
|
||||||
zeros = zeros.t().contiguous()
|
|
||||||
scale_zeros = zeros * scales
|
|
||||||
self.scales = scales.clone().half()
|
|
||||||
if linear.bias is not None:
|
|
||||||
self.bias = linear.bias.clone().half()
|
|
||||||
|
|
||||||
intweight = []
|
|
||||||
for idx in range(self.infeatures):
|
|
||||||
intweight.append(
|
|
||||||
torch.round(
|
|
||||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
|
||||||
/ self.scales[self.g_idx[idx]]
|
|
||||||
).to(torch.int)[:, None]
|
|
||||||
)
|
|
||||||
intweight = torch.cat(intweight, dim=1)
|
|
||||||
intweight = intweight.t().contiguous()
|
|
||||||
intweight = intweight.numpy().astype(np.uint32)
|
|
||||||
qweight = np.zeros(
|
|
||||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
|
||||||
)
|
|
||||||
i = 0
|
|
||||||
row = 0
|
|
||||||
while row < qweight.shape[0]:
|
|
||||||
if self.bits in [2, 4, 8]:
|
|
||||||
for j in range(i, i + (32 // self.bits)):
|
|
||||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
|
||||||
i += 32 // self.bits
|
|
||||||
row += 1
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
|
||||||
|
|
||||||
qweight = qweight.astype(np.int32)
|
|
||||||
self.qweight = torch.from_numpy(qweight)
|
|
||||||
|
|
||||||
zeros -= 1
|
|
||||||
zeros = zeros.numpy().astype(np.uint32)
|
|
||||||
qzeros = np.zeros(
|
|
||||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
|
||||||
)
|
|
||||||
i = 0
|
|
||||||
col = 0
|
|
||||||
while col < qzeros.shape[1]:
|
|
||||||
if self.bits in [2, 4, 8]:
|
|
||||||
for j in range(i, i + (32 // self.bits)):
|
|
||||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
|
||||||
i += 32 // self.bits
|
|
||||||
col += 1
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
|
||||||
|
|
||||||
qzeros = qzeros.astype(np.int32)
|
|
||||||
self.qzeros = torch.from_numpy(qzeros)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
|
||||||
out = QuantLinearFunction.apply(
|
|
||||||
x.reshape(-1, x.shape[-1]),
|
|
||||||
self.qweight,
|
|
||||||
self.scales,
|
|
||||||
self.qzeros,
|
|
||||||
self.g_idx,
|
|
||||||
self.bits,
|
|
||||||
self.maxq,
|
|
||||||
)
|
|
||||||
out = out + self.bias if self.bias is not None else out
|
|
||||||
return out.reshape(out_shape)
|
|
@ -12,7 +12,7 @@ from huggingface_hub import HfApi
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from text_generation_server.utils import initialize_torch_distributed, Weights
|
from text_generation_server.utils import initialize_torch_distributed, Weights
|
||||||
from text_generation_server.utils.hub import weight_files
|
from text_generation_server.utils.hub import weight_files
|
||||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
from text_generation_server.layers.gptq import QuantLinear
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from text_generation_server.layers.gptq.utils import torch_snr_error
|
from text_generation_server.layers.gptq.utils import torch_snr_error
|
||||||
@ -956,15 +956,24 @@ def quantize(
|
|||||||
|
|
||||||
pack(model, quantizers, bits, groupsize)
|
pack(model, quantizers, bits, groupsize)
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
state_dict,
|
||||||
|
filename_pattern="model.safetensors",
|
||||||
|
max_shard_size=max_shard_size,
|
||||||
)
|
)
|
||||||
|
index = None
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
shards = state_dict_split.filename_to_tensors
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
save_file(
|
save_file(
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from text_generation_server.utils.import_utils import (
|
|
||||||
SYSTEM,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@ -33,46 +30,6 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
|||||||
torch.nn.LayerNorm.load = load_layer_norm
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
import dropout_layer_norm
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
|
||||||
def forward(self, hidden_states, residual=None):
|
|
||||||
if hidden_states.shape[-1] > 8192:
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
normed_hidden_states,
|
|
||||||
residual,
|
|
||||||
*rest,
|
|
||||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
self.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
if residual is None:
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
|
||||||
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm._C import ops
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
@ -82,21 +39,6 @@ elif SYSTEM == "rocm":
|
|||||||
|
|
||||||
return super().forward(hidden_states), residual
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
|
||||||
def forward(self, hidden_states, residual=None):
|
|
||||||
out = ipex.llm.functional.add_layer_norm(
|
|
||||||
residual,
|
|
||||||
hidden_states,
|
|
||||||
self.weight,
|
|
||||||
self.bias,
|
|
||||||
self.eps,
|
|
||||||
residual is not None,
|
|
||||||
)
|
|
||||||
return out, residual if residual is not None else hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FastRMSNorm(nn.Module):
|
class FastRMSNorm(nn.Module):
|
||||||
def __init__(self, weight: torch.Tensor, eps: float):
|
def __init__(self, weight: torch.Tensor, eps: float):
|
||||||
@ -111,74 +53,15 @@ class FastRMSNorm(nn.Module):
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if SYSTEM == "ipex":
|
from vllm_hpu_extension.kernels import rms_norm
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
|
||||||
residual,
|
orig_shape = hidden_states.shape
|
||||||
hidden_states,
|
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
self.variance_epsilon,
|
|
||||||
residual is not None,
|
|
||||||
)
|
|
||||||
return out, residual if residual is not None else hidden_states
|
|
||||||
elif hidden_states.shape[-1] > 8192:
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
residual += hidden_states.view(residual.shape)
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(
|
|
||||||
variance + self.variance_epsilon
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
|
||||||
elif SYSTEM == "cuda":
|
|
||||||
# faster post attention rms norm
|
|
||||||
(
|
|
||||||
normed_hidden_states,
|
|
||||||
res,
|
|
||||||
*rest,
|
|
||||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.variance_epsilon,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
True, # Activate RMSNorm
|
|
||||||
)
|
|
||||||
if res is None:
|
|
||||||
res = hidden_states
|
|
||||||
|
|
||||||
return normed_hidden_states, res
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
|
||||||
ops.rms_norm(
|
|
||||||
out,
|
|
||||||
hidden_states,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return out, residual
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
residual = hidden_states
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||||
)
|
if len(orig_shape) == 2:
|
||||||
|
residual = residual.unsqueeze(0)
|
||||||
|
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||||
|
return x.view(orig_shape), residual.view(orig_shape)
|
||||||
|
@ -1,21 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
import os
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
|
||||||
"true",
|
|
||||||
"1",
|
|
||||||
)
|
|
||||||
|
|
||||||
if ROCM_USE_SKINNY_GEMM:
|
|
||||||
try:
|
|
||||||
from vllm import _custom_C
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(torch.nn.Module):
|
class FastLinear(torch.nn.Module):
|
||||||
@ -44,83 +28,11 @@ class FastLinear(torch.nn.Module):
|
|||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
class FastLinearROCm(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.weight = torch.nn.Parameter(weight)
|
|
||||||
if bias is not None:
|
|
||||||
self.bias = torch.nn.Parameter(bias)
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.cu_count = torch.cuda.get_device_properties(
|
|
||||||
device="cuda"
|
|
||||||
).multi_processor_count
|
|
||||||
self.use_skinny_gemm = (
|
|
||||||
ROCM_USE_SKINNY_GEMM
|
|
||||||
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
|
||||||
if bias:
|
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
|
||||||
else:
|
|
||||||
bias = None
|
|
||||||
return cls(weight, bias)
|
|
||||||
|
|
||||||
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
|
||||||
weight = self.weight
|
|
||||||
bias = self.bias
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.use_skinny_gemm
|
|
||||||
and inp.dtype == torch.float16
|
|
||||||
and inp.shape[-1] % 8 == 0
|
|
||||||
):
|
|
||||||
batched = False
|
|
||||||
inp_shape = inp.shape
|
|
||||||
|
|
||||||
if inp.dim() == 3:
|
|
||||||
inp = inp.view(-1, inp_shape[-1])
|
|
||||||
batched = True
|
|
||||||
|
|
||||||
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
|
||||||
if m > 8 and n <= 4:
|
|
||||||
out = torch.empty(
|
|
||||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
|
||||||
)
|
|
||||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
|
||||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
|
||||||
out = torch.empty(
|
|
||||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
|
||||||
)
|
|
||||||
_custom_C.LLMM1(weight, inp, out, 4)
|
|
||||||
else:
|
|
||||||
out = F.linear(inp, weight)
|
|
||||||
|
|
||||||
if batched:
|
|
||||||
out.view(*inp_shape[:-1], out.shape[-1])
|
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
out = out + bias
|
|
||||||
return out
|
|
||||||
return F.linear(inp, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias):
|
def get_linear(weight, bias):
|
||||||
# Weights that are loaded through methods that are not
|
# Weights that are loaded through methods that are not
|
||||||
# quantization-aware are still bare tensors. We may want
|
# quantization-aware are still bare tensors. We may want
|
||||||
# to change this in the future.
|
# to change this in the future.
|
||||||
if isinstance(weight, torch.Tensor):
|
if isinstance(weight, torch.Tensor):
|
||||||
if SYSTEM == "rocm":
|
|
||||||
return FastLinearROCm(weight, bias)
|
|
||||||
else:
|
|
||||||
return FastLinear(weight, bias)
|
return FastLinear(weight, bias)
|
||||||
|
|
||||||
return weight.get_linear(bias)
|
return weight.get_linear(bias)
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
|
||||||
GPTQMarlinWeightsLoader,
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"GPTQMarlinFP8Linear",
|
|
||||||
"GPTQMarlinWeightsLoader",
|
|
||||||
"MarlinWeightsLoader",
|
|
||||||
"can_use_gptq_marlin",
|
|
||||||
"repack_gptq_for_marlin",
|
|
||||||
]
|
|
@ -1,140 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.layers.fp8 import fp8_quantize
|
|
||||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
|
||||||
from text_generation_server.layers.marlin.util import (
|
|
||||||
_check_marlin_kernels,
|
|
||||||
permute_scales,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
|
|
||||||
try:
|
|
||||||
import marlin_kernels
|
|
||||||
except ImportError:
|
|
||||||
marlin_kernels = None
|
|
||||||
|
|
||||||
|
|
||||||
MARLIN_TILE_SIZE = 16
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinFP8Linear(nn.Module):
|
|
||||||
"""
|
|
||||||
FP8 GPTQ-Marlin linear layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
qweight: torch.Tensor,
|
|
||||||
scales: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor],
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
_check_marlin_kernels()
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
|
||||||
|
|
||||||
scales = scales.unsqueeze(0)
|
|
||||||
if scales.shape[1] == 1:
|
|
||||||
out_features, in_features = qweight.shape
|
|
||||||
scales = scales.repeat(1, out_features)
|
|
||||||
qweight, scales = repack_fp8_for_marlin(qweight, scales)
|
|
||||||
|
|
||||||
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
|
||||||
out_features = scales.shape[1]
|
|
||||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
|
||||||
|
|
||||||
self.qweight = qweight
|
|
||||||
self.scales = scales
|
|
||||||
self.bias = bias if bias is not None else None
|
|
||||||
|
|
||||||
self.workspace = torch.zeros(
|
|
||||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
|
||||||
qweight, scales = fp8_quantize(weight)
|
|
||||||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
|
|
||||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
A_flat = A.view(-1, A.shape[-1])
|
|
||||||
C = marlin_kernels.fp8_marlin_gemm(
|
|
||||||
A_flat,
|
|
||||||
self.qweight,
|
|
||||||
self.scales,
|
|
||||||
self.workspace,
|
|
||||||
8,
|
|
||||||
A_flat.shape[0],
|
|
||||||
self.scales.shape[1],
|
|
||||||
A_flat.shape[1],
|
|
||||||
)
|
|
||||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
C += self.bias
|
|
||||||
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Repack FP8 weights to gptq format (packed int32 elements).
|
|
||||||
"""
|
|
||||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
|
||||||
|
|
||||||
if fp8_tensor.shape[0] % 4 != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape to prepare for packing
|
|
||||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
|
||||||
|
|
||||||
# Convert fp8 to uint8 (byte) representation
|
|
||||||
byte_tensor = reshaped.view(torch.uint8)
|
|
||||||
|
|
||||||
# Pack 4 uint8 values into one int32
|
|
||||||
packed = torch.zeros(
|
|
||||||
fp8_tensor.shape[0] // 4,
|
|
||||||
fp8_tensor.shape[1],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=fp8_tensor.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(4):
|
|
||||||
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
|
|
||||||
|
|
||||||
return packed
|
|
||||||
|
|
||||||
|
|
||||||
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Repack FP8 tensor for GPTQ-Marlin.
|
|
||||||
"""
|
|
||||||
|
|
||||||
out_features, in_features = weight.shape
|
|
||||||
|
|
||||||
# Torch linear layers weights with shape [out_features, in_features],
|
|
||||||
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
|
|
||||||
# so transpose before packing.
|
|
||||||
qweight = pack_fp8_as_int32(weight.t())
|
|
||||||
|
|
||||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
||||||
repacked = marlin_kernels.gptq_marlin_repack(
|
|
||||||
qweight, perm, in_features, out_features, 8
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = permute_scales(scales)
|
|
||||||
|
|
||||||
return repacked, scales
|
|
@ -1,464 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import numpy
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.layers.marlin.util import (
|
|
||||||
_check_marlin_kernels,
|
|
||||||
marlin_zero_points,
|
|
||||||
permute_scales,
|
|
||||||
unpack_cols,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
|
||||||
|
|
||||||
try:
|
|
||||||
import marlin_kernels
|
|
||||||
except ImportError:
|
|
||||||
marlin_kernels = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
|
||||||
has_sm_8_0 = major >= 8
|
|
||||||
except Exception:
|
|
||||||
has_sm_8_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
GPTQ_MARLIN_BITS = [4, 8]
|
|
||||||
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|
||||||
MARLIN_TILE_SIZE = 16
|
|
||||||
|
|
||||||
|
|
||||||
def can_use_gptq_marlin(
|
|
||||||
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
|
||||||
) -> bool:
|
|
||||||
return (
|
|
||||||
SYSTEM == "cuda"
|
|
||||||
and marlin_kernels is not None
|
|
||||||
and has_sm_8_0
|
|
||||||
and quantize in {"awq", "gptq"}
|
|
||||||
and quant_method in {"awq", "gptq"}
|
|
||||||
and bits in GPTQ_MARLIN_BITS
|
|
||||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
|
||||||
# We only suppord asymmetric quantization for AWQ.
|
|
||||||
and (sym or quant_method == "awq")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinWeightsLoader(WeightsLoader):
|
|
||||||
"""
|
|
||||||
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
bits: int,
|
|
||||||
desc_act: bool,
|
|
||||||
groupsize: int,
|
|
||||||
quant_method: str,
|
|
||||||
quantize: str,
|
|
||||||
sym: bool,
|
|
||||||
):
|
|
||||||
self.bits = bits
|
|
||||||
self.desc_act = desc_act
|
|
||||||
self.groupsize = groupsize
|
|
||||||
self.quant_method = quant_method
|
|
||||||
self.quantize = quantize
|
|
||||||
self.sym = sym
|
|
||||||
|
|
||||||
def get_weights(self, weights: Weights, prefix: str):
|
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
||||||
try:
|
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weights_col_packed(
|
|
||||||
self,
|
|
||||||
weights: Weights,
|
|
||||||
prefix: str,
|
|
||||||
block_sizes: Union[int, List[int]],
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
qweight = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
scales = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
scales = scales.to(dtype=weights.dtype)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
||||||
try:
|
|
||||||
qweight = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
||||||
try:
|
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
||||||
else:
|
|
||||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
|
||||||
else:
|
|
||||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
|
||||||
|
|
||||||
sharded_in_features = weights.process_group.size() > 1
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=sharded_in_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_gptq_params(self, weights: Weights):
|
|
||||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
|
||||||
self.bits = weights.get_tensor("gptq_bits").item()
|
|
||||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
|
||||||
self.desc_act = False
|
|
||||||
# `server quantize` used asymmetric quantization unconditionally
|
|
||||||
# before the `gptq_sym` setting tensor was added.
|
|
||||||
self.sym = (
|
|
||||||
weights.get_tensor("gptq_sym").item()
|
|
||||||
if weights._has_tensor("gptq_sym")
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
self.quant_method = "gptq"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPTQMarlinWeight(Weight):
|
|
||||||
"""
|
|
||||||
Repacked GPTQ Marlin weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
qweight: torch.Tensor
|
|
||||||
qzeros: torch.Tensor
|
|
||||||
scales: torch.Tensor
|
|
||||||
g_idx: torch.Tensor
|
|
||||||
perm: torch.Tensor
|
|
||||||
bits: int
|
|
||||||
is_full_k: bool
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
assert self.qweight.dtype == torch.int32
|
|
||||||
assert self.scales.dtype == torch.float16
|
|
||||||
assert self.g_idx.dtype == torch.int32
|
|
||||||
assert self.perm.dtype == torch.int32
|
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
|
||||||
return GPTQMarlinLinear(
|
|
||||||
weight=self,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def repack_gptq_for_marlin(
|
|
||||||
*,
|
|
||||||
qweight: torch.Tensor,
|
|
||||||
qzeros: Optional[torch.Tensor],
|
|
||||||
scales: torch.Tensor,
|
|
||||||
g_idx: Optional[torch.Tensor],
|
|
||||||
bits: int,
|
|
||||||
desc_act: bool,
|
|
||||||
groupsize: int,
|
|
||||||
quant_method: str,
|
|
||||||
sym: bool,
|
|
||||||
sharded_infeatures: bool,
|
|
||||||
) -> GPTQMarlinWeight:
|
|
||||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
|
||||||
_check_marlin_kernels()
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
if bits not in GPTQ_MARLIN_BITS:
|
|
||||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
|
|
||||||
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
|
||||||
)
|
|
||||||
if not (sym or quant_method == "awq"):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
|
|
||||||
|
|
||||||
weights_per_int = 32 // bits
|
|
||||||
in_features = qweight.shape[0]
|
|
||||||
out_features = qweight.shape[1]
|
|
||||||
|
|
||||||
# AWQ uses column packing, GPTQ uses row packing
|
|
||||||
if quant_method == "awq":
|
|
||||||
out_features *= weights_per_int
|
|
||||||
else:
|
|
||||||
in_features *= weights_per_int
|
|
||||||
|
|
||||||
if in_features % groupsize != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if g_idx is not None and desc_act and groupsize != -1:
|
|
||||||
perm = torch.argsort(g_idx).to(torch.int)
|
|
||||||
g_idx = g_idx[perm]
|
|
||||||
else:
|
|
||||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
||||||
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
||||||
|
|
||||||
if quant_method == "awq":
|
|
||||||
repacked = marlin_kernels.awq_marlin_repack(
|
|
||||||
qweight, in_features, out_features, bits
|
|
||||||
)
|
|
||||||
if qzeros is not None:
|
|
||||||
qzeros = awq_to_marlin_zero_points(
|
|
||||||
qzeros,
|
|
||||||
in_features // groupsize,
|
|
||||||
out_features,
|
|
||||||
bits,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
repacked = marlin_kernels.gptq_marlin_repack(
|
|
||||||
qweight, perm, in_features, out_features, bits
|
|
||||||
)
|
|
||||||
|
|
||||||
if qzeros is None:
|
|
||||||
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
|
||||||
|
|
||||||
scales = permute_scales(scales)
|
|
||||||
|
|
||||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
|
||||||
|
|
||||||
return GPTQMarlinWeight(
|
|
||||||
qweight=repacked,
|
|
||||||
qzeros=qzeros,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
perm=perm,
|
|
||||||
bits=bits,
|
|
||||||
is_full_k=is_full_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinLinear(nn.Module):
|
|
||||||
"""
|
|
||||||
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
|
|
||||||
kernels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
weight: GPTQMarlinWeight,
|
|
||||||
bias: Optional[torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
_check_marlin_kernels()
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
|
||||||
out_features = weight.scales.shape[1]
|
|
||||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
|
||||||
|
|
||||||
self.bits = weight.bits
|
|
||||||
self.is_full_k = weight.is_full_k
|
|
||||||
|
|
||||||
self.qweight = weight.qweight
|
|
||||||
self.qzeros = weight.qzeros
|
|
||||||
self.scales = weight.scales
|
|
||||||
self.g_idx = weight.g_idx
|
|
||||||
self.perm = weight.perm
|
|
||||||
if bias is not None:
|
|
||||||
self.bias = bias
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.workspace = torch.zeros(
|
|
||||||
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
A_flat = A.view(-1, A.shape[-1])
|
|
||||||
C = marlin_kernels.gptq_marlin_gemm(
|
|
||||||
A_flat,
|
|
||||||
self.qweight,
|
|
||||||
self.scales,
|
|
||||||
self.qzeros,
|
|
||||||
self.g_idx,
|
|
||||||
self.perm,
|
|
||||||
self.workspace,
|
|
||||||
self.bits,
|
|
||||||
A_flat.shape[0],
|
|
||||||
self.scales.shape[1],
|
|
||||||
A_flat.shape[1],
|
|
||||||
self.is_full_k,
|
|
||||||
self.qzeros.numel() > 0,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
C += self.bias
|
|
||||||
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
def awq_to_marlin_zero_points(
|
|
||||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# AWQ zero-points are quantized and packed on the column dim.
|
|
||||||
# In addition, the values are permuted based on dequantizer.
|
|
||||||
# Here we undo both of these, and then apply marlin permutation
|
|
||||||
# and pack it back.
|
|
||||||
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
|
||||||
|
|
||||||
# Undo interleaving (use argsort(..) to get inverse perm)
|
|
||||||
if num_bits == 4:
|
|
||||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
|
||||||
elif num_bits == 8:
|
|
||||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
|
||||||
else:
|
|
||||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
|
||||||
|
|
||||||
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
|
||||||
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
|
||||||
|
|
||||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
|
||||||
return marlin_zp
|
|
||||||
|
|
||||||
|
|
||||||
def _check_valid_shape(in_features: int, out_features: int):
|
|
||||||
if (in_features % 128 != 0 or out_features % 64 != 0) and (
|
|
||||||
in_features % 64 != 0 or out_features % 128 != 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
|
|
||||||
" The shape elements must be divisible by (128, 64) or (64, 128)."
|
|
||||||
)
|
|
@ -1,346 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
|
||||||
|
|
||||||
try:
|
|
||||||
import marlin_kernels
|
|
||||||
except ImportError:
|
|
||||||
marlin_kernels = None
|
|
||||||
|
|
||||||
|
|
||||||
class MarlinWeightsLoader(WeightsLoader):
|
|
||||||
"""Loader for Marlin-quantized weights."""
|
|
||||||
|
|
||||||
def __init__(self, *, bits: int, is_marlin_24: bool):
|
|
||||||
self.bits = bits
|
|
||||||
self.is_marlin_24 = is_marlin_24
|
|
||||||
|
|
||||||
def get_weights(self, weights: "Weights", prefix: str):
|
|
||||||
"""
|
|
||||||
Get weights at the given prefix and apply without tensor paralllism.
|
|
||||||
"""
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
|
||||||
if is_marlin_24:
|
|
||||||
try:
|
|
||||||
B = weights.get_tensor(f"{prefix}.B_24")
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
|
|
||||||
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
|
||||||
s = weights.get_tensor(f"{prefix}.s")
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
B = weights.get_tensor(f"{prefix}.B")
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
|
|
||||||
s = weights.get_tensor(f"{prefix}.s")
|
|
||||||
weight = MarlinWeight(B=B, s=s)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def get_weights_col_packed(
|
|
||||||
self,
|
|
||||||
weights: Weights,
|
|
||||||
prefix: str,
|
|
||||||
block_sizes: Union[int, List[int]],
|
|
||||||
):
|
|
||||||
if self.is_marlin_24:
|
|
||||||
B = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
B_meta = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
s = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
|
||||||
else:
|
|
||||||
B = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
s = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
weight = MarlinWeight(B=B, s=s)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
||||||
if self.is_marlin_24:
|
|
||||||
try:
|
|
||||||
B = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `marlin` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
B_meta = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
s = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
B = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `marlin` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
s = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
weight = MarlinWeight(B=B, s=s)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
|
||||||
if self.is_marlin_24:
|
|
||||||
try:
|
|
||||||
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
|
|
||||||
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
|
|
||||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
|
||||||
if num_groups == 1:
|
|
||||||
# The number of groups is 1 when groupsize == -1. share
|
|
||||||
# scales between all shards in this case.
|
|
||||||
s = weights.get_tensor(f"{prefix}.s")
|
|
||||||
else:
|
|
||||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
|
||||||
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
|
|
||||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
|
||||||
if num_groups == 1:
|
|
||||||
# The number of groups is 1 when groupsize == -1. share
|
|
||||||
# scales between all shards in this case.
|
|
||||||
s = weights.get_tensor(f"{prefix}.s")
|
|
||||||
else:
|
|
||||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
|
||||||
weight = MarlinWeight(B=B, s=s)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MarlinWeight(Weight):
|
|
||||||
"""
|
|
||||||
Marlin weights.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
|
||||||
s (torch.Tensor): bfloat16/float16 scales.
|
|
||||||
"""
|
|
||||||
|
|
||||||
B: torch.Tensor
|
|
||||||
s: torch.Tensor
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
assert self.B.dtype == torch.int32
|
|
||||||
assert self.s.dtype in [torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
|
||||||
return MarlinLinear(weight=self, bias=bias)
|
|
||||||
|
|
||||||
|
|
||||||
class MarlinLinear(nn.Module):
|
|
||||||
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
_check_marlin_kernels()
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
|
||||||
out_features = weight.s.shape[1]
|
|
||||||
assert (
|
|
||||||
in_features % 128 == 0
|
|
||||||
), f"Number of input features ({in_features}) not divisable by 128"
|
|
||||||
assert (
|
|
||||||
out_features % 256 == 0
|
|
||||||
), f"Number of output features ({out_features}) not divisable by 256"
|
|
||||||
|
|
||||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
|
||||||
assert groupsize in {
|
|
||||||
-1,
|
|
||||||
128,
|
|
||||||
}, f"Group size must be -1 or 128, was {groupsize}"
|
|
||||||
|
|
||||||
self.B = weight.B
|
|
||||||
self.s = weight.s
|
|
||||||
if bias is not None:
|
|
||||||
self.bias = bias
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.workspace = torch.zeros(
|
|
||||||
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
C = marlin_kernels.marlin_gemm(
|
|
||||||
A.view(-1, A.shape[-1]),
|
|
||||||
self.B,
|
|
||||||
self.s,
|
|
||||||
self.workspace,
|
|
||||||
A.shape[0],
|
|
||||||
self.s.shape[1],
|
|
||||||
A.shape[1],
|
|
||||||
)
|
|
||||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
C += self.bias
|
|
||||||
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
|
||||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
|
||||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
|
||||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
|
||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
|
||||||
MARLIN_TILE_SIZE = 16
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPTQMarlin24Weight:
|
|
||||||
"""
|
|
||||||
GPTQ-Marlin 2:4 weights.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
|
||||||
B_meta (torch.Tensor): metadata for 2:4 sparsity.
|
|
||||||
s (torch.Tensor): float16 scales.
|
|
||||||
bits: quantized weight size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
B: torch.Tensor
|
|
||||||
B_meta: torch.Tensor
|
|
||||||
s: torch.Tensor
|
|
||||||
bits: int
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
assert self.B.dtype == torch.int32
|
|
||||||
assert self.B_meta.dtype == torch.int16
|
|
||||||
assert self.s.dtype == torch.float16
|
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
|
||||||
return GPTQMarlin24Linear(
|
|
||||||
weight=self,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlin24Linear(nn.Module):
|
|
||||||
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
_check_marlin_kernels()
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
|
||||||
supported_bits = ", ".join(
|
|
||||||
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
|
||||||
)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
|
||||||
)
|
|
||||||
|
|
||||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
|
||||||
out_features = weight.s.shape[1]
|
|
||||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
|
||||||
|
|
||||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
|
||||||
supported_sizes = ", ".join(
|
|
||||||
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
|
||||||
)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bits = weight.bits
|
|
||||||
weights_per_int32 = 32 // self.bits
|
|
||||||
|
|
||||||
assert (
|
|
||||||
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
|
|
||||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
|
|
||||||
assert (
|
|
||||||
out_features % weights_per_int32 == 0
|
|
||||||
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
|
|
||||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
|
|
||||||
if groupsize != -1 and in_features % groupsize != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.B = weight.B
|
|
||||||
self.B_meta = weight.B_meta
|
|
||||||
self.s = weight.s
|
|
||||||
if bias is not None:
|
|
||||||
self.bias = bias
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.workspace = torch.zeros(
|
|
||||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=weight.B.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
|
||||||
assert marlin_kernels is not None
|
|
||||||
|
|
||||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
|
||||||
A.view(-1, A.shape[-1]),
|
|
||||||
self.B,
|
|
||||||
self.B_meta,
|
|
||||||
self.s,
|
|
||||||
self.workspace,
|
|
||||||
self.bits,
|
|
||||||
A.shape[0],
|
|
||||||
self.s.shape[1],
|
|
||||||
A.shape[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
C += self.bias
|
|
||||||
|
|
||||||
return C
|
|
@ -1,141 +0,0 @@
|
|||||||
import functools
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy
|
|
||||||
import torch
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
try:
|
|
||||||
import marlin_kernels
|
|
||||||
except ImportError:
|
|
||||||
marlin_kernels = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
|
||||||
has_sm_8_0 = major >= 8
|
|
||||||
except Exception:
|
|
||||||
has_sm_8_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
def _check_marlin_kernels():
|
|
||||||
if not (SYSTEM == "cuda" and has_sm_8_0):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
|
||||||
)
|
|
||||||
|
|
||||||
if marlin_kernels is None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"marlin is not installed, install it with: pip install server/marlin"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
|
|
||||||
@functools.cache
|
|
||||||
def get_perms() -> Tuple[List[int], List[int]]:
|
|
||||||
scale_perm = []
|
|
||||||
for i in range(8):
|
|
||||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
|
||||||
scale_perm_single = []
|
|
||||||
for i in range(4):
|
|
||||||
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
|
||||||
return scale_perm, scale_perm_single
|
|
||||||
|
|
||||||
|
|
||||||
def permute_scales(scales: torch.Tensor):
|
|
||||||
scale_perm, scale_perm_single = get_perms()
|
|
||||||
out_features = scales.shape[1]
|
|
||||||
if scales.shape[0] == 1:
|
|
||||||
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
|
||||||
else:
|
|
||||||
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
|
|
||||||
return scales.reshape((-1, out_features)).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
# Functions below are from vLLM
|
|
||||||
|
|
||||||
|
|
||||||
def get_pack_factor(bits: int) -> int:
|
|
||||||
if 32 % bits != 0:
|
|
||||||
raise ValueError(f"Cannot {bits} bit values into uint32")
|
|
||||||
return 32 // bits
|
|
||||||
|
|
||||||
|
|
||||||
def pack_cols(
|
|
||||||
q_w: torch.Tensor,
|
|
||||||
num_bits: int,
|
|
||||||
size_k: int,
|
|
||||||
size_n: int,
|
|
||||||
):
|
|
||||||
assert q_w.shape == (size_k, size_n)
|
|
||||||
|
|
||||||
pack_factor = get_pack_factor(num_bits)
|
|
||||||
assert size_n % pack_factor == 0
|
|
||||||
|
|
||||||
orig_device = q_w.device
|
|
||||||
|
|
||||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
|
||||||
|
|
||||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
|
||||||
|
|
||||||
for i in range(pack_factor):
|
|
||||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
|
||||||
|
|
||||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
|
||||||
q_res = q_res.contiguous()
|
|
||||||
|
|
||||||
return q_res
|
|
||||||
|
|
||||||
|
|
||||||
def unpack_cols(
|
|
||||||
packed_q_w: torch.Tensor,
|
|
||||||
num_bits: int,
|
|
||||||
size_k: int,
|
|
||||||
size_n: int,
|
|
||||||
):
|
|
||||||
pack_factor = get_pack_factor(num_bits)
|
|
||||||
assert size_n % pack_factor == 0
|
|
||||||
assert packed_q_w.shape == (
|
|
||||||
size_k,
|
|
||||||
size_n // pack_factor,
|
|
||||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
|
||||||
packed_q_w.shape, size_k, size_n, pack_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
orig_device = packed_q_w.device
|
|
||||||
|
|
||||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
|
||||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
|
||||||
|
|
||||||
mask = (1 << num_bits) - 1
|
|
||||||
for i in range(pack_factor):
|
|
||||||
vals = packed_q_w_cpu & mask
|
|
||||||
packed_q_w_cpu >>= num_bits
|
|
||||||
q_res[:, i::pack_factor] = vals
|
|
||||||
|
|
||||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
|
||||||
q_res = q_res.contiguous()
|
|
||||||
|
|
||||||
return q_res
|
|
||||||
|
|
||||||
|
|
||||||
def marlin_zero_points(
|
|
||||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
scale_perm, _ = get_perms()
|
|
||||||
# Permute zero-points in a similar way to scales, but do not use the
|
|
||||||
# "single" permutation, since zero-points are applied on every MMA
|
|
||||||
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
|
||||||
|
|
||||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
|
||||||
if num_bits == 4:
|
|
||||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
|
||||||
elif num_bits == 8:
|
|
||||||
interleave = numpy.array([0, 2, 1, 3])
|
|
||||||
else:
|
|
||||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
|
||||||
|
|
||||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
|
||||||
zp = zp.reshape((-1, size_n)).contiguous()
|
|
||||||
zp = pack_cols(zp, num_bits, size_k, size_n)
|
|
||||||
|
|
||||||
return zp
|
|
@ -10,13 +10,8 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
|
||||||
from text_generation_server.layers.moe.gptq_marlin import (
|
|
||||||
GPTQMarlinSparseMoELayer,
|
|
||||||
can_use_marlin_moe_gemm,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
@ -24,12 +19,7 @@ from text_generation_server.utils.weights import (
|
|||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
from .fused_moe import fused_topk, grouped_topk
|
||||||
from .fused_moe_rocm import grouped_topk
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
|
||||||
elif SYSTEM != "ipex":
|
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||||
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||||
@ -52,6 +42,8 @@ class MoELayer(Protocol):
|
|||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
): ...
|
): ...
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -81,9 +73,14 @@ class DenseMoELayer(nn.Module):
|
|||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
assert scoring_func is None, "scoring func is not handled"
|
||||||
|
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||||
|
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||||
@ -199,22 +196,27 @@ class SparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(weights.loader, DefaultWeightsLoader)
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
|
and weights.loader.to_fp8
|
||||||
|
):
|
||||||
|
cls = FP8SparseMoELayer
|
||||||
|
else:
|
||||||
cls = UnquantizedSparseMoELayer
|
cls = UnquantizedSparseMoELayer
|
||||||
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
|
||||||
cls = GPTQMarlinSparseMoELayer
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
|
||||||
)
|
)
|
||||||
|
|
||||||
log_once(
|
log_once(
|
||||||
@ -230,6 +232,8 @@ class SparseMoELayer(nn.Module):
|
|||||||
topk=topk,
|
topk=topk,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
gate_proj_name=gate_proj_name,
|
gate_proj_name=gate_proj_name,
|
||||||
up_proj_name=up_proj_name,
|
up_proj_name=up_proj_name,
|
||||||
down_proj_name=down_proj_name,
|
down_proj_name=down_proj_name,
|
||||||
@ -241,17 +245,6 @@ class SparseMoELayer(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def is_supported(weights: Weights) -> bool:
|
def is_supported(weights: Weights) -> bool:
|
||||||
return (
|
return (
|
||||||
(
|
|
||||||
isinstance(weights.loader, DefaultWeightsLoader)
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
)
|
) or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
|
||||||
or (
|
|
||||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
|
||||||
and can_use_marlin_moe_gemm(
|
|
||||||
quant_method=weights.loader.quant_method,
|
|
||||||
quantize=weights.loader.quantize,
|
|
||||||
sym=weights.loader.sym,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
fp8_quantize,
|
||||||
|
quant_dtype,
|
||||||
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .unquantized import fused_moe
|
||||||
|
except Exception:
|
||||||
|
fused_moe = None
|
||||||
|
|
||||||
|
|
||||||
|
class FP8SparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
|
(
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.gate_up_proj_weight_scale,
|
||||||
|
self.gate_up_proj_input_scale,
|
||||||
|
) = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||||
|
_load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=self.gate_up_proj_weight_scale,
|
||||||
|
w2_scale=self.down_proj_weight_scale,
|
||||||
|
a1_scale=self.gate_up_proj_input_scale,
|
||||||
|
a2_scale=self.down_proj_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights(
|
||||||
|
get_weight_fn,
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
all_weight_scales = None
|
||||||
|
max_input_scale = None
|
||||||
|
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = get_weight_fn(prefix, i, name, weights)
|
||||||
|
|
||||||
|
assert isinstance(weight, Fp8Weight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=quant_dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
if all_weight_scales is None:
|
||||||
|
all_weight_scales = torch.empty(
|
||||||
|
(n_experts,) + weight.weight_scale.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||||
|
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||||
|
normalize_e4m3fn_to_native_float8(
|
||||||
|
weight.weight, weight.weight_scale, weight.input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if current_input_scale is not None:
|
||||||
|
if max_input_scale is None or current_input_scale > max_input_scale:
|
||||||
|
max_input_scale = current_input_scale
|
||||||
|
else:
|
||||||
|
all_weight[i], all_weight_scales[i] = fp8_quantize(
|
||||||
|
weight.weight, scalar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight, all_weight_scales, max_input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||||
|
)
|
@ -16,10 +16,8 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
|
||||||
def grouped_topk(
|
def grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@ -50,3 +48,18 @@ def grouped_topk(
|
|||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights = torch.nn.functional.softmax(
|
||||||
|
gating_output, dim=1, dtype=torch.float32
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
@ -1,215 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.weights import Weights
|
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
|
||||||
GPTQMarlinWeight,
|
|
||||||
GPTQMarlinWeightsLoader,
|
|
||||||
)
|
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
|
||||||
else:
|
|
||||||
fused_marlin_moe = None
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
|
||||||
has_sm_8_0 = major >= 8
|
|
||||||
except Exception:
|
|
||||||
has_sm_8_0 = False
|
|
||||||
|
|
||||||
|
|
||||||
def can_use_marlin_moe_gemm(
|
|
||||||
*,
|
|
||||||
quant_method: str,
|
|
||||||
quantize: str,
|
|
||||||
sym: bool,
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
SYSTEM == "cuda"
|
|
||||||
and fused_marlin_moe is not None
|
|
||||||
and has_sm_8_0
|
|
||||||
and quantize == "gptq"
|
|
||||||
and quant_method == "gptq"
|
|
||||||
and sym
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPTQMarlinMoEWeight:
|
|
||||||
qweight: torch.Tensor
|
|
||||||
qzeros: torch.Tensor
|
|
||||||
scales: torch.Tensor
|
|
||||||
g_idx: torch.Tensor
|
|
||||||
perm: torch.Tensor
|
|
||||||
is_full_k: bool
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinSparseMoELayer(nn.Module):
|
|
||||||
"""
|
|
||||||
MoE layer that uses a fused GPTQ-Marlin kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
n_expert_group: Optional[int],
|
|
||||||
n_experts: int,
|
|
||||||
prefix: str,
|
|
||||||
renormalize: bool,
|
|
||||||
topk: int,
|
|
||||||
topk_group: Optional[int],
|
|
||||||
weights: Weights,
|
|
||||||
gate_proj_name: str = "gate_proj",
|
|
||||||
up_proj_name: str = "up_proj",
|
|
||||||
down_proj_name: str = "down_proj",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not (
|
|
||||||
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (n_expert_group is None) == (
|
|
||||||
topk_group is None
|
|
||||||
), "n_expert_group and topk_group must both be None or have some value"
|
|
||||||
|
|
||||||
self.n_expert_group = n_expert_group
|
|
||||||
self.topk = topk
|
|
||||||
self.topk_group = topk_group
|
|
||||||
self.renormalize = renormalize
|
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
|
||||||
prefix=prefix,
|
|
||||||
n_experts=n_experts,
|
|
||||||
names=[gate_proj_name, up_proj_name],
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.down_proj = _load_expert_weights_row(
|
|
||||||
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bits = weights.loader.bits
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
|
||||||
return fused_marlin_moe(
|
|
||||||
x,
|
|
||||||
w1=self.gate_up_proj.qweight,
|
|
||||||
w2=self.down_proj.qweight,
|
|
||||||
g_idx1=self.gate_up_proj.g_idx,
|
|
||||||
g_idx2=self.down_proj.g_idx,
|
|
||||||
perm1=self.gate_up_proj.perm,
|
|
||||||
perm2=self.down_proj.perm,
|
|
||||||
w1_scale=self.gate_up_proj.scales,
|
|
||||||
w2_scale=self.down_proj.scales,
|
|
||||||
is_full_k1=self.gate_up_proj.is_full_k,
|
|
||||||
is_full_k2=self.down_proj.is_full_k,
|
|
||||||
gating_output=gating_output,
|
|
||||||
topk=self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.n_expert_group is not None,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_bits=self.bits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_expert_multi_weights_col(
|
|
||||||
*,
|
|
||||||
prefix: str,
|
|
||||||
n_experts: int,
|
|
||||||
names: List[str],
|
|
||||||
weights: Weights,
|
|
||||||
) -> GPTQMarlinMoEWeight:
|
|
||||||
moe_weight = None
|
|
||||||
for i in range(n_experts):
|
|
||||||
weight = weights.get_multi_weights_col(
|
|
||||||
[f"{prefix}.{i}.{name}" for name in names], 0
|
|
||||||
)
|
|
||||||
assert isinstance(weight, GPTQMarlinWeight)
|
|
||||||
moe_weight = _pack_weight(
|
|
||||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
|
||||||
)
|
|
||||||
assert moe_weight is not None
|
|
||||||
return moe_weight
|
|
||||||
|
|
||||||
|
|
||||||
def _load_expert_weights_row(
|
|
||||||
*,
|
|
||||||
prefix: str,
|
|
||||||
n_experts: int,
|
|
||||||
name: str,
|
|
||||||
weights: Weights,
|
|
||||||
) -> GPTQMarlinMoEWeight:
|
|
||||||
moe_weight = None
|
|
||||||
for i in range(n_experts):
|
|
||||||
weight = weights.get_weights_row(
|
|
||||||
f"{prefix}.{i}.{name}",
|
|
||||||
)
|
|
||||||
assert isinstance(weight, GPTQMarlinWeight)
|
|
||||||
moe_weight = _pack_weight(
|
|
||||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
|
||||||
)
|
|
||||||
assert moe_weight is not None
|
|
||||||
return moe_weight
|
|
||||||
|
|
||||||
|
|
||||||
def _pack_weight(
|
|
||||||
*,
|
|
||||||
n_experts: int,
|
|
||||||
expert: int,
|
|
||||||
moe_weight: Optional[GPTQMarlinMoEWeight],
|
|
||||||
weight: GPTQMarlinWeight,
|
|
||||||
) -> GPTQMarlinMoEWeight:
|
|
||||||
if moe_weight is None:
|
|
||||||
qweight = torch.empty(
|
|
||||||
(n_experts,) + weight.qweight.shape,
|
|
||||||
dtype=weight.qweight.dtype,
|
|
||||||
device=weight.qweight.device,
|
|
||||||
)
|
|
||||||
qzeros = torch.empty(
|
|
||||||
(n_experts,) + weight.qzeros.shape,
|
|
||||||
dtype=weight.qzeros.dtype,
|
|
||||||
device=weight.qzeros.device,
|
|
||||||
)
|
|
||||||
scales = torch.empty(
|
|
||||||
(n_experts,) + weight.scales.shape,
|
|
||||||
dtype=weight.scales.dtype,
|
|
||||||
device=weight.scales.device,
|
|
||||||
)
|
|
||||||
g_idx = torch.empty(
|
|
||||||
(n_experts,) + weight.g_idx.shape,
|
|
||||||
dtype=weight.g_idx.dtype,
|
|
||||||
device=weight.g_idx.device,
|
|
||||||
)
|
|
||||||
perm = torch.empty(
|
|
||||||
(n_experts,) + weight.perm.shape,
|
|
||||||
dtype=weight.perm.dtype,
|
|
||||||
device=weight.perm.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
moe_weight = GPTQMarlinMoEWeight(
|
|
||||||
qweight=qweight,
|
|
||||||
qzeros=qzeros,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
perm=perm,
|
|
||||||
is_full_k=weight.is_full_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
moe_weight.qweight[expert] = weight.qweight
|
|
||||||
moe_weight.qzeros[expert] = weight.qzeros
|
|
||||||
moe_weight.scales[expert] = weight.scales
|
|
||||||
moe_weight.g_idx[expert] = weight.g_idx
|
|
||||||
moe_weight.perm[expert] = weight.perm
|
|
||||||
|
|
||||||
return moe_weight
|
|
@ -3,13 +3,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
if SYSTEM == "rocm":
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
elif SYSTEM != "ipex":
|
|
||||||
from moe_kernels.fused_moe import fused_moe
|
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedSparseMoELayer(nn.Module):
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
@ -23,6 +18,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
@ -37,6 +34,9 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.renormalize = renormalize
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -53,30 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||||
if SYSTEM == "rocm":
|
for i in range(n_experts):
|
||||||
return fused_moe(
|
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||||
x,
|
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||||
self.gate_up_proj,
|
|
||||||
self.down_proj,
|
|
||||||
gating_output,
|
|
||||||
self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return fused_moe(
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
x,
|
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||||
w1=self.gate_up_proj,
|
|
||||||
w2=self.down_proj,
|
|
||||||
gating_output=gating_output,
|
|
||||||
topk=self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
inplace=True,
|
|
||||||
use_grouped_topk=self.n_expert_group is not None,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_expert_multi_weights_col(
|
def _load_expert_multi_weights_col(
|
||||||
|
@ -2,14 +2,10 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
|
RotaryPosEmbeddingMode,
|
||||||
if SYSTEM == "cuda":
|
apply_rotary_pos_emb,
|
||||||
import rotary_emb
|
)
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm._C import ops
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
@ -30,7 +26,7 @@ def _get_rope_config(config):
|
|||||||
|
|
||||||
|
|
||||||
class PositionRotaryEmbedding(nn.Module):
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, inv_freq, scaling_factor):
|
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
@ -40,6 +36,9 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -48,40 +47,41 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Such controlflows may add some overhead.
|
num_tokens = query.shape[0]
|
||||||
if SYSTEM == "cuda":
|
|
||||||
rotary_dim = cos.shape[-1]
|
|
||||||
q1 = query[..., :rotary_dim]
|
|
||||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
|
||||||
|
|
||||||
k1 = key[..., :rotary_dim]
|
|
||||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
|
||||||
|
|
||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
||||||
|
# to query hidden dimension, so the original tensors need to be
|
||||||
|
# expanded
|
||||||
|
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
||||||
|
# and expansion of cos/sin tensors via concatenation
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||||
|
cos = torch.cat((cos, cos), dim=-1)
|
||||||
|
sin = torch.cat((sin, sin), dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
key_shape = key.shape
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
key = key.view(num_tokens, -1, head_size)
|
||||||
elif SYSTEM == "ipex":
|
key_rot = key[..., :rotary_dim]
|
||||||
ipex.llm.functional.rotary_embedding(
|
key_pass = key[..., rotary_dim:]
|
||||||
query, key, sin, cos, query.size(-1), True
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
)
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, config, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
scaling_factor = None
|
scaling_factor = None
|
||||||
rope_scaling = _get_rope_config(config)
|
rope_scaling = _get_rope_config(config)
|
||||||
|
if not hasattr(config, "max_position_embeddings") and hasattr(
|
||||||
|
config, "max_seq_len"
|
||||||
|
):
|
||||||
|
# handling for dbrx
|
||||||
|
config.max_position_embeddings = config.max_seq_len
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
# `rope_type` is now standard in transformers, but some existing models
|
# `rope_type` is now standard in transformers, but some existing models
|
||||||
# have `type` instead.
|
# have `type` instead.
|
||||||
@ -89,6 +89,17 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
if rope_type == "linear":
|
if rope_type == "linear":
|
||||||
pass
|
pass
|
||||||
|
elif rope_type == "default":
|
||||||
|
pass
|
||||||
|
elif rope_type == "mrope":
|
||||||
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
|
if mrope_section is not None:
|
||||||
|
return RotaryPositionEmbeddingMultimodalSections(
|
||||||
|
inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
mrope_section,
|
||||||
|
config.max_position_embeddings,
|
||||||
|
)
|
||||||
elif rope_type == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
@ -109,7 +120,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(inv_freq, scaling_factor)
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
elif rope_type == "yarn":
|
elif rope_type == "yarn":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
@ -185,12 +196,13 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
long_inv_freq=long_inv_freq,
|
long_inv_freq=long_inv_freq,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
original_max_position_embeddings=original_max_position_embeddings,
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
)
|
)
|
||||||
return cls(inv_freq, scaling_factor)
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix, weights):
|
def load(cls, config, prefix, weights):
|
||||||
@ -236,7 +248,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
)
|
)
|
||||||
return cls(inv_freq, scaling_factor)
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
@ -257,17 +269,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
|
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||||
"""
|
|
||||||
Return cos and sin for the asked position ids
|
|
||||||
"""
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
# For RoCm, we always use float cos/sin to avoid a cast.
|
|
||||||
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
|
||||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
|
||||||
|
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
@ -283,6 +285,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
long_inv_freq,
|
long_inv_freq,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
original_max_position_embeddings,
|
original_max_position_embeddings,
|
||||||
|
max_position_embeddings,
|
||||||
):
|
):
|
||||||
super(PositionRotaryEmbedding, self).__init__()
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
self.short_inv_freq = short_inv_freq
|
self.short_inv_freq = short_inv_freq
|
||||||
@ -295,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
@ -348,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
if (
|
if (
|
||||||
@ -383,7 +392,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
@ -461,7 +470,9 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
mscale_all_dim: float,
|
mscale_all_dim: float,
|
||||||
):
|
):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(
|
||||||
|
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
||||||
|
)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
@ -546,3 +557,50 @@ def apply_llama3_scaling(
|
|||||||
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||||
|
|
||||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inv_freq: torch.Tensor,
|
||||||
|
scaling_factor: float,
|
||||||
|
sections: list,
|
||||||
|
max_position_embeddings,
|
||||||
|
):
|
||||||
|
self.sections = sections
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self.section_indices = (
|
||||||
|
torch.arange(len(self.sections))
|
||||||
|
.repeat_interleave(torch.tensor(self.sections))
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.to(inv_freq.device)
|
||||||
|
)
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||||
|
):
|
||||||
|
# always cache the cos/sin for the full sequence length to avoid
|
||||||
|
# recomputing if the sequence length is smaller than the cached one
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
self._sections = self.section_indices.expand(seqlen, -1, -1)
|
||||||
|
|
||||||
|
def get_cos_sin(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
slen = position_ids.shape[0]
|
||||||
|
|
||||||
|
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
return cos, sin
|
||||||
|
@ -2,10 +2,8 @@ import torch
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
import habana_frameworks.torch as htorch
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
|
|
||||||
class LayerConcat(torch.nn.Module):
|
class LayerConcat(torch.nn.Module):
|
||||||
@ -90,11 +88,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
if SYSTEM == "ipex":
|
htorch.core.mark_step()
|
||||||
ipex.distributed.all_gather_into_tensor(
|
|
||||||
world_out, gather_input, group=self.process_group
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
@ -107,9 +101,8 @@ class TensorParallelHead(SuperLayer):
|
|||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
if SYSTEM == "ipex":
|
|
||||||
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
htorch.core.mark_step()
|
||||||
else:
|
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
@ -202,9 +195,10 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
if SYSTEM == "ipex":
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
else:
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
|
htorch.core.mark_step()
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -242,8 +236,9 @@ class TensorParallelEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
if SYSTEM == "ipex":
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
else:
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
|
htorch.core.mark_step()
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# ruff: noqa: F821
|
||||||
|
# the above line disables the `undefined-name` rule for the model type variables
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -8,6 +10,7 @@ from huggingface_hub import hf_hub_download, HfApi
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
import enum
|
||||||
|
|
||||||
# Needed to properly setup habana_frameworks
|
# Needed to properly setup habana_frameworks
|
||||||
|
|
||||||
@ -16,15 +19,10 @@ from text_generation_server.models.model import Model
|
|||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.starcoder import StarCoder
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||||
from text_generation_server.models.custom_modeling.mllama import (
|
PhiMoEConfig,
|
||||||
MllamaForConditionalGeneration,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
|
||||||
LlavaNextForConditionalGeneration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
|
||||||
from text_generation_server.utils.adapter import (
|
from text_generation_server.utils.adapter import (
|
||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
build_layer_weight_lookup,
|
build_layer_weight_lookup,
|
||||||
@ -33,9 +31,285 @@ from text_generation_server.utils.adapter import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.adapters.lora import LoraWeights
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Model",
|
||||||
|
"CausalLM",
|
||||||
|
"Seq2SeqLM",
|
||||||
|
"get_model_with_lora_adapters",
|
||||||
|
]
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
|
||||||
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
if ATTENTION == "paged":
|
||||||
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
|
FlashDeepseekV2ForCausalLM,
|
||||||
|
DeepseekV2Config,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
|
||||||
|
FlashDeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Config,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
|
FlashCohereForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
|
FlashGemmaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
|
FlashGemma2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
|
FlashDbrxForCausalLM,
|
||||||
|
DbrxConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
FlashGPTNeoXForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.pali_gemma import (
|
||||||
|
PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||||
|
FlashMllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llava_next import (
|
||||||
|
FlashLlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
|
FlashSantacoderForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||||
|
FlashGPT2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
|
||||||
|
FlashGPTJForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics3 import (
|
||||||
|
Idefics3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
|
||||||
|
Qwen2_5VLForConditionalGeneration,
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
|
except ImportError as e:
|
||||||
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
__all__.append(FlashCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(enum.Enum):
|
||||||
|
DEEPSEEK_V2 = {
|
||||||
|
"type": "deepseek_v2",
|
||||||
|
"name": "Deepseek V2",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||||
|
}
|
||||||
|
DEEPSEEK_V3 = {
|
||||||
|
"type": "deepseek_v3",
|
||||||
|
"name": "Deepseek V3",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
|
||||||
|
}
|
||||||
|
IDEFICS2 = {
|
||||||
|
"type": "idefics2",
|
||||||
|
"name": "Idefics 2",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
IDEFICS3 = {
|
||||||
|
"type": "idefics3",
|
||||||
|
"name": "Idefics 3",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
LLAVA_NEXT = {
|
||||||
|
"type": "llava_next",
|
||||||
|
"name": "Llava Next (1.6)",
|
||||||
|
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
LLAMA = {
|
||||||
|
"type": "llama",
|
||||||
|
"name": "Llama",
|
||||||
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
|
}
|
||||||
|
PHI3 = {
|
||||||
|
"type": "phi3",
|
||||||
|
"name": "Phi 3",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
}
|
||||||
|
GRANITE = {
|
||||||
|
"type": "granite",
|
||||||
|
"name": "Granite",
|
||||||
|
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
}
|
||||||
|
GEMMA = {
|
||||||
|
"type": "gemma",
|
||||||
|
"name": "Gemma",
|
||||||
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
|
}
|
||||||
|
PALIGEMMA = {
|
||||||
|
"type": "paligemma",
|
||||||
|
"name": "PaliGemma",
|
||||||
|
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||||
|
}
|
||||||
|
GEMMA2 = {
|
||||||
|
"type": "gemma2",
|
||||||
|
"name": "Gemma2",
|
||||||
|
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
||||||
|
}
|
||||||
|
COHERE = {
|
||||||
|
"type": "cohere",
|
||||||
|
"name": "Cohere",
|
||||||
|
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
|
||||||
|
}
|
||||||
|
DBRX = {
|
||||||
|
"type": "dbrx",
|
||||||
|
"name": "Dbrx",
|
||||||
|
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||||
|
}
|
||||||
|
MAMBA = {
|
||||||
|
"type": "mamba",
|
||||||
|
"name": "Mamba",
|
||||||
|
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||||
|
}
|
||||||
|
MISTRAL = {
|
||||||
|
"type": "mistral",
|
||||||
|
"name": "Mistral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
|
||||||
|
}
|
||||||
|
MIXTRAL = {
|
||||||
|
"type": "mixtral",
|
||||||
|
"name": "Mixtral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||||
|
}
|
||||||
|
GPT_BIGCODE = {
|
||||||
|
"type": "gpt_bigcode",
|
||||||
|
"name": "Gpt Bigcode",
|
||||||
|
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
|
||||||
|
}
|
||||||
|
PHI = {
|
||||||
|
"type": "phi",
|
||||||
|
"name": "Phi",
|
||||||
|
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||||
|
}
|
||||||
|
PHI_MOE = {
|
||||||
|
"type": "phimoe",
|
||||||
|
"name": "PhiMoe",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
}
|
||||||
|
BAICHUAN = {
|
||||||
|
"type": "baichuan",
|
||||||
|
"name": "Baichuan",
|
||||||
|
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
}
|
||||||
|
FALCON = {
|
||||||
|
"type": "falcon",
|
||||||
|
"name": "Falcon",
|
||||||
|
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
|
||||||
|
}
|
||||||
|
STARCODER2 = {
|
||||||
|
"type": "starcoder2",
|
||||||
|
"name": "StarCoder 2",
|
||||||
|
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||||
|
}
|
||||||
|
QWEN2 = {
|
||||||
|
"type": "qwen2",
|
||||||
|
"name": "Qwen 2",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
|
||||||
|
}
|
||||||
|
QWEN2_VL = {
|
||||||
|
"type": "qwen2_vl",
|
||||||
|
"name": "Qwen 2 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
||||||
|
}
|
||||||
|
QWEN2_5_VL = {
|
||||||
|
"type": "qwen2_5_vl",
|
||||||
|
"name": "Qwen 2.5 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
|
||||||
|
}
|
||||||
|
GALACTICA = {
|
||||||
|
"type": "galactica",
|
||||||
|
"name": "Galactica",
|
||||||
|
"url": "https://huggingface.co/facebook/galactica-120b",
|
||||||
|
}
|
||||||
|
SANTACODER = {
|
||||||
|
"type": "santacoder",
|
||||||
|
"name": "SantaCoder",
|
||||||
|
"url": "https://huggingface.co/bigcode/santacoder",
|
||||||
|
}
|
||||||
|
GPT2 = {
|
||||||
|
"type": "gpt2",
|
||||||
|
"name": "Gpt2",
|
||||||
|
"url": "https://huggingface.co/openai-community/gpt2",
|
||||||
|
}
|
||||||
|
GPT_NEOX = {
|
||||||
|
"type": "gpt_neox",
|
||||||
|
"name": "Gpt Neox",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
|
||||||
|
}
|
||||||
|
GPTJ = {
|
||||||
|
"type": "gptj",
|
||||||
|
"name": "Gptj",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
||||||
|
}
|
||||||
|
MLLAMA = {
|
||||||
|
"type": "mllama",
|
||||||
|
"name": "Mllama",
|
||||||
|
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__GLOBALS = locals()
|
||||||
|
for data in ModelType:
|
||||||
|
__GLOBALS[data.name] = data.value["type"]
|
||||||
|
|
||||||
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
|
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
@ -53,9 +327,7 @@ def get_model(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
adapt_transformers_to_gaudi()
|
global FLASH_ATTENTION
|
||||||
if SDP_ON_BF16 == 1:
|
|
||||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
|
||||||
|
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
@ -177,9 +449,393 @@ def get_model(
|
|||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
|
kv_cache_dtype = dtype
|
||||||
|
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
if model_type == DEEPSEEK_V2:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV2Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
elif model_type == DEEPSEEK_V3:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV3ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV3Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
model_type == GPT_BIGCODE
|
||||||
|
or model_type == GPT2
|
||||||
|
and model_id.startswith("bigcode/")
|
||||||
|
):
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashSantacoderForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
num_kv_heads=1,
|
||||||
|
)
|
||||||
|
elif model_type == GPT2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPT2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GPTJ:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPTJForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GPT_NEOX:
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
elif model_type == PHI:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == PHI_MOE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
config_class=PhiMoEConfig,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == BAICHUAN:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemmaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemma2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == COHERE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashCohereForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == DBRX:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDbrxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Dbrx works better in bfloat16.
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DbrxConfig,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
|
||||||
|
and not sharded
|
||||||
|
and not config_dict.get("alibi", False)
|
||||||
|
):
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
|
)
|
||||||
|
elif model_type == MISTRAL:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMistralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == MIXTRAL:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMixtralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == STARCODER2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2_VL:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2_5_VL:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2_5VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=Qwen2_5_VLConfig,
|
||||||
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
|
elif model_type == MLLAMA:
|
||||||
|
return FlashMllamaCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMllamaForConditionalGeneration,
|
||||||
|
batch_class=FlashMllamaCausalLMBatch,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == IDEFICS2:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
|
)
|
||||||
|
elif model_type == IDEFICS3:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics3ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
|
)
|
||||||
|
elif model_type == PALIGEMMA:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
elif model_type == LLAVA_NEXT:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_class=FlashLlavaNextForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
MllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
|
LlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapt_transformers_to_gaudi()
|
||||||
|
if SDP_ON_BF16 == 1:
|
||||||
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
return BLOOM(
|
return BLOOM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
|
|||||||
past_value.view(-1, *past_value.shape[-2:]),
|
past_value.view(-1, *past_value.shape[-2:]),
|
||||||
)
|
)
|
||||||
|
|
||||||
if CUSTOM_KERNELS_ENABLED:
|
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
|
||||||
assert self.training is False, "Only foward pass was implemented"
|
assert self.training is False, "Only foward pass was implemented"
|
||||||
assert (
|
assert (
|
||||||
attention_mask.shape[-1] < 4096
|
attention_mask.shape[-1] < 4096
|
||||||
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_to_bloom_cache(
|
def _convert_to_bloom_cache(
|
||||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||||
|
@ -28,10 +28,10 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -39,7 +39,6 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -47,11 +46,10 @@ from text_generation_server.layers.rotary import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
if SYSTEM == "cuda":
|
RotaryPosEmbeddingMode,
|
||||||
import dropout_layer_norm
|
apply_rotary_pos_emb,
|
||||||
else:
|
)
|
||||||
dropout_layer_norm = None
|
|
||||||
|
|
||||||
|
|
||||||
class CohereRotary(PositionRotaryEmbedding):
|
class CohereRotary(PositionRotaryEmbedding):
|
||||||
@ -63,38 +61,25 @@ class CohereRotary(PositionRotaryEmbedding):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Such controlflows may add some overhead.
|
# Such controlflows may add some overhead.
|
||||||
if SYSTEM == "cuda":
|
num_tokens = query.shape[0]
|
||||||
import rotary_emb
|
|
||||||
|
|
||||||
q1 = query[..., ::2]
|
|
||||||
q2 = query[..., 1::2]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
|
||||||
|
|
||||||
k1 = key[..., ::2]
|
|
||||||
k2 = key[..., 1::2]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm._C import ops
|
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
|
||||||
|
|
||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||||
|
sin = torch.repeat_interleave(sin, 2, dim=-1)
|
||||||
|
cos = torch.repeat_interleave(cos, 2, dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
key_shape = key.shape
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
key = key.view(num_tokens, -1, head_size)
|
||||||
elif SYSTEM == "ipex":
|
key_rot = key[..., :rotary_dim]
|
||||||
import intel_extension_for_pytorch as ipex
|
key_pass = key[..., rotary_dim:]
|
||||||
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
ipex.llm.functional.rotary_embedding(
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
query, key, sin, cos, query.size(-1), False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CohereLayerNorm(nn.Module):
|
class CohereLayerNorm(nn.Module):
|
||||||
@ -107,7 +92,6 @@ class CohereLayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
|
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
-1, self.weight.shape[0], self.weight.shape[1]
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
)
|
)
|
||||||
@ -121,36 +105,6 @@ class CohereLayerNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||||
return hidden_states.to(input_dtype)
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
(
|
|
||||||
hidden_states,
|
|
||||||
*rest,
|
|
||||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
None,
|
|
||||||
self.ones,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Required to apply one weight matrix per head
|
|
||||||
hidden_states = hidden_states.view(
|
|
||||||
-1, self.weight.shape[0], self.weight.shape[1]
|
|
||||||
)
|
|
||||||
hidden_states = self.weight * hidden_states
|
|
||||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
@ -229,6 +183,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.use_qk_norm = config.use_qk_norm
|
self.use_qk_norm = config.use_qk_norm
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
@ -264,10 +219,9 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, key, value = qkv.split(
|
query, key, value = qkv.split(
|
||||||
@ -291,30 +245,35 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -386,10 +345,9 @@ class FlashCohereLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -400,10 +358,9 @@ class FlashCohereLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_hidden_states)
|
mlp_output = self.mlp(normed_hidden_states)
|
||||||
@ -452,18 +409,15 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: torch.Tensor,
|
seqlen: torch.Tensor,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
|
||||||
@ -475,10 +429,9 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -516,11 +469,9 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -529,10 +480,9 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -20,17 +20,14 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
PREFILL_IN_KV_CACHE,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -46,6 +43,7 @@ from text_generation_server.layers.rotary import (
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
|
|
||||||
|
|
||||||
class DbrxAttentionConfig(PretrainedConfig):
|
class DbrxAttentionConfig(PretrainedConfig):
|
||||||
@ -290,6 +288,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -309,10 +308,9 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
if self.clip_qkv is not None:
|
if self.clip_qkv is not None:
|
||||||
@ -330,30 +328,35 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -387,10 +390,9 @@ class DbrxNormAttentionNorm(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
||||||
|
|
||||||
@ -401,10 +403,9 @@ class DbrxNormAttentionNorm(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -482,18 +483,15 @@ class BlockSparseMoE(nn.Module):
|
|||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
|
||||||
|
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
out = fused_moe(
|
out = self.hpu_fused_moe(x, router_logits, self.top_k)
|
||||||
x,
|
|
||||||
self.wv1,
|
|
||||||
self.w2,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=self.moe_normalize_expert_weights,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
@ -620,10 +618,9 @@ class DbrxLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output, attn_res = self.attn(
|
attn_output, attn_res = self.attn(
|
||||||
@ -633,10 +630,9 @@ class DbrxLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_output = self.moe(attn_output)
|
moe_output = self.moe(attn_output)
|
||||||
@ -677,18 +673,15 @@ class DbrxModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -699,10 +692,9 @@ class DbrxModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -732,11 +724,9 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -745,10 +735,9 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -33,21 +33,14 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
try:
|
|
||||||
from vllm import _custom_C
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Config(PretrainedConfig):
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -232,6 +225,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.kv_a_layernorm = FastRMSNorm.load(
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
@ -260,11 +255,10 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
cu_seqlen_prefill: torch.Tensor,
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: KVCache,
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
if self.q_lora_rank is None:
|
if self.q_lora_rank is None:
|
||||||
query = self.q_proj(hidden_states)
|
query = self.q_proj(hidden_states)
|
||||||
@ -321,30 +315,35 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove padding.
|
# Remove padding.
|
||||||
@ -387,22 +386,6 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
if (
|
|
||||||
SYSTEM == "rocm"
|
|
||||||
and self.hidden_act == "silu"
|
|
||||||
and hidden_states.dtype == torch.float16
|
|
||||||
and hidden_states.shape[0] == 1
|
|
||||||
and not self.quantize
|
|
||||||
):
|
|
||||||
out = torch.empty(
|
|
||||||
hidden_states.shape[0],
|
|
||||||
self.intermediate_size,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
|
||||||
return self.down_proj(out, reduce=reduce)
|
|
||||||
else:
|
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(
|
return self.down_proj(
|
||||||
@ -520,10 +503,9 @@ class DeepseekV2Layer(nn.Module):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
cu_seqlen_prefill: torch.Tensor,
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -534,10 +516,9 @@ class DeepseekV2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -583,18 +564,15 @@ class DeepseekV2Model(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -605,10 +583,9 @@ class DeepseekV2Model(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -635,11 +612,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -648,10 +623,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -0,0 +1,642 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=2,
|
||||||
|
n_routed_experts=160,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=8,
|
||||||
|
topk_group=3,
|
||||||
|
num_experts_per_tok=6,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
self.value_head_size = config.v_head_dim
|
||||||
|
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.qk_rope_head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
mscale = get_mscale(
|
||||||
|
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_a_proj = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
_, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, key_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||||
|
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
|
query_pe = (
|
||||||
|
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
batch_size, heads, head_dim = key_pe.shape
|
||||||
|
key_pe = (
|
||||||
|
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding.
|
||||||
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
if self.hidden_act != "silu":
|
||||||
|
# Bail out because MoE only supports silu.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[self.hidden_act]
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV3Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = (
|
||||||
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
if config.topk_method == "noaux_tc":
|
||||||
|
self.gate.e_score_correction_bias = torch.zeros(
|
||||||
|
config.n_routed_experts, device=weights.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.gate.e_score_correction_bias = None
|
||||||
|
|
||||||
|
self.moe_layer = moe_layer_cls(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
scoring_func=config.scoring_func,
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
|
||||||
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Layer(nn.Module):
|
||||||
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV3Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
):
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
|
attn_output, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV3Layer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DeepseekV3Model(
|
||||||
|
"model" if not prefix else f"{prefix}.model", config, weights
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -28,8 +28,8 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -40,7 +40,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -208,6 +208,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -234,11 +235,10 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -253,19 +253,24 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
causal=self.causal,
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.window_size,
|
window_size_left=self.window_size,
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
@ -273,14 +278,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -390,11 +394,10 @@ class FlashGemma2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -405,11 +408,10 @@ class FlashGemma2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -458,19 +460,16 @@ class FlashGemma2Model(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
adapter_data: Optional[torch.Tensor],
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -481,11 +480,10 @@ class FlashGemma2Model(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -529,11 +527,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -543,11 +539,10 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -28,9 +28,8 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
PREFILL_IN_KV_CACHE,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -39,6 +38,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -187,6 +187,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -206,10 +207,9 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -224,31 +224,36 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -317,10 +322,9 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -331,10 +335,9 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -379,18 +382,16 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
adapter_data: Optional[torch.Tensor],
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -401,10 +402,9 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -446,11 +446,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -460,10 +458,10 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -24,12 +24,11 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -38,6 +37,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
|
||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||||
@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
|||||||
prefix,
|
prefix,
|
||||||
weights,
|
weights,
|
||||||
)
|
)
|
||||||
elif config.quantize == "marlin":
|
|
||||||
raise RuntimeError(
|
|
||||||
"GPT-2 models with marlin quantization are not yet supported"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
||||||
|
|
||||||
@ -195,6 +191,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = load_row(
|
self.o_proj = load_row(
|
||||||
config,
|
config,
|
||||||
@ -212,10 +209,9 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
query, key, value = self.query_key_value(hidden_states).split(
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
self.head_size * self.num_heads, dim=1
|
self.head_size * self.num_heads, dim=1
|
||||||
@ -224,30 +220,35 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -313,10 +314,9 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -326,10 +326,9 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
@ -379,12 +378,9 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
true_max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
@ -395,10 +391,9 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
@ -432,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -448,12 +441,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -24,12 +24,12 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -38,13 +38,16 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
|
RotaryPosEmbeddingMode,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
@ -78,39 +81,25 @@ class GPTJRotary(PositionRotaryEmbedding):
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Such controlflows may add some overhead.
|
num_tokens = query.shape[0]
|
||||||
if SYSTEM == "cuda":
|
|
||||||
import rotary_emb
|
|
||||||
|
|
||||||
q1 = query[..., ::2]
|
|
||||||
q2 = query[..., 1::2]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
|
||||||
|
|
||||||
k1 = key[..., ::2]
|
|
||||||
k2 = key[..., 1::2]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm._C import ops
|
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
|
||||||
|
|
||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||||
|
sin = torch.repeat_interleave(sin, 2, dim=-1)
|
||||||
|
cos = torch.repeat_interleave(cos, 2, dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
key_shape = key.shape
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
key = key.view(num_tokens, -1, head_size)
|
||||||
elif SYSTEM == "ipex":
|
key_rot = key[..., :rotary_dim]
|
||||||
import intel_extension_for_pytorch as ipex
|
key_pass = key[..., rotary_dim:]
|
||||||
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
ipex.llm.functional.rotary_embedding(
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
query, key, sin, cos, query.size(-1), False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGPTJAttention(torch.nn.Module):
|
class FlashGPTJAttention(torch.nn.Module):
|
||||||
@ -140,6 +129,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = load_row(
|
self.o_proj = load_row(
|
||||||
config,
|
config,
|
||||||
@ -166,10 +156,9 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
query, key, value = self.query_key_value(hidden_states).split(
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
self.head_size * self.num_heads, dim=1
|
self.head_size * self.num_heads, dim=1
|
||||||
@ -186,30 +175,35 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -267,10 +261,9 @@ class FlashGPTJLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -280,10 +273,9 @@ class FlashGPTJLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||||
@ -327,19 +319,15 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.wte(input_ids)
|
hidden_states = self.wte(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -350,10 +338,9 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
@ -381,11 +368,9 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -394,11 +379,9 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -27,14 +27,16 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import (
|
||||||
|
KVCache,
|
||||||
|
get_kv_scales,
|
||||||
|
)
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -57,15 +59,6 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
|
||||||
pass
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
try:
|
|
||||||
from vllm import _custom_C
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights, layer_id):
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
@ -157,7 +150,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
# `config.attention_multiplier` is used in Granite
|
||||||
|
self.softmax_scale = getattr(
|
||||||
|
config, "attention_multiplier", self.head_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -177,11 +173,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=getattr(config, "attention_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
@ -202,12 +200,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache: KVCache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -222,30 +219,35 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_scales=self.kv_scales,
|
||||||
block_tables,
|
kv_cache=kv_cache,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -363,26 +365,6 @@ class LlamaMLP(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
|
||||||
SYSTEM == "rocm"
|
|
||||||
and self.hidden_act == "silu"
|
|
||||||
and hidden_states.dtype == torch.float16
|
|
||||||
and hidden_states.shape[0] == 1
|
|
||||||
and not self.quantize
|
|
||||||
and self.hidden_size
|
|
||||||
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
|
|
||||||
):
|
|
||||||
out = torch.empty(
|
|
||||||
hidden_states.shape[0],
|
|
||||||
self.intermediate_size,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
_custom_C.LLMM_Silu(
|
|
||||||
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
|
||||||
)
|
|
||||||
return self.down_proj(out, adapter_data)
|
|
||||||
else:
|
|
||||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(
|
return self.down_proj(
|
||||||
@ -408,7 +390,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
if SparseMoELayer.is_supported(weights)
|
if SparseMoELayer.is_supported(weights)
|
||||||
else DenseMoELayer
|
else DenseMoELayer
|
||||||
)
|
)
|
||||||
self.dense = Phi3MoE(
|
self.mlp = Phi3MoE(
|
||||||
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||||
)
|
)
|
||||||
# with moe the layernorms are are not rmsnorms and they have bias
|
# with moe the layernorms are are not rmsnorms and they have bias
|
||||||
@ -423,7 +405,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.dense = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
)
|
)
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
@ -437,6 +419,11 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
# This could eventually be baked into the weights like we do for the embeddings/lm_head
|
||||||
|
# but this would mean modifying the lora code
|
||||||
|
self.residual_multiplier = getattr(config, "residual_multiplier", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -445,12 +432,11 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -461,19 +447,21 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
attn_output *= self.residual_multiplier
|
||||||
|
|
||||||
# faster post attention rms norm
|
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
mlp_output *= self.residual_multiplier
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -493,9 +481,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=0,
|
index=0,
|
||||||
prefix=(
|
prefix=f"{prefix}.layers.0",
|
||||||
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -504,18 +490,14 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
# Skip first and last layers
|
# Skip first and last layers
|
||||||
for layer_id in range(1, config.num_hidden_layers - 1):
|
for layer_id in range(1, config.num_hidden_layers - 1):
|
||||||
if layer_id in self.cross_attention_layers:
|
if layer_id in self.cross_attention_layers:
|
||||||
from text_generation_server.models.custom_modeling.mllama import (
|
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||||
FlashLlamaCrossLayer,
|
FlashLlamaCrossLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaCrossLayer(
|
FlashLlamaCrossLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
f"model.layers.{layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -524,11 +506,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
f"model.layers.{layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -539,18 +517,14 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=last_layer_id,
|
index=last_layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{last_layer_id}"),
|
||||||
f"model.layers.{last_layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.layers.{last_layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
prefix=f"{prefix}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
@ -567,22 +541,17 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
|
||||||
true_max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
cross_attention_states=None,
|
cross_attention_states=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -593,12 +562,11 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -607,42 +575,60 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens"
|
f"{name}.embed_tokens"
|
||||||
if not prefix
|
if not prefix
|
||||||
else f"{prefix}.model.embed_tokens"
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(
|
||||||
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
suffix = "lm_head"
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||||
|
if embedding_multiplier is not None:
|
||||||
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
prefix,
|
||||||
weights=weights,
|
weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
self.logits_scaling = getattr(config, "logits_scaling", None)
|
||||||
|
if self.logits_scaling is not None and self.lm_head.head is not None:
|
||||||
|
try:
|
||||||
|
# Scale the weights directly
|
||||||
|
self.lm_head.head.linear.weight.data /= self.logits_scaling
|
||||||
|
self.logits_scaled = True
|
||||||
|
except Exception:
|
||||||
|
self.logits_scaled = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
cross_attention_states=None,
|
cross_attention_states=None,
|
||||||
@ -653,16 +639,20 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
if self.logits_scaling is not None and not self.logits_scaled:
|
||||||
|
logits /= self.logits_scaling
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits /= self.logits_scaling
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -0,0 +1,285 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Llava-NeXT model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
|
"""
|
||||||
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`tuple`):
|
||||||
|
The size of the input image in the format (height, width).
|
||||||
|
grid_pinpoints (`List`):
|
||||||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||||
|
of the form `(height, width)`.
|
||||||
|
patch_size (`int`):
|
||||||
|
The size of each image patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The shape of the image patch grid in the format (height, width).
|
||||||
|
"""
|
||||||
|
if not isinstance(grid_pinpoints, list):
|
||||||
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||||
|
|
||||||
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||||
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def unpad_image(tensor, original_size):
|
||||||
|
"""
|
||||||
|
Unpads a PyTorch tensor of a padded and resized image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):
|
||||||
|
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||||
|
original_size (`tuple`):
|
||||||
|
The original size of the image (height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The unpadded image tensor.
|
||||||
|
"""
|
||||||
|
original_height, original_width = original_size
|
||||||
|
current_height, current_width = tensor.shape[1:]
|
||||||
|
|
||||||
|
original_aspect_ratio = original_width / original_height
|
||||||
|
current_aspect_ratio = current_width / current_height
|
||||||
|
|
||||||
|
if original_aspect_ratio > current_aspect_ratio:
|
||||||
|
scale_factor = current_width / original_width
|
||||||
|
new_height = int(original_height * scale_factor)
|
||||||
|
padding = (current_height - new_height) // 2
|
||||||
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||||
|
else:
|
||||||
|
scale_factor = current_height / original_height
|
||||||
|
new_width = int(original_width * scale_factor)
|
||||||
|
padding = (current_width - new_width) // 2
|
||||||
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||||
|
|
||||||
|
return unpadded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||||
|
class LlavaNextMultiModalProjector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
|
self.linear_2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_features):
|
||||||
|
hidden_states = self.linear_1(image_features)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
vision_config = config.vision_config
|
||||||
|
# Instead of selecting in hidden_states[-2].
|
||||||
|
# Instead compute only the n -2 + 1 layers and don't pool
|
||||||
|
if config.vision_feature_layer < 0:
|
||||||
|
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||||
|
self.vision_tower = load_vision_model(
|
||||||
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
|
config=config.vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||||
|
prefix="multi_modal_projector", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.image_newline = weights.get_tensor("image_newline")
|
||||||
|
|
||||||
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
self.config = config
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
mask = torch.where(input_ids == self.config.image_token_index)
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
try:
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
# Unused for this model
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
|
# 1. Extract the input embeddings
|
||||||
|
|
||||||
|
# 2. Merge text and images
|
||||||
|
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.view(
|
||||||
|
num_images * num_patches, channels, height, width
|
||||||
|
)
|
||||||
|
image_features = self.vision_tower(pixel_values)
|
||||||
|
|
||||||
|
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||||
|
# Already done within the clip model
|
||||||
|
selected_image_feature = image_features.last_hidden_state
|
||||||
|
|
||||||
|
if self.config.vision_feature_select_strategy == "default":
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.config.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
# split up image_features for each of the individual images
|
||||||
|
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||||
|
# if we assume each image has 5 image features (base image + 4 patches)
|
||||||
|
split_sizes = [num_patches] * num_images
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
|
||||||
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
|
height = width = (
|
||||||
|
self.config.vision_config.image_size
|
||||||
|
// self.config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
new_image_features = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
|
||||||
|
if height * width != base_image_feature.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of patches is not consistent with the image size."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimensions are intentionally swapped to be bug-compatible with
|
||||||
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||||
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.config.image_grid_pinpoints,
|
||||||
|
self.config.vision_config.image_size,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.image_newline[:, None, None].expand(
|
||||||
|
*image_feature.shape[:-1], 1
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(base_image_feature, image_feature), dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(image_feature, self.image_newline[None]), dim=0
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
|
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_features
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -26,12 +26,12 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -41,20 +41,12 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
try:
|
|
||||||
from vllm import _custom_C
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(PretrainedConfig):
|
class MistralConfig(PretrainedConfig):
|
||||||
model_type = "mistral"
|
model_type = "mistral"
|
||||||
|
|
||||||
@ -160,6 +152,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -185,12 +178,10 @@ class MistralAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -205,38 +196,36 @@ class MistralAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
kv_cache.store(
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
key=kv[:, 0],
|
||||||
else:
|
value=kv[:, 1],
|
||||||
kv_to_cache = kv
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
reshape_and_cache(
|
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -300,24 +289,6 @@ class MistralMLP(nn.Module):
|
|||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
|
||||||
SYSTEM == "rocm"
|
|
||||||
and self.hidden_act == "silu"
|
|
||||||
and hidden_states.dtype == torch.float16
|
|
||||||
and hidden_states.shape[0] == 1
|
|
||||||
and not self.quantize
|
|
||||||
):
|
|
||||||
out = torch.empty(
|
|
||||||
hidden_states.shape[0],
|
|
||||||
self.intermediate_size,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
_custom_C.LLMM_Silu(
|
|
||||||
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
|
||||||
)
|
|
||||||
return self.down_proj(out, adapter_data)
|
|
||||||
else:
|
|
||||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(
|
return self.down_proj(
|
||||||
@ -355,12 +326,10 @@ class MistralLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -371,12 +340,10 @@ class MistralLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -423,20 +390,15 @@ class MistralModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
true_max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, true_max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -447,12 +409,10 @@ class MistralModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -498,35 +458,21 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
|
||||||
if prefill_cache_indices is not None:
|
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
|
||||||
slots = slots[prefill_cache_indices]
|
|
||||||
elif self.max_past is not None:
|
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
|
||||||
# kernel requires the true values
|
|
||||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
true_max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -37,9 +37,9 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
@ -215,6 +215,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -234,11 +235,9 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -253,38 +252,36 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
kv_cache.store(
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
key=kv[:, 0],
|
||||||
else:
|
value=kv[:, 1],
|
||||||
kv_to_cache = kv
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
reshape_and_cache(
|
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -378,11 +375,9 @@ class MixtralLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -393,11 +388,9 @@ class MixtralLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -448,20 +441,15 @@ class MixtralModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
true_max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, true_max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -472,11 +460,9 @@ class MixtralModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -507,34 +493,21 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
|
||||||
if prefill_cache_indices is not None:
|
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
|
||||||
slots = slots[prefill_cache_indices]
|
|
||||||
elif self.max_past is not None:
|
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
|
||||||
# kernel requires the true values
|
|
||||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
true_max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -0,0 +1,986 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Mllama model."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_aspect_ratio_attention_mask(
|
||||||
|
aspect_ratio_mask: torch.Tensor,
|
||||||
|
num_patches: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Expand aspect ratio mask to target_length
|
||||||
|
batch_size, max_num_tiles = aspect_ratio_mask.shape
|
||||||
|
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
|
||||||
|
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
|
||||||
|
|
||||||
|
# Mask padding patches
|
||||||
|
pad_patches = target_length - num_patches
|
||||||
|
attention_mask[:, :, -pad_patches:] = 0
|
||||||
|
|
||||||
|
# Invert the mask (0 -> 1, 1 -> 0)
|
||||||
|
attention_mask = 1 - attention_mask
|
||||||
|
|
||||||
|
# Reshape to 2D and create 4D attention mask
|
||||||
|
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
|
||||||
|
attention_mask = attention_mask.reshape(
|
||||||
|
batch_size, max_num_tiles * target_length, 1
|
||||||
|
)
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length),
|
||||||
|
fill_value=min_dtype,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(
|
||||||
|
target_length, device=device
|
||||||
|
) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = (
|
||||||
|
causal_mask.clone()
|
||||||
|
) # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = (
|
||||||
|
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[
|
||||||
|
:, :, :, :mask_length
|
||||||
|
].masked_fill(padding_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_cross_attention_mask(
|
||||||
|
cross_attention_mask: torch.Tensor,
|
||||||
|
num_vision_tokens: int,
|
||||||
|
dtype: str,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# reshape so it can be used by attn module
|
||||||
|
batch_size, text_total_length, *_ = cross_attention_mask.shape
|
||||||
|
cross_attention_mask = cross_attention_mask.repeat_interleave(
|
||||||
|
num_vision_tokens, dim=3
|
||||||
|
)
|
||||||
|
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
|
||||||
|
cross_attention_mask = cross_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# invert the mask
|
||||||
|
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
|
||||||
|
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
|
||||||
|
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
|
||||||
|
# last dimension contains negative infinity values, otherwise it's 1
|
||||||
|
negative_inf_value = torch.finfo(dtype).min
|
||||||
|
full_text_row_masked_out_mask = (
|
||||||
|
(cross_attention_mask != negative_inf_value)
|
||||||
|
.any(dim=-1)
|
||||||
|
.type_as(cross_attention_mask)[..., None]
|
||||||
|
)
|
||||||
|
cross_attention_mask *= full_text_row_masked_out_mask
|
||||||
|
|
||||||
|
return cross_attention_mask, full_text_row_masked_out_mask
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
||||||
|
class MllamaVisionMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionSdpaAttention(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.head_dim = config.hidden_size // config.attention_heads
|
||||||
|
self.num_heads = config.attention_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_state)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, q_seq_len, _ = query.shape
|
||||||
|
_, kv_seq_len, _ = key.shape
|
||||||
|
|
||||||
|
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
|
||||||
|
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||||
|
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||||
|
|
||||||
|
output = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights, is_gated: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_attention_heads = config.attention_heads
|
||||||
|
self.is_gated = is_gated
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
|
self.self_attn = MllamaVisionSdpaAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = MllamaVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
|
||||||
|
)
|
||||||
|
|
||||||
|
# there used to be an if else here, no code path
|
||||||
|
if is_gated:
|
||||||
|
self.gate_attn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.gate_ffn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.input_layernorm(hidden_state)
|
||||||
|
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
|
||||||
|
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
||||||
|
hidden_state = residual + gate_attn * hidden_state
|
||||||
|
|
||||||
|
# Feed forward
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.post_attention_layernorm(hidden_state)
|
||||||
|
hidden_state = self.mlp(hidden_state)
|
||||||
|
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
||||||
|
hidden_state = residual + gate_ffn * hidden_state
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionEncoder(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = [
|
||||||
|
MllamaVisionEncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=is_gated,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
encoder_states = [hidden_states]
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
encoder_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, encoder_states
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||||
|
|
||||||
|
self.embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embedding", weights=weights
|
||||||
|
)
|
||||||
|
self.gate = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
embeddings = self.embedding(aspect_ratio_ids)
|
||||||
|
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
||||||
|
|
||||||
|
# Always gated.
|
||||||
|
embeddings = embeddings * self.gate.tanh()
|
||||||
|
|
||||||
|
hidden_state = hidden_state + embeddings
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||||
|
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.scale = config.hidden_size**-0.5
|
||||||
|
|
||||||
|
self.gate = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# position embedding
|
||||||
|
embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
|
||||||
|
self.tile_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.tile_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# position embeddings
|
||||||
|
hidden_state = hidden_state + self.gated_position_embedding.view(
|
||||||
|
1, 1, self.num_patches, self.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# precomputed tile position embeddings
|
||||||
|
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
||||||
|
batch_size = hidden_state.shape[0]
|
||||||
|
tile_position_embedding = tile_position_embedding.reshape(
|
||||||
|
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
||||||
|
)
|
||||||
|
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
||||||
|
hidden_state = hidden_state + gated_tile_position_embedding
|
||||||
|
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionModel(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
self.intermediate_layers_indices = config.intermediate_layers_indices
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||||||
|
self.scale = config.hidden_size**-0.5
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.hidden_size,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.class_embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
|
||||||
|
prefix=f"{prefix}.gated_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||||
|
prefix=f"{prefix}.pre_tile_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||||
|
prefix=f"{prefix}.post_tile_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
## layer norms
|
||||||
|
self.layernorm_pre = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layernorm_pre",
|
||||||
|
weights=weights,
|
||||||
|
# torch default
|
||||||
|
eps=1e-05,
|
||||||
|
)
|
||||||
|
self.layernorm_post = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layernorm_post",
|
||||||
|
weights=weights,
|
||||||
|
# torch default
|
||||||
|
eps=1e-05,
|
||||||
|
)
|
||||||
|
|
||||||
|
## encoders
|
||||||
|
self.transformer = MllamaVisionEncoder(
|
||||||
|
prefix=f"{prefix}.transformer",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=False,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
self.global_transformer = MllamaVisionEncoder(
|
||||||
|
prefix=f"{prefix}.global_transformer",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=True,
|
||||||
|
num_layers=config.num_global_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, _, hidden_size = hidden_state.shape
|
||||||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
aspect_ratio_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
) = pixel_values.shape
|
||||||
|
|
||||||
|
pixel_values = pixel_values.reshape(
|
||||||
|
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||||
|
)
|
||||||
|
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
||||||
|
batch_size * num_concurrent_media, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
# patch embedding
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# tile embeddings
|
||||||
|
_, num_patches, dim = hidden_state.shape
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media, num_tiles, -1, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.pre_tile_positional_embedding(
|
||||||
|
hidden_state, aspect_ratio_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply cls token
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.apply_class_embedding(hidden_state)
|
||||||
|
num_patches += 1
|
||||||
|
|
||||||
|
# apply position embeddings
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
||||||
|
|
||||||
|
# apply encoder
|
||||||
|
hidden_state = self.layernorm_pre(hidden_state)
|
||||||
|
|
||||||
|
# Compute the number of tokens to pad
|
||||||
|
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
||||||
|
# Compute padding tuple for pad function
|
||||||
|
padding = (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
num_padding_patches,
|
||||||
|
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
||||||
|
# Pad the tensor
|
||||||
|
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
||||||
|
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.reshape(
|
||||||
|
batch_size * num_concurrent_media, -1
|
||||||
|
)
|
||||||
|
attention_mask = _prepare_aspect_ratio_attention_mask(
|
||||||
|
aspect_ratio_mask=attention_mask,
|
||||||
|
num_patches=self.num_patches,
|
||||||
|
target_length=hidden_state.shape[2],
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
||||||
|
hidden_state, all_intermediate_hidden_states = self.transformer(
|
||||||
|
hidden_state,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = [
|
||||||
|
hidden_state
|
||||||
|
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
|
||||||
|
if idx in self.intermediate_layers_indices
|
||||||
|
]
|
||||||
|
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
||||||
|
|
||||||
|
# apply global encoder
|
||||||
|
hidden_state = self.layernorm_post(hidden_state)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state = self.post_tile_positional_embedding(
|
||||||
|
hidden_state, aspect_ratio_ids
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles * (num_patches + num_padding_patches),
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state, _ = self.global_transformer(
|
||||||
|
hidden_state, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state[:, :, :slice_index]
|
||||||
|
|
||||||
|
# adding intermediate layer outputs
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
||||||
|
)
|
||||||
|
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaTextCrossAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, *, prefix, config, weights, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.num_heads = self.config.num_attention_heads
|
||||||
|
self.num_key_value_heads = self.config.num_key_value_heads
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = config.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
self.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.k_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.v_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.v_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_norm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.k_norm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
# past_key_value=None,
|
||||||
|
# attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
# hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
# bsz, q_len, _ = hidden_states.size()
|
||||||
|
(
|
||||||
|
cross_attention_states,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
indices,
|
||||||
|
) = cross_attention_states
|
||||||
|
bs = cu_seqlen_q.size(0) - 1
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
|
||||||
|
key_states = self.k_proj(cross_attention_states)
|
||||||
|
value_states = self.v_proj(cross_attention_states)
|
||||||
|
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bs, -1, self.num_key_value_heads, self.head_size
|
||||||
|
)
|
||||||
|
key_states = self.k_norm(key_states)
|
||||||
|
|
||||||
|
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
|
||||||
|
causal = False
|
||||||
|
# logger.info(
|
||||||
|
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||||
|
# )
|
||||||
|
# execute sdpa
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=None,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=None,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
|
||||||
|
class MllamaTextMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shape = x.shape
|
||||||
|
gate_up_states = self.gate_up_proj(x)
|
||||||
|
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
|
||||||
|
result = self.down_proj(
|
||||||
|
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaCrossLayer(torch.nn.Module):
|
||||||
|
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
|
||||||
|
|
||||||
|
def __init__(self, *, prefix, config, weights, index) -> None:
|
||||||
|
layer_idx = index
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attn = MllamaTextCrossAttention(
|
||||||
|
prefix=f"{prefix}.cross_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.cross_attn_attn_gate = torch.nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.post_attention_layernorm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
cross_attention_states, # [ IB, ...]
|
||||||
|
hpu_attention_meta,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if cross_attention_states is None:
|
||||||
|
return hidden_states, residual
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
|
||||||
|
indices = cross_attention_states[-1]
|
||||||
|
out_hidden_states = hidden_states[:]
|
||||||
|
if len(indices) > 0:
|
||||||
|
assert max(indices) < hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states[indices]
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.cross_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
# attention_mask=cross_attention_mask,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
)
|
||||||
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
out_hidden_states[indices] = hidden_states
|
||||||
|
hidden_states = out_hidden_states
|
||||||
|
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||||
|
class MllamaTextRMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = weight
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, *, prefix, weights, eps):
|
||||||
|
weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
return cls(weight=weight, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMllamaForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
config.text_config._attn_implementation = "sdpa"
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.vision_model = MllamaVisionModel(
|
||||||
|
prefix="vision_model", config=config.vision_config, weights=weights
|
||||||
|
)
|
||||||
|
self.multi_modal_projector = FastLinear.load(
|
||||||
|
prefix="multi_modal_projector", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.text_model = FlashLlamaForCausalLM(
|
||||||
|
prefix="language_model", config=config.text_config, weights=weights
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
self.device = weights.device
|
||||||
|
|
||||||
|
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
|
||||||
|
if aspect_ratio_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||||
|
)
|
||||||
|
# logger.info(f"PIxel values {pixel_values.shape}")
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
vision_states = self.vision_model(
|
||||||
|
pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
||||||
|
)
|
||||||
|
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
|
||||||
|
-1, vision_states.shape[-2], self.hidden_size
|
||||||
|
)
|
||||||
|
_, _, h = cross_attention_states.shape
|
||||||
|
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
|
||||||
|
# logger.info(f"cross {cross_attention_states.shape}")
|
||||||
|
return cross_attention_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
seqlen_q = len(image_indices)
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
device = cross_attention_states.device
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
offset = 0
|
||||||
|
cu_q = []
|
||||||
|
indices = []
|
||||||
|
for index in image_indices:
|
||||||
|
cu_q.append(offset)
|
||||||
|
length = seqlen.input_lengths[index].item()
|
||||||
|
assert index < seqlen.cu_seqlen_q.shape[0]
|
||||||
|
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||||
|
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||||
|
offset += length
|
||||||
|
cu_q.append(offset)
|
||||||
|
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
assert max(indices) < input_ids.shape[0]
|
||||||
|
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
seqlen_q + 1, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
|
)
|
||||||
|
indices = image_indices[:]
|
||||||
|
|
||||||
|
cross_attention_states = (
|
||||||
|
cross_attention_states,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
@ -29,8 +29,8 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -39,7 +39,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -132,6 +132,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
@ -146,10 +147,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
@ -165,30 +165,35 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=qkv[:, 1],
|
||||||
|
value=qkv[:, 2],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
query=qkv[:, 0],
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
key=qkv[:, 1],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
value=qkv[:, 2],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -255,10 +260,9 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
@ -269,10 +273,9 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
@ -293,10 +296,9 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -347,18 +349,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -369,10 +368,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
@ -401,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -414,10 +410,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -19,7 +19,7 @@ from torch import nn
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
@ -69,22 +69,20 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
# Unused here
|
# Unused here
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
max_s += 1
|
|
||||||
position_ids += 1
|
position_ids += 1
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
@ -106,10 +104,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -9,8 +9,8 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -19,7 +19,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
@ -139,6 +139,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
# in llama the dense layer is called "o_proj" and has bias=False
|
# in llama the dense layer is called "o_proj" and has bias=False
|
||||||
self.dense = TensorParallelRowLinear.load(
|
self.dense = TensorParallelRowLinear.load(
|
||||||
@ -159,10 +160,9 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
# Compute query, key, value and split
|
# Compute query, key, value and split
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
@ -188,29 +188,34 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_scales=self.kv_scales,
|
||||||
block_tables,
|
kv_cache=kv_cache,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -274,10 +279,9 @@ class FlashPhiLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -287,10 +291,9 @@ class FlashPhiLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.resid_dropout(attn_output).add(
|
hidden_states = self.resid_dropout(attn_output).add(
|
||||||
@ -339,18 +342,15 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -361,10 +361,9 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -394,11 +393,9 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -407,10 +404,9 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -8,8 +8,8 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -17,7 +17,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -86,6 +86,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
@ -104,11 +106,9 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
@ -123,38 +123,36 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
kv_cache.store(
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
key=kv[:, 0],
|
||||||
else:
|
value=kv[:, 1],
|
||||||
kv_to_cache = kv
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
reshape_and_cache(
|
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -223,13 +221,11 @@ class Qwen2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, residual = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.self_attn(
|
attn_output = self.self_attn(
|
||||||
@ -238,21 +234,17 @@ class Qwen2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(hidden_states)
|
||||||
attn_output, res
|
mlp_output = self.mlp(hidden_states)
|
||||||
)
|
hidden_states = mlp_output + residual
|
||||||
|
return hidden_states
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
|
||||||
|
|
||||||
return mlp_output, attn_res
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Model(torch.nn.Module):
|
class Qwen2Model(torch.nn.Module):
|
||||||
@ -264,9 +256,6 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen2Layer(
|
Qwen2Layer(
|
||||||
@ -290,42 +279,35 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
true_max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
|
||||||
# Avoid to index in each layer
|
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
position_ids, true_max_s, hidden_states.dtype
|
position_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -346,6 +328,12 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
self.max_past_tensor = (
|
self.max_past_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
@ -359,34 +347,23 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
|
||||||
if prefill_cache_indices is not None:
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
|
||||||
slots = slots[prefill_cache_indices]
|
|
||||||
elif self.max_past is not None:
|
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
|
||||||
# kernel requires the true values
|
|
||||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
true_max_s,
|
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -12,14 +12,14 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +79,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
self.alibi = False
|
self.alibi = False
|
||||||
self.rotary = True
|
self.rotary = True
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = 2048
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
# Backward compatibility with n_embed kwarg
|
# Backward compatibility with n_embed kwarg
|
||||||
@ -160,6 +161,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
@ -180,10 +182,9 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
@ -200,30 +201,35 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -278,6 +284,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
@ -293,10 +300,9 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
@ -312,36 +318,35 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(
|
||||||
kv[:, :, 0].contiguous(),
|
key=kv[:, :, 0].contiguous(),
|
||||||
kv[:, :, 1].contiguous(),
|
value=kv[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
slots=slots,
|
||||||
kv_cache[1],
|
kv_scales=self.kv_scales,
|
||||||
slots,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
key=kv[:, :, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
value=kv[:, :, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(
|
return self.dense(
|
||||||
@ -424,10 +429,9 @@ class FlashRWLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -438,10 +442,9 @@ class FlashRWLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
@ -460,10 +463,9 @@ class FlashRWLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.post_attention_layernorm is not None:
|
if self.post_attention_layernorm is not None:
|
||||||
@ -547,10 +549,9 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
# Layer norm.
|
# Layer norm.
|
||||||
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
|
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
|
||||||
@ -562,10 +563,9 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
@ -623,18 +623,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
@ -645,10 +642,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
@ -675,11 +671,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -688,10 +682,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -8,8 +8,8 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -18,7 +18,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
@ -32,10 +32,6 @@ def load_multi_mqa(
|
|||||||
return _load_multi_mqa_gptq(
|
return _load_multi_mqa_gptq(
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
)
|
)
|
||||||
elif config.quantize == "marlin":
|
|
||||||
raise RuntimeError(
|
|
||||||
"santacoder models with marlin quantization are not yet supported"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return _load_multi_mqa(
|
return _load_multi_mqa(
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
@ -259,6 +255,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.c_proj = load_row(
|
self.c_proj = load_row(
|
||||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.kv_head_mapping = torch.zeros(
|
self.kv_head_mapping = torch.zeros(
|
||||||
self.num_heads, dtype=torch.int32, device=weights.device
|
self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
)
|
)
|
||||||
@ -268,10 +265,9 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
|
||||||
@ -284,32 +280,35 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
key=key_value[:, 0],
|
||||||
|
value=key_value[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
key=key_value[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
value=key_value[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -371,20 +370,18 @@ class Block(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
@ -435,10 +432,9 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
@ -452,10 +448,9 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
@ -484,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -497,10 +490,9 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -29,17 +29,19 @@ from typing import Optional, List, Tuple
|
|||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -110,17 +112,31 @@ class Starcoder2Config(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
|
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
base_layer = _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=prefixes,
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer=base_layer,
|
||||||
|
layer_id=layer_id,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
@ -158,6 +174,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
class Starcoder2Attention(torch.nn.Module):
|
class Starcoder2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -189,14 +206,23 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=getattr(config, "use_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -209,13 +235,12 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
adapter_data,
|
||||||
prefill_cache_indices,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -228,45 +253,45 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
kv_cache.store(
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
key=kv[:, 0],
|
||||||
else:
|
value=kv[:, 1],
|
||||||
kv_to_cache = kv
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
reshape_and_cache(
|
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# sdpa
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
kv_scales=self.kv_scales,
|
||||||
self.softmax_scale,
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache,
|
||||||
kv_cache[1],
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2MLP(nn.Module):
|
class Starcoder2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, index):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -280,27 +305,42 @@ class Starcoder2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.c_fc = TensorParallelColumnLinear.load(
|
c_fc = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.c_fc",
|
prefix=f"{prefix}.c_fc",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.c_proj = TensorParallelRowLinear.load(
|
c_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.c_proj",
|
prefix=f"{prefix}.c_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
self.c_fc = TensorParallelMultiAdapterLinear.load(
|
||||||
hidden_states = self.c_fc(hidden_states)
|
c_fc,
|
||||||
|
layer_id=index,
|
||||||
|
layer_names=[f"{prefix}.c_fc"],
|
||||||
|
sizes=[config.intermediate_size, config.intermediate_size],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.c_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
c_proj,
|
||||||
|
index,
|
||||||
|
"c_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
hidden_states = self.c_fc(hidden_states, adapter_data)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
return self.c_proj(hidden_states)
|
return self.c_proj(hidden_states, adapter_data)
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2GatedMLP(nn.Module):
|
class Starcoder2GatedMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -314,27 +354,47 @@ class Starcoder2GatedMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=prefixes,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
STARCODER2_NORMALIZATION_CLASSES = {
|
STARCODER2_NORMALIZATION_CLASSES = {
|
||||||
@ -353,11 +413,11 @@ class Starcoder2Layer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"model.layers.{layer_id}"
|
||||||
self.self_attn = Starcoder2Attention(
|
self.self_attn = Starcoder2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
@ -379,11 +439,10 @@ class Starcoder2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
adapter_data,
|
||||||
prefill_cache_indices,
|
hpu_attention_meta,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -394,11 +453,10 @@ class Starcoder2Layer(nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
adapter_data,
|
||||||
prefill_cache_indices,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -406,7 +464,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -447,20 +505,16 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
adapter_data,
|
||||||
true_max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
position_ids, true_max_s, hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -471,11 +525,10 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
adapter_data,
|
||||||
prefill_cache_indices,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -519,34 +572,22 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
true_max_s = max_s
|
|
||||||
if prefill_cache_indices is not None:
|
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
|
||||||
slots = slots[prefill_cache_indices]
|
|
||||||
elif self.max_past is not None:
|
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
|
||||||
# kernel requires the true values
|
|
||||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
adapter_data,
|
||||||
true_max_s,
|
hpu_attention_meta,
|
||||||
prefill_cache_indices,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
|
|||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
):
|
):
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
# mask = input_ids == self.config.image_token_index
|
# mask = input_ids == self.config.image_token_index
|
||||||
mask = input_ids == self.config.image_token_id
|
# - replace `==` with torch.where to fix the issue in hpu graph
|
||||||
|
mask = torch.where(input_ids == self.config.image_token_id)
|
||||||
# Let's pray we have enabled enough slots !
|
# Let's pray we have enabled enough slots !
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
@ -739,17 +740,16 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
# Unused here
|
# Unused here
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
@ -793,6 +793,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
].contiguous()
|
].contiguous()
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
"""
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
dimension=1, size=patch_size, step=patch_size
|
dimension=1, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
@ -800,6 +801,21 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
dimension=2, size=patch_size, step=patch_size
|
dimension=2, size=patch_size, step=patch_size
|
||||||
)
|
)
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
"""
|
||||||
|
# hpu does none support unfold
|
||||||
|
conv_kernel = torch.ones(
|
||||||
|
[1, 1, patch_size, patch_size],
|
||||||
|
dtype=pixel_values.dtype,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
patches_subgrid = torch.nn.functional.conv2d(
|
||||||
|
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||||
|
conv_kernel,
|
||||||
|
stride=patch_size,
|
||||||
|
).squeeze(1)
|
||||||
|
patch_attention_mask = torch.eq(
|
||||||
|
patches_subgrid, (patch_size * patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.vision_model(
|
image_hidden_states = self.vision_model(
|
||||||
@ -825,12 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
@ -0,0 +1,596 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Idefics3 model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||||
|
resolution.
|
||||||
|
|
||||||
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||||
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||||
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||||
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.patch_embedding.bias = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches_per_side = self.image_size // self.patch_size
|
||||||
|
self.num_patches = self.num_patches_per_side**2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
|
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
max_nb_patches_h, max_nb_patches_w = (
|
||||||
|
max_im_h // self.patch_size,
|
||||||
|
max_im_w // self.patch_size,
|
||||||
|
)
|
||||||
|
boundaries = torch.arange(
|
||||||
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||||
|
)
|
||||||
|
position_ids = torch.full(
|
||||||
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||||
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||||
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||||
|
|
||||||
|
bucket_coords_h = torch.bucketize(
|
||||||
|
fractional_coords_h, boundaries, right=True
|
||||||
|
)
|
||||||
|
bucket_coords_w = torch.bucketize(
|
||||||
|
fractional_coords_w, boundaries, right=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_ids = (
|
||||||
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||||
|
).flatten()
|
||||||
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||||
|
|
||||||
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||||
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
k_v_seq_len = key_states.shape[-2]
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = Idefics3VisionAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = Idefics3VisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3Encoder(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics3EncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = Idefics3VisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.encoder = Idefics3Encoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = pixel_values.size(0)
|
||||||
|
if patch_attention_mask is None:
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
patch_attention_mask = torch.ones(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
pixel_values.size(2) // patch_size,
|
||||||
|
pixel_values.size(3) // patch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
patch_attention_mask = patch_attention_mask.to(
|
||||||
|
dtype=torch.bool, device=pixel_values.device
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(
|
||||||
|
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||||
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
|
if not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
else:
|
||||||
|
patch_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
patch_attention_mask, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3SimpleMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
|
||||||
|
output_size = config.text_config.hidden_size
|
||||||
|
proj = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
|
||||||
|
requires_grad=False,
|
||||||
|
).to(weights.dtype)
|
||||||
|
self.proj = nn.Linear(input_size, output_size, bias=False)
|
||||||
|
self.proj.weight = proj
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3Connector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
|
||||||
|
self.scale_factor = config.scale_factor
|
||||||
|
|
||||||
|
def pixel_shuffle(self, x, scale_factor=2):
|
||||||
|
bsz, seq, embed_dim = x.size()
|
||||||
|
height = width = int(seq**0.5)
|
||||||
|
x = x.view(bsz, height, width, embed_dim)
|
||||||
|
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
x = x.reshape(
|
||||||
|
bsz,
|
||||||
|
int(width / scale_factor),
|
||||||
|
int(height / scale_factor),
|
||||||
|
embed_dim * (scale_factor**2),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, image_hidden_states):
|
||||||
|
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
||||||
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
|
||||||
|
# since Idefics3 uses the `embed_tokens` for the final prediction
|
||||||
|
# config.text_config.tie_word_embeddings = True
|
||||||
|
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
name="text_model",
|
||||||
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
# The vision and connector models are not quantized.
|
||||||
|
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
||||||
|
self.vision_model = Idefics3VisionTransformer(
|
||||||
|
prefix=(
|
||||||
|
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||||
|
),
|
||||||
|
config=vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.quantize = None
|
||||||
|
self.connector = Idefics3Connector(
|
||||||
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
# mask = input_ids == self.config.image_token_index
|
||||||
|
# - replace `==` with torch.where to fix the issue in hpu graph
|
||||||
|
mask = torch.where(input_ids == self.config.image_token_id)
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
# Unused here
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(
|
||||||
|
dtype=self.dtype
|
||||||
|
) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
"""
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
"""
|
||||||
|
# hpu does none support unfold
|
||||||
|
conv_kernel = torch.ones(
|
||||||
|
[1, 1, patch_size, patch_size],
|
||||||
|
dtype=pixel_values.dtype,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
patches_subgrid = torch.nn.functional.conv2d(
|
||||||
|
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||||
|
conv_kernel,
|
||||||
|
stride=patch_size,
|
||||||
|
).squeeze(1)
|
||||||
|
patch_attention_mask = torch.eq(
|
||||||
|
patches_subgrid, (patch_size * patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -46,14 +46,8 @@ from text_generation_server.layers import (
|
|||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
|
||||||
import dropout_layer_norm
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
from vllm._C import ops
|
|
||||||
else:
|
|
||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
@ -351,94 +345,18 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if SYSTEM == "ipex":
|
from vllm_hpu_extension.kernels import rms_norm
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
orig_shape = hidden_states.shape
|
||||||
residual,
|
|
||||||
hidden_states,
|
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
self.variance_epsilon,
|
|
||||||
residual is not None,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
elif hidden_states.shape[-1] > 8192:
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
residual += hidden_states.view(residual.shape)
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(
|
|
||||||
variance + self.variance_epsilon
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
|
|
||||||
return self.weight * hidden_states
|
|
||||||
elif SYSTEM == "cuda":
|
|
||||||
# faster post attention rms norm
|
|
||||||
unwrap = False
|
|
||||||
if len(hidden_states.shape) > 2:
|
|
||||||
unwrap = True
|
|
||||||
shape = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
|
||||||
|
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.variance_epsilon,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
True, # Activate RMSNorm
|
|
||||||
)
|
|
||||||
if res is None:
|
|
||||||
res = hidden_states
|
|
||||||
|
|
||||||
if unwrap:
|
|
||||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
|
||||||
|
|
||||||
return normed_hidden_states
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
unwrap = False
|
|
||||||
if len(hidden_states.shape) > 2:
|
|
||||||
unwrap = True
|
|
||||||
shape = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
|
||||||
ops.rms_norm(
|
|
||||||
out,
|
|
||||||
hidden_states,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
if unwrap:
|
|
||||||
out = out.view(*shape)
|
|
||||||
|
|
||||||
return out
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
residual = hidden_states
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||||
)
|
if len(orig_shape) == 2:
|
||||||
|
residual = residual.unsqueeze(0)
|
||||||
|
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||||
|
return x.view(orig_shape), residual.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaMLP
|
# this was adapted from LlamaMLP
|
||||||
|
@ -196,6 +196,9 @@ class MambaModel(nn.Module):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = "backbone"
|
prefix = "backbone"
|
||||||
|
try:
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
|
||||||
|
except RuntimeError:
|
||||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -206,6 +209,9 @@ class MambaModel(nn.Module):
|
|||||||
self.norm_f = FastRMSNorm.load(
|
self.norm_f = FastRMSNorm.load(
|
||||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
|
||||||
|
except RuntimeError:
|
||||||
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,796 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" PyTorch GPTNeoX model."""
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
BaseModelOutputWithPast,
|
|
||||||
CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from text_generation_server.layers import (
|
|
||||||
TensorParallelColumnLinear,
|
|
||||||
TensorParallelEmbedding,
|
|
||||||
TensorParallelRowLinear,
|
|
||||||
SpeculativeHead,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
|
||||||
if (
|
|
||||||
torch.cuda.is_available()
|
|
||||||
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from custom_kernels import fused_attention_cuda
|
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def make_causal_mask(
|
|
||||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
|
||||||
) -> torch.BoolTensor:
|
|
||||||
"""
|
|
||||||
Make causal mask used for self-attention.
|
|
||||||
"""
|
|
||||||
batch_size, target_length = input_ids_shape
|
|
||||||
mask = torch.ones(
|
|
||||||
(target_length, target_length + past_key_values_length),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
mask = mask.triu(1 + past_key_values_length)
|
|
||||||
|
|
||||||
expanded_mask = mask.unsqueeze(0).expand(
|
|
||||||
batch_size, target_length, target_length + past_key_values_length
|
|
||||||
)
|
|
||||||
return expanded_mask
|
|
||||||
|
|
||||||
|
|
||||||
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
|
||||||
"""
|
|
||||||
batch_size, src_length = mask.shape
|
|
||||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
|
||||||
|
|
||||||
expanded_mask = ~(mask[:, None, :].to(torch.bool))
|
|
||||||
return expanded_mask.expand(batch_size, tgt_length, src_length)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_attn_mask(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
input_shape: Tuple[int, int],
|
|
||||||
past_key_values_length: int,
|
|
||||||
) -> torch.BoolTensor:
|
|
||||||
# create causal mask
|
|
||||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
|
||||||
combined_attention_mask = None
|
|
||||||
device = attention_mask.device
|
|
||||||
_, src_length = input_shape
|
|
||||||
|
|
||||||
if src_length > 1:
|
|
||||||
combined_attention_mask = make_causal_mask(
|
|
||||||
input_shape, device=device, past_key_values_length=past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
|
||||||
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask
|
|
||||||
if combined_attention_mask is None
|
|
||||||
else expanded_attn_mask | combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXPreTrainedModel(PreTrainedModel):
|
|
||||||
"""
|
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
||||||
models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXAttention(nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.head_size = self.hidden_size // self.num_attention_heads
|
|
||||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
|
||||||
# ??? TODO
|
|
||||||
# self.register_buffer(
|
|
||||||
# "bias",
|
|
||||||
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
||||||
# 1, 1, max_positions, max_positions
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# self.register_buffer("masked_bias", torch.tensor(-1e9))
|
|
||||||
self.rotary_emb = RotaryEmbedding(
|
|
||||||
self.rotary_ndims,
|
|
||||||
config.max_position_embeddings,
|
|
||||||
base=config.rotary_emb_base,
|
|
||||||
)
|
|
||||||
self.rotary_emb.inv_freq = nn.Parameter(
|
|
||||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
|
||||||
)
|
|
||||||
self.inv_norm_factor = 1.0 / torch.sqrt(
|
|
||||||
torch.tensor(self.head_size, dtype=torch.float32)
|
|
||||||
).to(torch.get_default_dtype())
|
|
||||||
|
|
||||||
if self.num_attention_heads % weights.process_group.size() != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"`num_attention_heads` must be divisible by `num_shards` "
|
|
||||||
f"(got `num_attention_heads`: {self.num_attention_heads} "
|
|
||||||
f"and `num_shards`: {weights.process_group.size()}"
|
|
||||||
)
|
|
||||||
self.num_attention_heads = (
|
|
||||||
self.num_attention_heads // weights.process_group.size()
|
|
||||||
)
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
|
|
||||||
)
|
|
||||||
self.dense = TensorParallelRowLinear.load(
|
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
position_ids,
|
|
||||||
attention_mask,
|
|
||||||
head_mask=None,
|
|
||||||
layer_past=None,
|
|
||||||
use_cache=False,
|
|
||||||
output_attentions=False,
|
|
||||||
):
|
|
||||||
has_layer_past = layer_past is not None
|
|
||||||
|
|
||||||
# Compute QKV
|
|
||||||
# Attention heads [batch, seq_len, hidden_size]
|
|
||||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
|
||||||
qkv = self.query_key_value(hidden_states)
|
|
||||||
|
|
||||||
# [batch, seq_len, (num_heads * 3 * head_size)]
|
|
||||||
# --> [batch, seq_len, num_heads, 3 * head_size]
|
|
||||||
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
|
|
||||||
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
|
|
||||||
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
|
|
||||||
query, key, value = qkv.split(self.head_size, -1)
|
|
||||||
|
|
||||||
# Compute token offset for rotary embeddings (when decoding)
|
|
||||||
seq_len = key.shape[-2]
|
|
||||||
if has_layer_past:
|
|
||||||
seq_len += layer_past[0].shape[-2]
|
|
||||||
|
|
||||||
# Compute rotary embeddings on rotary_ndims
|
|
||||||
query_rot = query[..., : self.rotary_ndims]
|
|
||||||
key_rot = key[..., : self.rotary_ndims]
|
|
||||||
|
|
||||||
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
|
|
||||||
|
|
||||||
query[..., : self.rotary_ndims] = query_rot
|
|
||||||
key[..., : self.rotary_ndims] = key_rot
|
|
||||||
|
|
||||||
if CUSTOM_KERNELS_ENABLED:
|
|
||||||
attn_output, present, attn_weights = fused_attention_cuda.forward(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
layer_past,
|
|
||||||
attention_mask,
|
|
||||||
head_mask,
|
|
||||||
self.inv_norm_factor,
|
|
||||||
self.num_attention_heads,
|
|
||||||
use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Cache QKV values
|
|
||||||
if has_layer_past:
|
|
||||||
past_key = layer_past[0]
|
|
||||||
past_value = layer_past[1]
|
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
|
||||||
present = (key, value) if use_cache else None
|
|
||||||
|
|
||||||
# Compute attention
|
|
||||||
attn_output, attn_weights = self._attn(
|
|
||||||
query, key, value, attention_mask, head_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape outputs
|
|
||||||
attn_output = self._merge_heads(
|
|
||||||
attn_output, self.num_attention_heads, self.head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.dense(attn_output)
|
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (attn_weights,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
|
|
||||||
"""
|
|
||||||
Splits hidden dim into attn_head_size and num_attention_heads
|
|
||||||
"""
|
|
||||||
# tensor: [bs, seq_len, hidden_size]
|
|
||||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
|
||||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
|
||||||
tensor = tensor.view(new_shape)
|
|
||||||
# -> [bs, num_attention_heads, seq_len, attn_head_size]
|
|
||||||
tensor = tensor.permute(0, 2, 1, 3)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
|
|
||||||
"""
|
|
||||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
|
||||||
"""
|
|
||||||
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
|
|
||||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
|
||||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
|
||||||
tensor = tensor.view(
|
|
||||||
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
|
|
||||||
)
|
|
||||||
# -> [bs, seq_len, hidden_size]
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
|
||||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
|
||||||
# compute causal mask from causal mask buffer
|
|
||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
|
||||||
key_length = key.size(-2)
|
|
||||||
|
|
||||||
query = query.reshape(
|
|
||||||
batch_size * num_attention_heads, query_length, attn_head_size
|
|
||||||
)
|
|
||||||
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
|
|
||||||
attn_scores = torch.zeros(
|
|
||||||
1,
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=key.device,
|
|
||||||
).expand(batch_size * num_attention_heads, query_length, key_length)
|
|
||||||
attn_scores = torch.baddbmm(
|
|
||||||
attn_scores,
|
|
||||||
query,
|
|
||||||
key.transpose(1, 2),
|
|
||||||
beta=1.0,
|
|
||||||
alpha=self.inv_norm_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
|
||||||
input_dtype = attn_scores.dtype
|
|
||||||
if input_dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
attn_scores = attn_scores.to(torch.float)
|
|
||||||
attn_scores = torch.where(
|
|
||||||
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
|
|
||||||
)
|
|
||||||
attn_scores = attn_scores.view(
|
|
||||||
batch_size, num_attention_heads, query_length, key_length
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
|
||||||
|
|
||||||
# Mask heads if we want to
|
|
||||||
if head_mask is not None:
|
|
||||||
attn_weights = attn_weights * head_mask
|
|
||||||
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
return attn_output, attn_weights
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
|
||||||
super().__init__()
|
|
||||||
self.true_inv_freq = 1.0 / (
|
|
||||||
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", self.true_inv_freq)
|
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
|
||||||
self.cos_cached = None
|
|
||||||
self.sin_cached = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
|
|
||||||
t = torch.arange(
|
|
||||||
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
|
|
||||||
)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
|
|
||||||
|
|
||||||
def forward(self, q, k, position_ids, seq_len=None):
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
if (
|
|
||||||
seq_len > self.max_seq_len_cached
|
|
||||||
or self.cos_cached is None
|
|
||||||
or self.sin_cached is None
|
|
||||||
):
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
self.cos_cached, self.sin_cached = self._create_cos_sin(
|
|
||||||
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
|
|
||||||
)
|
|
||||||
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def rotary_forward(q, k, cos, sin, position_ids):
|
|
||||||
cos = cos[position_ids].unsqueeze(1)
|
|
||||||
sin = sin[position_ids].unsqueeze(1)
|
|
||||||
|
|
||||||
chunk_size = q.shape[-1] // 2
|
|
||||||
q1, q2 = q.split(chunk_size, -1)
|
|
||||||
q_rotated = torch.cat((-q2, q1), dim=-1)
|
|
||||||
k1, k2 = k.split(chunk_size, -1)
|
|
||||||
k_rotated = torch.cat((-k2, k1), dim=-1)
|
|
||||||
|
|
||||||
q_embed = (q * cos) + (q_rotated * sin)
|
|
||||||
k_embed = (k * cos) + (k_rotated * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXMLP(nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.act = (
|
|
||||||
ACT2FN[config.hidden_act]
|
|
||||||
if "gelu_fast" not in config.hidden_act
|
|
||||||
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
|
||||||
)
|
|
||||||
self.dense_4h_to_h = TensorParallelRowLinear.load(
|
|
||||||
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
|
||||||
hidden_states = self.act(hidden_states)
|
|
||||||
hidden_states = self.dense_4h_to_h(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXLayer(nn.Module):
|
|
||||||
def __init__(self, layer_id, prefix: str, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
|
||||||
self.input_layernorm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
self.attention = GPTNeoXAttention(
|
|
||||||
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
|
|
||||||
)
|
|
||||||
self.mlp = GPTNeoXMLP(
|
|
||||||
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
position_ids,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
use_cache=False,
|
|
||||||
layer_past=None,
|
|
||||||
output_attentions=False,
|
|
||||||
):
|
|
||||||
attention_layer_outputs = self.attention(
|
|
||||||
self.input_layernorm(hidden_states),
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
layer_past=layer_past,
|
|
||||||
head_mask=head_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
attn_output = attention_layer_outputs[
|
|
||||||
0
|
|
||||||
] # output_attn: attn_output, present, (attn_weights)
|
|
||||||
outputs = attention_layer_outputs[1:]
|
|
||||||
|
|
||||||
if self.use_parallel_residual:
|
|
||||||
# pseudocode:
|
|
||||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
|
||||||
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
|
||||||
hidden_states = mlp_output + attn_output + hidden_states
|
|
||||||
else:
|
|
||||||
# pseudocode:
|
|
||||||
# x = x + attn(ln1(x))
|
|
||||||
# x = x + mlp(ln2(x))
|
|
||||||
attn_output = attn_output + hidden_states
|
|
||||||
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
|
|
||||||
hidden_states = mlp_output + attn_output
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs = (
|
|
||||||
hidden_states,
|
|
||||||
) + outputs # hidden_states, present, (attn_weights)
|
|
||||||
else:
|
|
||||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|
||||||
def __init__(self, prefix: str, config, weights):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
|
||||||
|
|
||||||
self.embed_in = TensorParallelEmbedding(
|
|
||||||
prefix=f"{prefix}.embed_in", weights=weights
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
GPTNeoXLayer(layer_id, prefix, config, weights)
|
|
||||||
for layer_id in range(config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.final_layer_norm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
self.tp_world_size = weights.process_group.size()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids=None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
||||||
`past_key_values`).
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_length, seq_length + past_length, dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_in(input_ids)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# Attention mask.
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
if past_key_values[0] is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[-1]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(
|
|
||||||
(batch_size, seq_length_with_past), device=hidden_states.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
|
||||||
|
|
||||||
causal_mask = prepare_attn_mask(
|
|
||||||
attention_mask,
|
|
||||||
input_shape=(batch_size, seq_length),
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.num_attention_heads % self.tp_world_size == 0
|
|
||||||
block_size = self.num_attention_heads // self.tp_world_size
|
|
||||||
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_attentions = () if output_attentions else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
outputs = layer(
|
|
||||||
hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
head_mask=head_mask[i],
|
|
||||||
layer_past=layer_past,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
if output_attentions:
|
|
||||||
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
|
||||||
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
# Add last hidden state
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, presents, all_hidden_states, all_attentions]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=presents,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
|
||||||
|
|
||||||
def __init__(self, prefix: str, config, weights):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
if not prefix:
|
|
||||||
prefix = "gpt_neox"
|
|
||||||
else:
|
|
||||||
prefix = f"{prefix}.gpt_neox"
|
|
||||||
|
|
||||||
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
|
|
||||||
self.embed_out = SpeculativeHead.load(
|
|
||||||
config, prefix="embed_out", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
|
|
||||||
only required when the model is used as a decoder in a Sequence to Sequence model.
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
|
|
||||||
`past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
|
||||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
|
||||||
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
||||||
`past_key_values`).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
|
|
||||||
>>> import torch
|
|
||||||
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
||||||
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
||||||
>>> config.is_decoder = True
|
|
||||||
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
|
|
||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
||||||
>>> outputs = model(**inputs)
|
|
||||||
|
|
||||||
>>> prediction_logits = outputs.logits
|
|
||||||
```"""
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.gpt_neox(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
lm_logits, speculative_logits = self.embed_out(hidden_states)
|
|
||||||
|
|
||||||
lm_loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# move labels to correct device to enable model parallelism
|
|
||||||
labels = labels.to(lm_logits.device)
|
|
||||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (lm_logits,) + outputs[1:]
|
|
||||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
|
||||||
|
|
||||||
return (
|
|
||||||
CausalLMOutputWithPast(
|
|
||||||
loss=lm_loss,
|
|
||||||
logits=lm_logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
),
|
|
||||||
speculative_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
attention_mask=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
input_shape = input_ids.shape
|
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
|
||||||
if past_key_values and past_key_values[0] is not None:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
if attention_mask is not None and position_ids is None:
|
|
||||||
# create position_ids on the fly for batch generation
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
if past_key_values:
|
|
||||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
||||||
else:
|
|
||||||
model_inputs = {"input_ids": input_ids}
|
|
||||||
|
|
||||||
model_inputs.update(
|
|
||||||
{
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def _reorder_cache(self, past_key_values, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past_key_values:
|
|
||||||
reordered_past += (
|
|
||||||
tuple(
|
|
||||||
past_state.index_select(0, beam_idx)
|
|
||||||
for past_state in layer_past[:2]
|
|
||||||
)
|
|
||||||
+ layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
@ -1,857 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" PyTorch OPT model."""
|
|
||||||
import random
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
BaseModelOutputWithPast,
|
|
||||||
CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers import OPTConfig
|
|
||||||
from text_generation_server.layers import (
|
|
||||||
FastLinear,
|
|
||||||
TensorParallelColumnLinear,
|
|
||||||
TensorParallelEmbedding,
|
|
||||||
TensorParallelRowLinear,
|
|
||||||
SpeculativeHead,
|
|
||||||
)
|
|
||||||
|
|
||||||
EPS = 1e-5
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
||||||
def _make_causal_mask(
|
|
||||||
input_ids_shape: torch.Size,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
past_key_values_length: int = 0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Make causal mask used for bi-directional self-attention.
|
|
||||||
"""
|
|
||||||
bsz, tgt_len = input_ids_shape
|
|
||||||
mask = torch.full(
|
|
||||||
(tgt_len, tgt_len),
|
|
||||||
torch.tensor(torch.finfo(dtype).min, device=device),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
||||||
mask = mask.to(dtype)
|
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
|
||||||
mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
mask,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
return mask[None, None, :, :].expand(
|
|
||||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
||||||
"""
|
|
||||||
bsz, src_len = mask.size()
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
||||||
|
|
||||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTLearnedPositionalEmbedding(nn.Module):
|
|
||||||
"""
|
|
||||||
This module learns positional embeddings up to a fixed maximum size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefix: str, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.offset = 2
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
weights.get_tensor(
|
|
||||||
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
|
|
||||||
):
|
|
||||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
|
||||||
attention_mask = attention_mask.long()
|
|
||||||
|
|
||||||
# create positions depending on attention_mask
|
|
||||||
positions = (
|
|
||||||
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
|
|
||||||
).long() - 1
|
|
||||||
|
|
||||||
# cut positions if `past_key_values_length` is > 0
|
|
||||||
positions = positions[:, past_key_values_length:]
|
|
||||||
|
|
||||||
return torch.nn.functional.embedding(positions + self.offset, self.weight)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTAttention(nn.Module):
|
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
prefix,
|
|
||||||
weights,
|
|
||||||
is_decoder: bool = False,
|
|
||||||
bias: bool = True,
|
|
||||||
process_group=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
hidden_size = config.hidden_size
|
|
||||||
num_heads = config.num_attention_heads
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.dropout = config.dropout
|
|
||||||
self.head_dim = hidden_size // num_heads
|
|
||||||
|
|
||||||
if (self.head_dim * num_heads) != self.hidden_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
||||||
f" and `num_heads`: {num_heads})."
|
|
||||||
)
|
|
||||||
self.scaling = self.head_dim**-0.5
|
|
||||||
self.is_decoder = is_decoder
|
|
||||||
|
|
||||||
process_group = weights.process_group
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
|
||||||
f"and `num_shards`: {weights.process_group.size()}"
|
|
||||||
)
|
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
|
||||||
self.hidden_size = self.hidden_size // process_group.size()
|
|
||||||
|
|
||||||
self.q_proj = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
|
|
||||||
)
|
|
||||||
self.k_proj = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
|
|
||||||
)
|
|
||||||
self.v_proj = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
|
|
||||||
)
|
|
||||||
self.out_proj = TensorParallelRowLinear.load(
|
|
||||||
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
|
|
||||||
)
|
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
||||||
return (
|
|
||||||
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""Input shape: Batch x Time x Channel"""
|
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
||||||
# for the decoder
|
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
# get query proj
|
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
|
||||||
# get key, value proj
|
|
||||||
if is_cross_attention and past_key_value is not None:
|
|
||||||
# reuse k,v, cross_attentions
|
|
||||||
key_states = past_key_value[0]
|
|
||||||
value_states = past_key_value[1]
|
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
else:
|
|
||||||
# self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
|
|
||||||
if self.is_decoder:
|
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
||||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
||||||
key_states = key_states.view(*proj_shape)
|
|
||||||
value_states = value_states.view(*proj_shape)
|
|
||||||
|
|
||||||
src_len = key_states.size(1)
|
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = (
|
|
||||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
||||||
+ attention_mask
|
|
||||||
)
|
|
||||||
attn_weights = torch.max(
|
|
||||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
||||||
|
|
||||||
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
|
||||||
if attn_weights.dtype == torch.float16:
|
|
||||||
attn_weights = nn.functional.softmax(
|
|
||||||
attn_weights, dim=-1, dtype=torch.float32
|
|
||||||
).to(torch.float16)
|
|
||||||
else:
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
|
||||||
if layer_head_mask.size() != (self.num_heads,):
|
|
||||||
raise ValueError(
|
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
|
||||||
f" {layer_head_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
|
||||||
bsz, self.num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
# this operation is a bit awkward, but it's required to
|
|
||||||
# make sure that attn_weights keeps its gradient.
|
|
||||||
# In order to do so, attn_weights have to be reshaped
|
|
||||||
# twice and have to be reused in the following
|
|
||||||
attn_weights_reshaped = attn_weights.view(
|
|
||||||
bsz, self.num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights_reshaped.view(
|
|
||||||
bsz * self.num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_weights_reshaped = None
|
|
||||||
|
|
||||||
attn_probs = nn.functional.dropout(
|
|
||||||
attn_weights, p=self.dropout, training=self.training
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
|
|
||||||
# Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
||||||
# partitioned aross GPUs when using tensor-parallelism.
|
|
||||||
attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.process_group = weights.process_group
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
|
|
||||||
self.self_attn = OPTAttention(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.self_attn",
|
|
||||||
weights=weights,
|
|
||||||
is_decoder=True,
|
|
||||||
bias=config.enable_bias,
|
|
||||||
)
|
|
||||||
self.do_layer_norm_before = config.do_layer_norm_before
|
|
||||||
self.dropout = config.dropout
|
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
|
||||||
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
|
|
||||||
)
|
|
||||||
self.fc1 = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
|
|
||||||
)
|
|
||||||
self.fc2 = TensorParallelRowLinear.load(
|
|
||||||
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
|
|
||||||
)
|
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
||||||
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
|
||||||
`(encoder_attention_heads,)`.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
||||||
if self.do_layer_norm_before:
|
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
layer_head_mask=layer_head_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
hidden_states = nn.functional.dropout(
|
|
||||||
hidden_states, p=self.dropout, training=self.training
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# 350m applies layer norm AFTER attention
|
|
||||||
if not self.do_layer_norm_before:
|
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
hidden_states_shape = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
||||||
if self.do_layer_norm_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.fc1(hidden_states)
|
|
||||||
hidden_states = self.activation_fn(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.fc2(hidden_states)
|
|
||||||
hidden_states = nn.functional.dropout(
|
|
||||||
hidden_states, p=self.dropout, training=self.training
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = (residual + hidden_states).view(hidden_states_shape)
|
|
||||||
|
|
||||||
# 350m applies layer norm AFTER attention
|
|
||||||
if not self.do_layer_norm_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class OPTPreTrainedModel(PreTrainedModel):
|
|
||||||
config_class = OPTConfig
|
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoder(OPTPreTrainedModel):
|
|
||||||
def __init__(self, prefix: str, config: OPTConfig, weights):
|
|
||||||
super().__init__(config)
|
|
||||||
self.dropout = config.dropout
|
|
||||||
self.layerdrop = config.layerdrop
|
|
||||||
self.padding_idx = config.pad_token_id
|
|
||||||
self.max_target_positions = config.max_position_embeddings
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
|
|
||||||
prefix = prefix + "." if prefix else ""
|
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=f"{prefix}decoder.embed_tokens", weights=weights
|
|
||||||
)
|
|
||||||
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
|
||||||
self.project_out = FastLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}decoder.project_out",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.project_out = None
|
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
|
||||||
self.project_in = FastLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}decoder.project_in",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.project_in = None
|
|
||||||
|
|
||||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
|
||||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
|
||||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
|
||||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.final_layer_norm = None
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
OPTDecoderLayer(layer_id, prefix, config, weights)
|
|
||||||
for layer_id in range(config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
|
||||||
def _prepare_decoder_attention_mask(
|
|
||||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
||||||
):
|
|
||||||
# create causal mask
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = None
|
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape,
|
|
||||||
inputs_embeds.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
expanded_attn_mask = _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
).to(inputs_embeds.device)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask
|
|
||||||
if combined_attention_mask is None
|
|
||||||
else expanded_attn_mask + combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
|
||||||
provide it.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
|
||||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
|
||||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
|
||||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
|
||||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
||||||
than the model's internal embedding lookup matrix.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
||||||
for more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
past_key_values_length = (
|
|
||||||
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
||||||
)
|
|
||||||
# required mask seq length can be calculated via length of past
|
|
||||||
mask_seq_length = past_key_values_length + seq_length
|
|
||||||
|
|
||||||
# embed positions
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(
|
|
||||||
batch_size, mask_seq_length, device=inputs_embeds.device
|
|
||||||
)
|
|
||||||
causal_attention_mask = self._prepare_decoder_attention_mask(
|
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
||||||
)
|
|
||||||
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
|
||||||
|
|
||||||
if self.project_in is not None:
|
|
||||||
inputs_embeds = self.project_in(inputs_embeds)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds + pos_embeds
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
|
||||||
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
|
||||||
if attn_mask is not None:
|
|
||||||
if attn_mask.size()[0] != (len(self.layers)):
|
|
||||||
raise ValueError(
|
|
||||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
||||||
f" {head_mask.size()[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
dropout_probability = random.uniform(0, 1)
|
|
||||||
if self.training and (dropout_probability < self.layerdrop):
|
|
||||||
continue
|
|
||||||
|
|
||||||
past_key_value = (
|
|
||||||
past_key_values[idx] if past_key_values is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_attention_mask,
|
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
if self.final_layer_norm is not None:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
if self.project_out is not None:
|
|
||||||
hidden_states = self.project_out(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTModel(OPTPreTrainedModel):
|
|
||||||
def __init__(self, prefix: str, config: OPTConfig, weights):
|
|
||||||
super().__init__(config)
|
|
||||||
self.decoder = OPTDecoder(prefix, config, weights)
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
||||||
decoder_outputs = self.decoder(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return decoder_outputs
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
|
||||||
hidden_states=decoder_outputs.hidden_states,
|
|
||||||
attentions=decoder_outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTForCausalLM(OPTPreTrainedModel):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
self.model = OPTModel(prefix, config, weights)
|
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model.decoder(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
return (
|
|
||||||
CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
),
|
|
||||||
speculative_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
attention_mask=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if past_key_values:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
||||||
else:
|
|
||||||
model_inputs = {"input_ids": input_ids}
|
|
||||||
|
|
||||||
model_inputs.update(
|
|
||||||
{
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past_key_values, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past_key_values:
|
|
||||||
reordered_past += (
|
|
||||||
tuple(
|
|
||||||
past_state.index_select(0, beam_idx) for past_state in layer_past
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
@ -1,336 +0,0 @@
|
|||||||
# imlementation of the PhiModel and PhiForCausalLM classes
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
import math
|
|
||||||
from torch import nn
|
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
|
||||||
TensorParallelRowLinear,
|
|
||||||
TensorParallelColumnLinear,
|
|
||||||
TensorParallelEmbedding,
|
|
||||||
SpeculativeHead,
|
|
||||||
FastLinear,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# PhiConfig is the configuration class for the PhiModel.
|
|
||||||
class PhiConfig(PretrainedConfig):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=51200,
|
|
||||||
n_positions=2048,
|
|
||||||
n_embd=2560,
|
|
||||||
n_layer=32,
|
|
||||||
n_inner=None,
|
|
||||||
n_head=32,
|
|
||||||
rotary_dim=32,
|
|
||||||
layer_norm_epsilon=1e-5,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
pad_vocab_size_multiple=64,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
eos_token_id=2,
|
|
||||||
no_bias=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.n_positions = n_positions
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_inner = n_inner
|
|
||||||
self.n_head = n_head
|
|
||||||
self.rotary_dim = rotary_dim
|
|
||||||
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
|
||||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.no_bias = no_bias
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# RotaryEmbedding is a class that implements the rotary embedding.
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim, max_seq_len):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
|
|
||||||
inv_freq_len = len(inv_freq)
|
|
||||||
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
|
|
||||||
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
|
|
||||||
freqs = t.matmul(inv_freq)
|
|
||||||
self.sin = freqs.sin()
|
|
||||||
self.cos = freqs.cos()
|
|
||||||
|
|
||||||
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
|
|
||||||
b_size, seqlen, three, _, _headdim = qkv.shape
|
|
||||||
if three != 3:
|
|
||||||
raise Exception("unexpected shape for qkv")
|
|
||||||
_, rotary_dim = self.cos.shape
|
|
||||||
rotary_dim = rotary_dim * 2
|
|
||||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
|
||||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
|
||||||
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
|
||||||
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
|
||||||
q12 = torch.chunk(q_rot, 2, dim=-1)
|
|
||||||
k12 = torch.chunk(k_rot, 2, dim=-1)
|
|
||||||
q1, q2 = q12[0], q12[1]
|
|
||||||
k1, k2 = k12[0], k12[1]
|
|
||||||
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
|
||||||
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
|
||||||
q_rot = torch.cat(
|
|
||||||
[
|
|
||||||
q1 * c - q2 * s,
|
|
||||||
q1 * s + q2 * c,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
k_rot = torch.cat(
|
|
||||||
[
|
|
||||||
k1 * c - k2 * s,
|
|
||||||
k1 * s + k2 * c,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
q = torch.cat([q_rot, q_pass], dim=-1)
|
|
||||||
k = torch.cat([k_rot, k_pass], dim=-1)
|
|
||||||
v = qkv[:, :, 2]
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
|
|
||||||
class PhiCausalLMHead(nn.Module):
|
|
||||||
def __init__(self, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.ln = nn.LayerNorm.load(
|
|
||||||
prefix="lm_head.ln",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_epsilon,
|
|
||||||
)
|
|
||||||
self.linear = SpeculativeHead.load(
|
|
||||||
config=config, prefix="lm_head.linear", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.ln(hidden_states)
|
|
||||||
hidden_states = self.linear(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
|
|
||||||
class PhiMHA(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.Wqkv = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
|
||||||
)
|
|
||||||
self.out_proj = TensorParallelRowLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.out_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=not config.no_bias,
|
|
||||||
)
|
|
||||||
self.op_size = config.n_embd
|
|
||||||
self.head_dim = int(config.n_embd / config.n_head)
|
|
||||||
self.num_heads = config.n_head
|
|
||||||
self.rotary_emb = RotaryEmbedding(
|
|
||||||
config.rotary_dim,
|
|
||||||
config.n_positions,
|
|
||||||
)
|
|
||||||
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
past_kv_cache,
|
|
||||||
attention_mask=None,
|
|
||||||
):
|
|
||||||
b_size, seq_len, _n_embd = hidden_states.shape
|
|
||||||
qkv = self.Wqkv(hidden_states)
|
|
||||||
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
|
|
||||||
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
|
|
||||||
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
|
|
||||||
|
|
||||||
# if there is a kv_cache, then we need to concatenate
|
|
||||||
if past_kv_cache is not None:
|
|
||||||
prev_k, prev_v = past_kv_cache
|
|
||||||
k = torch.cat([prev_k, k], dim=1)
|
|
||||||
v = torch.cat([prev_v, v], dim=1)
|
|
||||||
|
|
||||||
past_kv_cache = [k, v]
|
|
||||||
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
seqlen_k = k.shape[1]
|
|
||||||
seqlen_q = q.shape[1]
|
|
||||||
causal_mask = torch.triu(
|
|
||||||
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
|
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
||||||
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
|
|
||||||
attn_output = (
|
|
||||||
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
|
|
||||||
.transpose(1, 2)
|
|
||||||
.flatten(-2)
|
|
||||||
)
|
|
||||||
return self.out_proj(attn_output), past_kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
|
|
||||||
class PhiMLP(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.n_inner = config.n_inner
|
|
||||||
self.fc1 = FastLinear.load(
|
|
||||||
config=config,
|
|
||||||
prefix=f"{prefix}.fc1",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.fc2 = FastLinear.load(
|
|
||||||
config=config,
|
|
||||||
prefix=f"{prefix}.fc2",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.activation = torch.nn.functional.gelu
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.fc1(hidden_states)
|
|
||||||
hidden_states = self.activation(hidden_states)
|
|
||||||
hidden_states = self.fc2(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
|
|
||||||
class PhiBlock(nn.Module):
|
|
||||||
def __init__(self, layer_id, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.layer_norm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
|
|
||||||
)
|
|
||||||
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
|
||||||
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
kv_cache,
|
|
||||||
attention_mask,
|
|
||||||
):
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
|
||||||
attn_outputs, past_kv_cache = self.mixer(
|
|
||||||
hidden_states, kv_cache, attention_mask
|
|
||||||
)
|
|
||||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
|
||||||
out = attn_outputs + feed_forward_hidden_states + residual
|
|
||||||
return out, past_kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
# PhiModel implements the embedding layer and the transformer blocks.
|
|
||||||
class PhiModel(nn.Module):
|
|
||||||
def __init__(self, prefix: str, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.tp_rank = weights.process_group.rank()
|
|
||||||
self.tp_world_size = weights.process_group.size()
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=f"{prefix}.embd.wte", weights=weights
|
|
||||||
)
|
|
||||||
self.blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
|
|
||||||
for layer_id in range(config.n_layer)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.ByteTensor] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
seq_len = hidden_states.shape[1]
|
|
||||||
mask = None if seq_len <= 1 else attention_mask
|
|
||||||
|
|
||||||
past_key_values = (
|
|
||||||
[None] * len(self.blocks) if past_key_values is None else past_key_values
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, block in enumerate(self.blocks):
|
|
||||||
hidden_states, new_key_values = block(
|
|
||||||
hidden_states, past_key_values[index], mask
|
|
||||||
)
|
|
||||||
past_key_values[index] = new_key_values
|
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
|
||||||
|
|
||||||
|
|
||||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
|
||||||
class PhiForCausalLM(torch.nn.Module):
|
|
||||||
def __init__(self, prefix: str, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not prefix:
|
|
||||||
prefix = "transformer"
|
|
||||||
else:
|
|
||||||
prefix = f"{prefix}.transformer"
|
|
||||||
|
|
||||||
self.model = PhiModel(prefix, config, weights)
|
|
||||||
self.lm_head = PhiCausalLMHead(config, weights)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.ByteTensor] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
model_output = self.model(
|
|
||||||
input_ids, past_key_values, attention_mask, return_dict, use_cache
|
|
||||||
)
|
|
||||||
logits = self.lm_head(model_output[0])
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = nn.CrossEntropyLoss()(
|
|
||||||
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (
|
|
||||||
((loss,) + (logits,) + model_output[1:])
|
|
||||||
if loss is not None
|
|
||||||
else (logits,) + model_output[1:]
|
|
||||||
)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=model_output[1],
|
|
||||||
hidden_states=None,
|
|
||||||
attentions=None,
|
|
||||||
)
|
|
@ -0,0 +1,946 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Qwen2.5 VL model."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
|
||||||
|
from typing import Union
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
from transformers.image_utils import ImageInput, VideoInput
|
||||||
|
from transformers.processing_utils import (
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
VideosKwargs,
|
||||||
|
)
|
||||||
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
|
||||||
|
fps: Union[List[float], float]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": False,
|
||||||
|
},
|
||||||
|
"videos_kwargs": {"fps": 2.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||||
|
r"""
|
||||||
|
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
|
||||||
|
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||||
|
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
|
||||||
|
Args:
|
||||||
|
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||||
|
The image processor is a required input.
|
||||||
|
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||||
|
The tokenizer is a required input.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["chat_template"]
|
||||||
|
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
|
||||||
|
):
|
||||||
|
self.image_token = (
|
||||||
|
"<|image_pad|>"
|
||||||
|
if not hasattr(tokenizer, "image_token")
|
||||||
|
else tokenizer.image_token
|
||||||
|
)
|
||||||
|
self.video_token = (
|
||||||
|
"<|video_pad|>"
|
||||||
|
if not hasattr(tokenizer, "video_token")
|
||||||
|
else tokenizer.video_token
|
||||||
|
)
|
||||||
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: ImageInput = None,
|
||||||
|
text: Union[
|
||||||
|
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
||||||
|
] = None,
|
||||||
|
videos: VideoInput = None,
|
||||||
|
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||||
|
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||||
|
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
||||||
|
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. Both channels-first and channels-last formats are supported.
|
||||||
|
text (`str`, `List[str]`, `List[List[str]]`):
|
||||||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
|
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
|
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||||
|
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||||
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
|
If set, will return tensors of a particular framework. Acceptable values are:
|
||||||
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||||
|
|
||||||
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||||
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||||
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||||
|
`None`).
|
||||||
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
|
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||||
|
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||||
|
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||||
|
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
||||||
|
"""
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
Qwen2_5_VLProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if images is not None:
|
||||||
|
image_inputs = self.image_processor(
|
||||||
|
images=images, videos=None, **output_kwargs["images_kwargs"]
|
||||||
|
)
|
||||||
|
image_grid_thw = image_inputs["image_grid_thw"]
|
||||||
|
else:
|
||||||
|
image_inputs = {}
|
||||||
|
image_grid_thw = None
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
videos_inputs = self.image_processor(
|
||||||
|
images=None, videos=videos, **output_kwargs["images_kwargs"]
|
||||||
|
)
|
||||||
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||||
|
|
||||||
|
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||||
|
if isinstance(fps, (int, float)):
|
||||||
|
second_per_grid_ts = [
|
||||||
|
self.image_processor.temporal_patch_size / fps
|
||||||
|
] * len(video_grid_thw)
|
||||||
|
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||||
|
second_per_grid_ts = [
|
||||||
|
self.image_processor.temporal_patch_size / tmp for tmp in fps
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||||
|
)
|
||||||
|
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||||
|
|
||||||
|
else:
|
||||||
|
videos_inputs = {}
|
||||||
|
video_grid_thw = None
|
||||||
|
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
if image_grid_thw is not None:
|
||||||
|
merge_length = self.image_processor.merge_size**2
|
||||||
|
index = 0
|
||||||
|
for i in range(len(text)):
|
||||||
|
while self.image_token in text[i]:
|
||||||
|
text[i] = text[i].replace(
|
||||||
|
self.image_token,
|
||||||
|
"<|placeholder|>"
|
||||||
|
* (image_grid_thw[index].prod() // merge_length),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||||
|
|
||||||
|
if video_grid_thw is not None:
|
||||||
|
merge_length = self.image_processor.merge_size**2
|
||||||
|
index = 0
|
||||||
|
for i in range(len(text)):
|
||||||
|
while self.video_token in text[i]:
|
||||||
|
text[i] = text[i].replace(
|
||||||
|
self.video_token,
|
||||||
|
"<|placeholder|>"
|
||||||
|
* (video_grid_thw[index].prod() // merge_length),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||||
|
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
refer to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||||
|
the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def post_process_image_text_to_text(self, generated_outputs):
|
||||||
|
"""
|
||||||
|
Post-process the output of the model to decode the text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||||
|
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||||
|
or `(sequence_length,)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The decoded text.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(
|
||||||
|
generated_outputs,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
names_from_processor = list(
|
||||||
|
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
|
||||||
|
)
|
||||||
|
return names_from_processor + ["second_per_grid_ts"]
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
|
||||||
|
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen2_5_vl"
|
||||||
|
base_config_key = "vision_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth=32,
|
||||||
|
hidden_size=3584,
|
||||||
|
hidden_act="silu",
|
||||||
|
intermediate_size=3420,
|
||||||
|
num_heads=16,
|
||||||
|
in_channels=3,
|
||||||
|
patch_size=14,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
spatial_patch_size=14,
|
||||||
|
temporal_patch_size=2,
|
||||||
|
tokens_per_second=4,
|
||||||
|
window_size=112,
|
||||||
|
out_hidden_size=3584,
|
||||||
|
fullatt_block_indexes=[7, 15, 23, 31],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_patch_size = spatial_patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.tokens_per_second = tokens_per_second
|
||||||
|
self.window_size = window_size
|
||||||
|
self.fullatt_block_indexes = fullatt_block_indexes
|
||||||
|
self.out_hidden_size = out_hidden_size
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=152064,
|
||||||
|
hidden_size=8192,
|
||||||
|
intermediate_size=29568,
|
||||||
|
num_hidden_layers=80,
|
||||||
|
num_attention_heads=64,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=1000000.0,
|
||||||
|
use_sliding_window=False,
|
||||||
|
sliding_window=4096,
|
||||||
|
max_window_layers=80,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
vision_config=None,
|
||||||
|
rope_scaling=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if vision_config is not None:
|
||||||
|
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.max_window_layers = max_window_layers
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
|
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||||
|
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||||
|
# TODO: @raushan update config in the hub
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
if self.rope_scaling["type"] == "mrope":
|
||||||
|
self.rope_scaling["type"] = "default"
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
|
||||||
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(
|
||||||
|
tensor: torch.Tensor, freqs: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
tensor = tensor.float()
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||||
|
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||||
|
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||||
|
output = output.to(orig_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLAttention(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size // weights.process_group.size()
|
||||||
|
self.head_dim = config.hidden_size // config.num_heads
|
||||||
|
self.num_heads = config.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_heads,
|
||||||
|
)
|
||||||
|
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||||
|
|
||||||
|
self.proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# apply the qkv linear layer to the hidden state
|
||||||
|
qkv = self.qkv(hidden_state)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape the query, key, and value tensors
|
||||||
|
_shape = (
|
||||||
|
hidden_state.shape[0],
|
||||||
|
self.num_heads,
|
||||||
|
self.embed_dim // self.num_heads,
|
||||||
|
)
|
||||||
|
query = query.view(*_shape)
|
||||||
|
key = key.view(*_shape)
|
||||||
|
value = value.view(*_shape)
|
||||||
|
|
||||||
|
# apply rotary positional embeddings
|
||||||
|
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
# calc maximum sequence length for any batch
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# execute sdpa
|
||||||
|
query = query.unsqueeze(0).transpose(1, 2)
|
||||||
|
key = key.unsqueeze(0).transpose(1, 2)
|
||||||
|
value = value.unsqueeze(0).transpose(1, 2)
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=None,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=None,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
|
||||||
|
# reshape output to original dimensions
|
||||||
|
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLVisionMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.gate = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.down = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate_states = self.gate(hidden_states)
|
||||||
|
up_states = self.up(hidden_states)
|
||||||
|
activated_states = self.activation_fn(gate_states) * up_states
|
||||||
|
down_states = self.down(activated_states)
|
||||||
|
return down_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLVisionBlock(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Qwen2_5VLAttention(
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.norm1 = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm1",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.norm2 = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm2",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.mlp = Qwen2_5VLVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||||
|
) -> torch.Tensor:
|
||||||
|
norm1_out, _ = self.norm1(hidden_states)
|
||||||
|
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||||
|
hidden_states = hidden_states + attn_out
|
||||||
|
norm2_out, _ = self.norm2(hidden_states)
|
||||||
|
mlp_out = self.mlp(norm2_out)
|
||||||
|
hidden_states = hidden_states + mlp_out
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLPatchMerger(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
||||||
|
self.patch_merger_ln_q = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_q",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = F.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VisionModel(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.spatial_merge_size = config.spatial_merge_size
|
||||||
|
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||||
|
self.patch_embedding = nn.Conv3d(
|
||||||
|
in_channels=config.in_channels,
|
||||||
|
out_channels=config.hidden_size,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=kernel_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
head_dim = config.hidden_size // config.num_heads
|
||||||
|
|
||||||
|
theta = 10000.0
|
||||||
|
dim = head_dim // 2
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen2_5VLVisionBlock(
|
||||||
|
prefix=f"{prefix}.blocks.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for i in range(config.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.merger = Qwen2_5VLPatchMerger(
|
||||||
|
prefix=f"{prefix}.merger",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
self.temporal_patch_size = config.temporal_patch_size
|
||||||
|
self.spatial_patch_size = config.spatial_patch_size
|
||||||
|
self.in_channels = config.in_channels
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.window_size = config.window_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
|
||||||
|
self.fullatt_block_indexes = config.fullatt_block_indexes
|
||||||
|
|
||||||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, _, hidden_size = hidden_state.shape
|
||||||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
def get_window_index(self, grid_thw):
|
||||||
|
window_index: list = []
|
||||||
|
cu_window_seqlens: list = [0]
|
||||||
|
window_index_id = 0
|
||||||
|
vit_merger_window_size = (
|
||||||
|
self.window_size // self.spatial_merge_size // self.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
for grid_t, grid_h, grid_w in grid_thw:
|
||||||
|
llm_grid_h, llm_grid_w = (
|
||||||
|
grid_h // self.spatial_merge_size,
|
||||||
|
grid_w // self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||||
|
grid_t, llm_grid_h, llm_grid_w
|
||||||
|
)
|
||||||
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||||
|
index_padded = index_padded.reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h,
|
||||||
|
vit_merger_window_size,
|
||||||
|
num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h * num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||||
|
index_padded = index_padded.reshape(-1)
|
||||||
|
index_new = index_padded[index_padded != -100]
|
||||||
|
window_index.append(index_new + window_index_id)
|
||||||
|
cu_seqlens_tmp = (
|
||||||
|
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||||
|
)
|
||||||
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||||
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||||
|
window_index = torch.cat(window_index, dim=0)
|
||||||
|
|
||||||
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# reshape the input tensor for processing
|
||||||
|
shape = (
|
||||||
|
-1,
|
||||||
|
self.in_channels,
|
||||||
|
self.temporal_patch_size,
|
||||||
|
self.spatial_patch_size,
|
||||||
|
self.spatial_patch_size,
|
||||||
|
)
|
||||||
|
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
|
||||||
|
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
|
||||||
|
# TODO: revisit to see if we can avoid some of these reshapes
|
||||||
|
|
||||||
|
# find the position ids for the input tensor based on the grid_thw
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||||
|
hpos_ids = hpos_ids.flatten()
|
||||||
|
|
||||||
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||||
|
wpos_ids = wpos_ids.flatten()
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
|
||||||
|
# apply the positional embeddings to the position ids
|
||||||
|
seq = torch.arange(
|
||||||
|
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||||
|
seq_len = hidden_states.shape[0]
|
||||||
|
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
og_shape = (seq_len, -1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
|
||||||
|
og_shape
|
||||||
|
)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
|
||||||
|
og_shape
|
||||||
|
)
|
||||||
|
|
||||||
|
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
|
||||||
|
|
||||||
|
cu_window_seqlens = torch.tensor(
|
||||||
|
cu_window_seqlens,
|
||||||
|
device="cpu",
|
||||||
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
|
)
|
||||||
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
|
||||||
|
hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a cu_seqlens tensor to be used in the attention mask
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||||
|
|
||||||
|
# iterately apply the blocks to the hidden states
|
||||||
|
for layer_num, block in enumerate(self.blocks):
|
||||||
|
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
|
||||||
|
# that are applied at specific layers.
|
||||||
|
if layer_num in self.fullatt_block_indexes:
|
||||||
|
cu_seqlens_now = cu_seqlens
|
||||||
|
else:
|
||||||
|
cu_seqlens_now = cu_window_seqlens
|
||||||
|
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply the final patch merger to the hidden states
|
||||||
|
hidden_states = self.merger(hidden_states)
|
||||||
|
reverse_indices = torch.argsort(window_index)
|
||||||
|
hidden_states = hidden_states[reverse_indices, :]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
|
||||||
|
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
|
||||||
|
if (
|
||||||
|
hasattr(config, "rope_scaling")
|
||||||
|
and config.rope_scaling is not None
|
||||||
|
and config.rope_scaling.get("type", None) == "default"
|
||||||
|
):
|
||||||
|
config.rope_scaling.update({"rope_type": "mrope"})
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.vision_start_token_id = config.vision_start_token_id
|
||||||
|
self.vision_end_token_id = config.vision_end_token_id
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.video_token_id = config.video_token_id
|
||||||
|
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.visual = Qwen2_5VisionModel(
|
||||||
|
prefix="visual", config=config.vision_config, weights=weights
|
||||||
|
)
|
||||||
|
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.device = weights.device
|
||||||
|
|
||||||
|
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
||||||
|
# modified to first find segments then initialize position ids for each segment
|
||||||
|
# Steps:
|
||||||
|
# locate all vision and text segments
|
||||||
|
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
||||||
|
# calculate `text_segment_lengths` for each text segment to be used as offset
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
# create position ids for each text segment
|
||||||
|
# combine all the position ids
|
||||||
|
# the final segment is the difference between the last vision segment and the end of the input
|
||||||
|
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||||
|
def get_position_ids(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if image_grid_thw is None:
|
||||||
|
return (
|
||||||
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(1, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
spatial_merge_size = self.spatial_merge_size
|
||||||
|
vision_start_token_id = self.vision_start_token_id
|
||||||
|
vision_end_token_id = self.vision_end_token_id
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
input_ids_len = input_ids.shape[0]
|
||||||
|
|
||||||
|
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
|
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
|
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||||
|
prev_vision_end = torch.cat(
|
||||||
|
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||||
|
)
|
||||||
|
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||||
|
vision_widths_max = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||||
|
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||||
|
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||||
|
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
for i, _ in enumerate(vision_segments):
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[i][0],
|
||||||
|
image_grid_thw[i][1] // spatial_merge_size,
|
||||||
|
image_grid_thw[i][2] // spatial_merge_size,
|
||||||
|
)
|
||||||
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||||
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||||
|
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||||
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
|
# offset by the position of the last vision segment
|
||||||
|
im = image_position_ids + vision_segment_lengths[i]
|
||||||
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
|
# create position ids for each text segment
|
||||||
|
text_ranges = [
|
||||||
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
+ text_segment_lengths[i]
|
||||||
|
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||||
|
]
|
||||||
|
|
||||||
|
full_llm_pos_ids_list = [
|
||||||
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
|
]
|
||||||
|
# import ipdb
|
||||||
|
|
||||||
|
# ipdb.set_trace()
|
||||||
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
|
final_text_len = input_ids_len - vision_ends[-1]
|
||||||
|
if final_text_len > 0:
|
||||||
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
|
position_ids = (
|
||||||
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
|
)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
# Unused in this model
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# apply the visual model to the pixel values if they are provided
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.visual(
|
||||||
|
pixel_values, grid_thw=image_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
|
mask = torch.where(input_ids == self.image_token_id)
|
||||||
|
inputs_embeds[mask] = image_embeds
|
||||||
|
|
||||||
|
hidden_states = self.text_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,519 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Qwen2 VL model."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(
|
||||||
|
tensor: torch.Tensor, freqs: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
tensor = tensor.float()
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||||
|
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||||
|
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||||
|
output = output.to(orig_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLAttention(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||||
|
self.head_dim = config.hidden_size // config.num_heads
|
||||||
|
self.num_heads = config.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_heads,
|
||||||
|
)
|
||||||
|
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||||
|
self.proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# apply the qkv linear layer to the hidden state
|
||||||
|
qkv = self.qkv(hidden_state)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape the query, key, and value tensors
|
||||||
|
_shape = (
|
||||||
|
hidden_state.shape[0],
|
||||||
|
self.num_heads,
|
||||||
|
self.embed_dim // self.num_heads,
|
||||||
|
)
|
||||||
|
query = query.view(*_shape)
|
||||||
|
key = key.view(*_shape)
|
||||||
|
value = value.view(*_shape)
|
||||||
|
|
||||||
|
# apply rotary positional embeddings
|
||||||
|
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
# calc maximum sequence length for any batch
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# execute sdpa
|
||||||
|
query = query.unsqueeze(0).transpose(1, 2)
|
||||||
|
key = key.unsqueeze(0).transpose(1, 2)
|
||||||
|
value = value.unsqueeze(0).transpose(1, 2)
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=None,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=None,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
# reshape output to original dimensions
|
||||||
|
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVisionMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVisionBlock(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Qwen2VLAttention(
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.norm1 = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.norm1",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.norm2 = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.norm2",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.mlp = Qwen2VLVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||||
|
) -> torch.Tensor:
|
||||||
|
norm1_out, residual = self.norm1(hidden_states)
|
||||||
|
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||||
|
hidden_states = attn_out + residual
|
||||||
|
norm2_out, residual = self.norm2(hidden_states)
|
||||||
|
hidden_states = hidden_states + self.mlp(norm2_out)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLPatchMerger(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
|
||||||
|
self.patch_merger_ln_q = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_q",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = F.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionModel(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_merge_size = config.spatial_merge_size
|
||||||
|
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||||
|
self.patch_embedding = nn.Conv3d(
|
||||||
|
in_channels=config.in_chans,
|
||||||
|
out_channels=config.embed_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=kernel_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
head_dim = config.embed_dim // config.num_heads
|
||||||
|
# TODO: replace with static positional embeddings once implemented
|
||||||
|
theta = 10000.0
|
||||||
|
dim = head_dim // 2
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen2VLVisionBlock(
|
||||||
|
prefix=f"{prefix}.blocks.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for i in range(config.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.merger = Qwen2VLPatchMerger(
|
||||||
|
prefix=f"{prefix}.merger",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.temporal_patch_size = config.temporal_patch_size
|
||||||
|
self.spatial_patch_size = config.spatial_patch_size
|
||||||
|
self.in_channels = config.in_channels
|
||||||
|
self.embed_dim = config.embed_dim
|
||||||
|
|
||||||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, _, hidden_size = hidden_state.shape
|
||||||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# reshape the input tensor for processing
|
||||||
|
shape = (
|
||||||
|
-1,
|
||||||
|
self.in_channels,
|
||||||
|
self.temporal_patch_size,
|
||||||
|
self.spatial_patch_size,
|
||||||
|
self.spatial_patch_size,
|
||||||
|
)
|
||||||
|
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
|
||||||
|
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
|
||||||
|
# TODO: revisit to see if we can avoid some of these reshapes
|
||||||
|
|
||||||
|
# find the position ids for the input tensor based on the grid_thw
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||||
|
hpos_ids = hpos_ids.flatten()
|
||||||
|
|
||||||
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||||
|
wpos_ids = wpos_ids.flatten()
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
|
||||||
|
# apply the positional embeddings to the position ids
|
||||||
|
seq = torch.arange(
|
||||||
|
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
|
||||||
|
|
||||||
|
# create a cu_seqlens tensor to be used in the attention mask
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||||
|
# iterately apply the blocks to the hidden states
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||||
|
|
||||||
|
# apply the final patch merger to the hidden states
|
||||||
|
hidden_states = self.merger(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
|
||||||
|
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment
|
||||||
|
if (
|
||||||
|
hasattr(config, "rope_scaling")
|
||||||
|
and config.rope_scaling is not None
|
||||||
|
and config.rope_scaling.get("type", None) == "default"
|
||||||
|
):
|
||||||
|
config.rope_scaling.update({"rope_type": "mrope"})
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.vision_start_token_id = config.vision_start_token_id
|
||||||
|
self.vision_end_token_id = config.vision_end_token_id
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.video_token_id = config.video_token_id
|
||||||
|
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.visual = Qwen2VisionModel(
|
||||||
|
prefix="visual", config=config.vision_config, weights=weights
|
||||||
|
)
|
||||||
|
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix="model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.device = weights.device
|
||||||
|
|
||||||
|
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
||||||
|
# modified to first find segments then initialize position ids for each segment
|
||||||
|
# Steps:
|
||||||
|
# locate all vision and text segments
|
||||||
|
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
||||||
|
# calculate `text_segment_lengths` for each text segment to be used as offset
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
# create position ids for each text segment
|
||||||
|
# combine all the position ids
|
||||||
|
# the final segment is the difference between the last vision segment and the end of the input
|
||||||
|
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||||
|
def get_position_ids(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if image_grid_thw is None:
|
||||||
|
return (
|
||||||
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(1, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
spatial_merge_size = self.spatial_merge_size
|
||||||
|
vision_start_token_id = self.vision_start_token_id
|
||||||
|
vision_end_token_id = self.vision_end_token_id
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
input_ids_len = input_ids.shape[0]
|
||||||
|
|
||||||
|
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
|
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
|
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||||
|
prev_vision_end = torch.cat(
|
||||||
|
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||||
|
)
|
||||||
|
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||||
|
vision_widths_max = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||||
|
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||||
|
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||||
|
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
for i, _ in enumerate(vision_segments):
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[i][0],
|
||||||
|
image_grid_thw[i][1] // spatial_merge_size,
|
||||||
|
image_grid_thw[i][2] // spatial_merge_size,
|
||||||
|
)
|
||||||
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||||
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||||
|
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||||
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
|
# offset by the position of the last vision segment
|
||||||
|
im = image_position_ids + vision_segment_lengths[i]
|
||||||
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
|
# create position ids for each text segment
|
||||||
|
text_ranges = [
|
||||||
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
+ text_segment_lengths[i]
|
||||||
|
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||||
|
]
|
||||||
|
|
||||||
|
full_llm_pos_ids_list = [
|
||||||
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
|
]
|
||||||
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
|
final_text_len = input_ids_len - vision_ends[-1]
|
||||||
|
if final_text_len > 0:
|
||||||
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
|
position_ids = (
|
||||||
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
|
)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# apply the visual model to the pixel values if they are provided
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.visual(
|
||||||
|
pixel_values, grid_thw=image_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
|
mask = torch.where(input_ids == self.image_token_id)
|
||||||
|
inputs_embeds[mask] = image_embeds
|
||||||
|
|
||||||
|
hidden_states = self.text_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
|
||||||
elif config.model_type == "mistral":
|
elif config.model_type == "mistral":
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
@ -16,7 +16,13 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
||||||
|
elif config.model_type == "gemma2":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
|
FlashGemma2ForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashGemma2ForCausalLM(prefix, config, weights)
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,489 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
FlashCausalLMBatch,
|
||||||
|
FlashCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||||
|
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
|
IDEFICS3_IMAGE_TOKEN = "<image>"
|
||||||
|
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||||
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||||
|
|
||||||
|
|
||||||
|
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
||||||
|
def _prompt_split_image(
|
||||||
|
*,
|
||||||
|
image_seq_len: int,
|
||||||
|
image_rows: int,
|
||||||
|
image_cols: int,
|
||||||
|
fake_token_around_image: str,
|
||||||
|
image_token: str,
|
||||||
|
global_img_token: str,
|
||||||
|
):
|
||||||
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||||
|
text_split_images = ""
|
||||||
|
for n_h in range(image_rows):
|
||||||
|
for n_w in range(image_cols):
|
||||||
|
text_split_images += (
|
||||||
|
f"{fake_token_around_image}"
|
||||||
|
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
)
|
||||||
|
text_split_images += "\n"
|
||||||
|
|
||||||
|
text_split_images += (
|
||||||
|
f"\n{fake_token_around_image}"
|
||||||
|
+ f"{global_img_token}"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
+ f"{fake_token_around_image}"
|
||||||
|
)
|
||||||
|
return text_split_images
|
||||||
|
|
||||||
|
|
||||||
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
|
"""
|
||||||
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`tuple`):
|
||||||
|
The size of the input image in the format (height, width).
|
||||||
|
grid_pinpoints (`List`):
|
||||||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||||
|
of the form `(height, width)`.
|
||||||
|
patch_size (`int`):
|
||||||
|
The size of each image patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The shape of the image patch grid in the format (width, height).
|
||||||
|
"""
|
||||||
|
if not isinstance(grid_pinpoints, list):
|
||||||
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||||
|
|
||||||
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||||
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||||
|
if config.model_type == "idefics2":
|
||||||
|
image_seq_len = 64
|
||||||
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||||
|
if processor.image_processor.do_image_splitting:
|
||||||
|
image_str *= 5
|
||||||
|
return image_str
|
||||||
|
if config.model_type == "idefics3":
|
||||||
|
# TODO: implement this in a more general way
|
||||||
|
n_rows = image_input["rows"][0][image_id]
|
||||||
|
n_cols = image_input["cols"][0][image_id]
|
||||||
|
image_seq_len = int(
|
||||||
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||||
|
/ (config.scale_factor**2)
|
||||||
|
)
|
||||||
|
image_str = _prompt_split_image(
|
||||||
|
image_seq_len=image_seq_len,
|
||||||
|
image_rows=n_rows,
|
||||||
|
image_cols=n_cols,
|
||||||
|
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||||
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
|
)
|
||||||
|
return image_str
|
||||||
|
elif config.model_type == "llava_next":
|
||||||
|
height, width = image_input["image_sizes"][image_id]
|
||||||
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||||
|
)
|
||||||
|
return "<image>" * num_features
|
||||||
|
|
||||||
|
elif config.model_type == "paligemma":
|
||||||
|
return "<image>" * config.text_config.num_image_tokens
|
||||||
|
elif config.model_type == "qwen2_vl":
|
||||||
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||||
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
|
padding = "<|image_pad|>" * num_pads
|
||||||
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
elif config.model_type == "qwen2_5_vl":
|
||||||
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||||
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
|
padding = "<|image_pad|>" * num_pads
|
||||||
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
elif config.model_type == "gemma3":
|
||||||
|
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||||
|
# and calculating the number of image tokens
|
||||||
|
num_pads = 256
|
||||||
|
padding = "<image_soft_token>" * num_pads
|
||||||
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def image_text_replacement_fixup(config, text: str) -> str:
|
||||||
|
if config.model_type == "idefics2":
|
||||||
|
return text.replace(
|
||||||
|
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def get_unpadded_features(
|
||||||
|
original_height: int,
|
||||||
|
original_width: int,
|
||||||
|
npatches: int,
|
||||||
|
num_patch_height: int,
|
||||||
|
num_patch_width: int,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
current_height = npatches * num_patch_height
|
||||||
|
current_width = npatches * num_patch_width
|
||||||
|
|
||||||
|
aspect_ratio: float = original_width / original_height
|
||||||
|
current_aspect_ratio: float = current_width / current_height
|
||||||
|
|
||||||
|
if aspect_ratio > current_aspect_ratio:
|
||||||
|
new_height = (original_height * current_width) // original_width
|
||||||
|
padding = (current_height - new_height) // 2
|
||||||
|
current_height = current_height - (2 * padding)
|
||||||
|
else:
|
||||||
|
new_width = (original_width * current_height) // original_height
|
||||||
|
padding = (current_width - new_width) // 2
|
||||||
|
current_width = current_width - (2 * padding)
|
||||||
|
|
||||||
|
unpadded_features = current_height * current_width
|
||||||
|
newline_features = current_height
|
||||||
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
|
||||||
|
def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
|
# From config
|
||||||
|
# Hardcoded for CLIP for now
|
||||||
|
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||||||
|
image_grid_pinpoints = config.image_grid_pinpoints
|
||||||
|
image_size = config.vision_config.image_size
|
||||||
|
patch_size = config.vision_config.patch_size
|
||||||
|
|
||||||
|
assert image_size % patch_size == 0
|
||||||
|
|
||||||
|
npatches = image_size // patch_size
|
||||||
|
|
||||||
|
# Dimensions are intentionally swapped to be bug-compatible with
|
||||||
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||||
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
[height, width],
|
||||||
|
image_grid_pinpoints,
|
||||||
|
image_size,
|
||||||
|
)
|
||||||
|
unpadded_features, newline_features = get_unpadded_features(
|
||||||
|
height, width, npatches, num_patch_height, num_patch_width
|
||||||
|
)
|
||||||
|
# The base patch covers the entire image
|
||||||
|
base_features = npatches**2
|
||||||
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
|
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
def concatenate(cls, batches):
|
||||||
|
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("filter")
|
||||||
|
def filter(self, request_ids: List[int]):
|
||||||
|
batch = super().filter(request_ids)
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||||
|
):
|
||||||
|
# Process images first. We need all of them so that the processor
|
||||||
|
# can make the image splits the same size. And we need the final
|
||||||
|
# sizes to insert correct number of image tokens.
|
||||||
|
images = []
|
||||||
|
for r in requests:
|
||||||
|
for chunk in r.input_chunks.chunks:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
pass
|
||||||
|
elif chunk_type == "image":
|
||||||
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
|
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
||||||
|
# default warmup image is 20x20
|
||||||
|
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||||
|
if image.width <= 20:
|
||||||
|
w = image.width * 2
|
||||||
|
h = image.height * 2
|
||||||
|
image = image.resize((w, h))
|
||||||
|
|
||||||
|
if config.model_type == "llava_next":
|
||||||
|
images.append(image)
|
||||||
|
elif config.model_type == "gemma3":
|
||||||
|
images.append(image)
|
||||||
|
else:
|
||||||
|
images.append([image])
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
|
if images:
|
||||||
|
kwargs = {}
|
||||||
|
if (
|
||||||
|
hasattr(processor, "image_processor_class")
|
||||||
|
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||||
|
):
|
||||||
|
kwargs["return_row_col_info"] = True
|
||||||
|
|
||||||
|
image_inputs = processor.image_processor(
|
||||||
|
images, return_tensors="pt", **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
|
||||||
|
batch_tokenized_inputs = []
|
||||||
|
max_length = 0
|
||||||
|
image_id = 0
|
||||||
|
for r in requests:
|
||||||
|
full_text = ""
|
||||||
|
for chunk in r.input_chunks.chunks:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
full_text += chunk.text
|
||||||
|
elif chunk_type == "image":
|
||||||
|
full_text += image_text_replacement(
|
||||||
|
processor, image_inputs, config, image_id
|
||||||
|
)
|
||||||
|
image_id += 1
|
||||||
|
|
||||||
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
|
input_ids = tokenizer(
|
||||||
|
full_text,
|
||||||
|
truncation=True,
|
||||||
|
max_length=r.truncate,
|
||||||
|
add_special_tokens=r.add_special_tokens,
|
||||||
|
)["input_ids"]
|
||||||
|
max_length = max(max_length, len(input_ids))
|
||||||
|
batch_tokenized_inputs.append(input_ids)
|
||||||
|
|
||||||
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb_processor(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
processor,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "FlashVlmCausalLMBatch":
|
||||||
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
|
pb.requests, tokenizer, processor, config
|
||||||
|
)
|
||||||
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
if image_inputs is not None:
|
||||||
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
|
if "pixel_attention_mask" in image_inputs:
|
||||||
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if "image_sizes" in image_inputs:
|
||||||
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.image_sizes = None
|
||||||
|
if "image_grid_thw" in image_inputs:
|
||||||
|
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
else:
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class FlashVlmCausalLM(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
processor_class=AutoProcessor,
|
||||||
|
processor_kwargs=None,
|
||||||
|
batch_class=FlashVlmCausalLMBatch,
|
||||||
|
revision,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||||
|
if processor_kwargs is None:
|
||||||
|
processor_kwargs = {}
|
||||||
|
self.processor = processor_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
self.batch_class = batch_class
|
||||||
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
# FIXME: VLM do not work with context chunking yet
|
||||||
|
support_chunking=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
|
||||||
|
return self.batch_class
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: FlashVlmCausalLMBatch,
|
||||||
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_current_length
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
cache_lengths_tensor = (
|
||||||
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
|
max_s = batch.max_current_length
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||||
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
|
position_ids = self.model.get_position_ids(
|
||||||
|
input_ids, batch.image_grid_thw
|
||||||
|
)
|
||||||
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
|
slots = slots_pad
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
|
image_sizes=batch.image_sizes,
|
||||||
|
image_grid_thw=batch.image_grid_thw,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
|
if batch.image_grid_thw is not None:
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
return logits, speculative_logits
|
@ -1,53 +1,31 @@
|
|||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
ATTENTION = os.getenv("ATTENTION", "default")
|
ATTENTION = os.getenv("ATTENTION", "default")
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
||||||
"1",
|
"1",
|
||||||
"true",
|
"true",
|
||||||
}
|
}
|
||||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer", "default"}
|
_expected = {"paged", "default"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
|
||||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||||
assert TGI_WIGGLE_ROOM > 0
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
assert TGI_WIGGLE_ROOM < 1
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int
|
BLOCK_SIZE: int
|
||||||
if ATTENTION == "flashdecoding":
|
|
||||||
BLOCK_SIZE = 256
|
|
||||||
elif ATTENTION == "flashinfer":
|
|
||||||
BLOCK_SIZE = 1
|
|
||||||
else:
|
|
||||||
BLOCK_SIZE = 16
|
|
||||||
|
|
||||||
# This is overridden by the cli
|
BLOCK_SIZE = 128
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
|
||||||
if cuda_graphs is not None:
|
|
||||||
try:
|
|
||||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cuda_graphs = None
|
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
|
||||||
|
|
||||||
# This is overridden at model loading.
|
# This is overridden at model loading.
|
||||||
global MODEL_ID
|
global MODEL_ID
|
||||||
|
@ -34,9 +34,6 @@ from text_generation_server.utils import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.utils.quantization import get_loader
|
from text_generation_server.utils.quantization import get_loader
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -596,22 +593,8 @@ class IdeficsCausalLM(Model):
|
|||||||
):
|
):
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
device = torch.device("hpu")
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
# 9b seems to work correctly enough in float16, but 80b seems
|
|
||||||
# to be really saturating for f16.
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
# Float16 doesn't exist on target.
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
self.device, self.dtype = device, dtype
|
self.device, self.dtype = device, dtype
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
|
@ -1,28 +1,30 @@
|
|||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from typing import Iterable, Optional, Tuple, List, Dict
|
from typing import Iterable, Optional, Tuple, List, Dict
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
|
||||||
block_tables_to_ragged,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
|
FlashVlmCausalLMBatch,
|
||||||
|
FlashVlmCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MllamaCausalLMBatch(VlmCausalLMBatch):
|
class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||||
image_indices: List[int] = 42
|
image_indices: List[int] = 42
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None
|
aspect_ratio_ids: Optional[torch.Tensor] = None
|
||||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||||
@ -158,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
config,
|
config,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "FlashVlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor, config
|
pb.requests, tokenizer, processor, config
|
||||||
)
|
)
|
||||||
@ -167,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||||
max=config.text_config.vocab_size - 1
|
max=config.text_config.vocab_size - 1
|
||||||
)
|
)
|
||||||
|
if isinstance(batch.input_ids, list):
|
||||||
|
if len(batch) > 1:
|
||||||
|
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids[0]
|
||||||
|
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||||
|
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
@ -187,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class MllamaCausalLM(VlmCausalLM):
|
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: FlashMllamaCausalLMBatch,
|
||||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
@ -202,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
speculative_ids = batch.speculative_ids
|
||||||
@ -221,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
cache_lengths_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
@ -244,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
@ -254,41 +263,10 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
# Try to find an associated cuda graph
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
|
||||||
if sorted_padded_bs:
|
|
||||||
# Get associated cuda graph
|
|
||||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
|
||||||
else:
|
|
||||||
cuda_graph = None
|
|
||||||
if (
|
|
||||||
cu_seqlen_prefill is not None
|
|
||||||
or cuda_graph is None
|
|
||||||
# Only run cuda graphs when there's no images.
|
|
||||||
or batch.cross_attention_states is not None
|
|
||||||
):
|
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
|
||||||
if PREFIX_CACHING:
|
|
||||||
block_tables = block_tables_to_ragged(
|
|
||||||
block_tables=block_tables,
|
|
||||||
input_lengths=batch.input_lengths,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
)
|
|
||||||
with self._forward_context(
|
|
||||||
block_tables=block_tables,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
input_lengths_tensor=input_lengths,
|
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
|
||||||
):
|
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
|
||||||
max_k=max_k,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
@ -301,57 +279,30 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
|
|
||||||
cross_attention_states = batch.cross_attention_states
|
cross_attention_states = batch.cross_attention_states
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if htorch.utils.internal.is_lazy():
|
||||||
|
kwargs["bypass_hpu_graphs"] = False
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
|
slots = slots_pad
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
max_s=max_s,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
adapter_data=adapter_data,
|
# TODO list
|
||||||
|
adapter_data=None,
|
||||||
image_indices=batch.image_indices[:],
|
image_indices=batch.image_indices[:],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
|
||||||
# Static inputs are potentially padded
|
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
block_tables = block_tables_to_ragged(
|
|
||||||
block_tables=block_tables,
|
|
||||||
input_lengths=batch.input_lengths,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
)
|
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
|
||||||
else:
|
|
||||||
cuda_graph["block_tables"][
|
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
|
||||||
] = block_tables
|
|
||||||
cuda_graph["slots"].fill_(0)
|
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
||||||
cuda_graph["input_lengths"].zero_()
|
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
|
||||||
input_lengths + prefix_lens_tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
# Replay the graph
|
|
||||||
cuda_graph["graph"].replay()
|
|
||||||
|
|
||||||
# Slice output to the correct shape
|
|
||||||
speculative_logits = (
|
|
||||||
cuda_graph["speculative_logits"][:bs]
|
|
||||||
if cuda_graph["speculative_logits"] is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logits = cuda_graph["logits"][:bs]
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
@ -33,6 +33,7 @@ class Model(ABC):
|
|||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||||
|
support_chunking: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
@ -4,8 +4,8 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
image_text_replacement,
|
image_text_replacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
class PaliGemmaBatch(FlashVlmCausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||||
|
@ -10,7 +10,6 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
)
|
)
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
@ -555,20 +554,9 @@ class Seq2SeqLM(Model):
|
|||||||
):
|
):
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device("hpu")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = default_dtype if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
# Float16 doesn't exist on target.
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
config = config_class.from_pretrained(
|
config = config_class.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
@ -600,7 +588,7 @@ class Seq2SeqLM(Model):
|
|||||||
aliases=aliases,
|
aliases=aliases,
|
||||||
weights_loader=weights_loader,
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if config.quantize in ["awq", "gptq"]:
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = model_class(config, weights)
|
model = model_class(config, weights)
|
||||||
|
@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
|
|||||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
|
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
|
||||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||||
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
|
|
||||||
if max_batch_size_str is not None:
|
|
||||||
MAX_BATCH_SIZE = int(max_batch_size_str)
|
|
||||||
else:
|
|
||||||
raise ValueError("MAX_BATCH_SIZE is not set")
|
|
||||||
|
|
||||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||||
@ -1467,6 +1463,12 @@ class VlmCausalLM(Model):
|
|||||||
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
||||||
max_input_tokens = request.max_input_tokens
|
max_input_tokens = request.max_input_tokens
|
||||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||||
|
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
|
||||||
|
if max_batch_size_str is not None:
|
||||||
|
MAX_BATCH_SIZE = int(max_batch_size_str)
|
||||||
|
else:
|
||||||
|
raise ValueError("MAX_BATCH_SIZE is not set")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# max prefill batch size warmup
|
# max prefill batch size warmup
|
||||||
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
|
||||||
|
@ -18,22 +18,27 @@ from text_generation_server.interceptor import ExceptionInterceptor
|
|||||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id
|
from text_generation_server.models.globals import set_model_id, ATTENTION
|
||||||
from text_generation_server.models.globals import set_adapter_to_index
|
from text_generation_server.models.globals import set_adapter_to_index
|
||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||||
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
|
FlashVlmCausalLMBatch,
|
||||||
|
)
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {
|
VLM_BATCH_TYPES = {
|
||||||
PaliGemmaBatch,
|
PaliGemmaBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
IdeficsCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
|
FlashMllamaCausalLMBatch,
|
||||||
}
|
}
|
||||||
except (ImportError, NotImplementedError):
|
except (ImportError, NotImplementedError):
|
||||||
# These imports can fail on CPU/Non flash.
|
# These imports can fail on CPU/Non flash.
|
||||||
@ -103,6 +108,42 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
if ATTENTION == "paged":
|
||||||
|
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||||
|
if (
|
||||||
|
self.model.batch_type in VLM_BATCH_TYPES
|
||||||
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
|
request.batch,
|
||||||
|
self.model.tokenizer,
|
||||||
|
self.model.processor,
|
||||||
|
self.model.model.config,
|
||||||
|
self.model.dtype,
|
||||||
|
self.model.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = self.model.batch_type.from_pb(
|
||||||
|
request.batch,
|
||||||
|
self.model.tokenizer,
|
||||||
|
self.model.dtype,
|
||||||
|
self.model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override default values with None for clearer semantics.
|
||||||
|
max_input_tokens = (
|
||||||
|
request.max_input_tokens
|
||||||
|
if request.HasField("max_input_tokens")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
max_total_tokens = (
|
||||||
|
request.max_total_tokens
|
||||||
|
if request.HasField("max_total_tokens")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
|
self.model.warmup(batch, max_input_tokens, max_total_tokens)
|
||||||
|
)
|
||||||
|
else:
|
||||||
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
self.model.warmup(request)
|
self.model.warmup(request)
|
||||||
)
|
)
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# Tensor Parallelism settings
|
# Tensor Parallelism settings
|
||||||
RANK = int(os.getenv("RANK", "0"))
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
|
||||||
# CUDA memory fraction
|
|
||||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
|
||||||
|
|
||||||
|
|
||||||
class FakeBarrier:
|
class FakeBarrier:
|
||||||
@ -17,10 +15,11 @@ class FakeBarrier:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FakeGroup:
|
class FakeGroup(ProcessGroup):
|
||||||
def __init__(self, rank, size):
|
def __init__(self, rank, size):
|
||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._size = size
|
self._size = size
|
||||||
|
super().__init__(rank, size)
|
||||||
|
|
||||||
def allreduce(self, *args, **kwargs):
|
def allreduce(self, *args, **kwargs):
|
||||||
return FakeBarrier()
|
return FakeBarrier()
|
||||||
@ -42,42 +41,11 @@ class FakeGroup:
|
|||||||
def rank(self):
|
def rank(self):
|
||||||
return self._rank
|
return self._rank
|
||||||
|
|
||||||
|
def _get_backend_name(self):
|
||||||
|
return "fake"
|
||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
def initialize_torch_distributed():
|
||||||
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
|
|
||||||
options = None
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
from torch.distributed import ProcessGroupNCCL
|
|
||||||
|
|
||||||
# Set the device id.
|
|
||||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
|
||||||
device = RANK % torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
|
||||||
backend = "nccl"
|
|
||||||
options = ProcessGroupNCCL.Options()
|
|
||||||
options.is_high_priority_stream = True
|
|
||||||
options._timeout = timedelta(seconds=60)
|
|
||||||
elif torch.hpu.is_available():
|
|
||||||
backend = "hccl"
|
|
||||||
n_hpus = torch.hpu.device_count()
|
|
||||||
if world_size > n_hpus:
|
|
||||||
raise ValueError(
|
|
||||||
f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import oneccl_bindings_for_pytorch # noqa: F401
|
|
||||||
|
|
||||||
backend = "ccl"
|
|
||||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
|
||||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
|
||||||
except ImportError:
|
|
||||||
backend = "gloo"
|
|
||||||
options = None
|
|
||||||
|
|
||||||
if WORLD_SIZE == 1:
|
if WORLD_SIZE == 1:
|
||||||
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
else:
|
else:
|
||||||
@ -87,11 +55,10 @@ def initialize_torch_distributed():
|
|||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=backend,
|
backend="hccl",
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
@ -1,75 +1,28 @@
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
import importlib.util
|
def get_hpu_free_memory(device, memory_fraction):
|
||||||
|
from habana_frameworks.torch.hpu import memory_stats
|
||||||
|
|
||||||
|
|
||||||
def is_ipex_available():
|
|
||||||
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_free_memory(device, memory_fraction):
|
|
||||||
total_free_memory, _ = torch.cuda.mem_get_info(device)
|
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
|
|
||||||
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
|
|
||||||
return free_memory
|
|
||||||
|
|
||||||
|
|
||||||
def get_xpu_free_memory(device, memory_fraction):
|
|
||||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
|
||||||
device_id = device.index
|
device_id = device.index
|
||||||
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
|
mem_stats = memory_stats(device_id)
|
||||||
|
logger.info(f"mem_stats: {mem_stats}")
|
||||||
|
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||||
free_memory = max(
|
free_memory = max(
|
||||||
0,
|
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||||
int(
|
|
||||||
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return free_memory
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_free_memory(device, memory_fraction):
|
def synchronize_hpu(device):
|
||||||
import psutil
|
torch.hpu.synchronize()
|
||||||
from text_generation_server.utils.dist import WORLD_SIZE
|
|
||||||
|
|
||||||
mem = psutil.virtual_memory()
|
|
||||||
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
|
|
||||||
return free_memory
|
|
||||||
|
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
SYSTEM = None
|
|
||||||
if torch.version.hip is not None:
|
|
||||||
SYSTEM = "rocm"
|
|
||||||
empty_cache = torch.cuda.empty_cache
|
|
||||||
synchronize = torch.cuda.synchronize
|
|
||||||
get_free_memory = get_cuda_free_memory
|
|
||||||
elif torch.version.cuda is not None and torch.cuda.is_available():
|
|
||||||
SYSTEM = "cuda"
|
|
||||||
empty_cache = torch.cuda.empty_cache
|
|
||||||
synchronize = torch.cuda.synchronize
|
|
||||||
get_free_memory = get_cuda_free_memory
|
|
||||||
elif is_ipex_available():
|
|
||||||
SYSTEM = "ipex"
|
|
||||||
import intel_extension_for_pytorch # noqa: F401
|
|
||||||
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
empty_cache = torch.xpu.empty_cache
|
|
||||||
synchronize = torch.xpu.synchronize
|
|
||||||
get_free_memory = get_xpu_free_memory
|
|
||||||
else:
|
|
||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = synchronize_hpu
|
||||||
get_free_memory = get_cpu_free_memory
|
get_free_memory = get_hpu_free_memory
|
||||||
else:
|
|
||||||
SYSTEM = "cpu"
|
|
||||||
|
|
||||||
empty_cache = noop
|
|
||||||
synchronize = noop
|
|
||||||
get_free_memory = get_cpu_free_memory
|
|
||||||
logger.info(f"Detected system {SYSTEM}")
|
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from hf_kernels import load_kernel as hf_load_kernel
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
|
def load_kernel(*, module: str, repo_id: str):
|
||||||
|
"""
|
||||||
|
Load a kernel. First try to load it as the given module (e.g. for
|
||||||
|
local development), falling back to a locked Hub kernel.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
m = importlib.import_module(module)
|
||||||
|
log_once(logger.info, f"Using local module for `{module}`")
|
||||||
|
return m
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return hf_load_kernel(repo_id=repo_id)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["load_kernel"]
|
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
SUPPORT_CHUNKING: Optional[bool] = None
|
||||||
|
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_support_chunking(support_chunking: bool):
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
SUPPORT_CHUNKING = support_chunking
|
||||||
|
|
||||||
|
|
||||||
|
def get_support_chunking() -> bool:
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
return SUPPORT_CHUNKING
|
||||||
|
|
||||||
|
|
||||||
|
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_prefill_tokens() -> int:
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
return MAX_PREFILL_TOKENS
|
@ -4,9 +4,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
|
||||||
WeightsLoader,
|
WeightsLoader,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,24 +127,6 @@ def get_loader(
|
|||||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=quantizer_config.bits,
|
|
||||||
groupsize=quantizer_config.groupsize,
|
|
||||||
quant_method=quantizer_config.quant_method,
|
|
||||||
quantize=quantize,
|
|
||||||
sym=quantizer_config.sym,
|
|
||||||
):
|
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
|
||||||
|
|
||||||
return GPTQMarlinWeightsLoader(
|
|
||||||
bits=quantizer_config.bits,
|
|
||||||
desc_act=quantizer_config.desc_act,
|
|
||||||
groupsize=quantizer_config.groupsize,
|
|
||||||
quant_method=quantizer_config.quant_method,
|
|
||||||
quantize=quantize,
|
|
||||||
sym=quantizer_config.sym,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return GPTQWeightsLoader(
|
return GPTQWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
desc_act=quantizer_config.desc_act,
|
desc_act=quantizer_config.desc_act,
|
||||||
@ -155,39 +135,6 @@ def get_loader(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
sym=quantizer_config.sym,
|
sym=quantizer_config.sym,
|
||||||
)
|
)
|
||||||
elif quantize == "bitsandbytes":
|
|
||||||
from text_generation_server.layers.bnb import BNBWeight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(BNBWeight)
|
|
||||||
elif quantize == "bitsandbytes-fp4":
|
|
||||||
from text_generation_server.layers.bnb import BNBFP4Weight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(BNBFP4Weight)
|
|
||||||
elif quantize == "bitsandbytes-nf4":
|
|
||||||
from text_generation_server.layers.bnb import BNBNF4Weight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(BNBNF4Weight)
|
|
||||||
elif quantize == "eetq":
|
|
||||||
from text_generation_server.layers.eetq import EETQWeight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(EETQWeight)
|
|
||||||
elif quantize == "exl2":
|
|
||||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
|
||||||
|
|
||||||
return Exl2WeightsLoader()
|
|
||||||
elif quantize == "marlin":
|
|
||||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
|
||||||
|
|
||||||
# TODO: improve check once we have one config type per quantize value
|
|
||||||
if not isinstance(quantizer_config, _QuantizerConfig):
|
|
||||||
raise ValueError(
|
|
||||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
|
||||||
)
|
|
||||||
|
|
||||||
return MarlinWeightsLoader(
|
|
||||||
bits=quantizer_config.bits,
|
|
||||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
|
||||||
)
|
|
||||||
elif quantize == "fp8" or quantize is None:
|
elif quantize == "fp8" or quantize is None:
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
|
@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Union, Type
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class WeightsLoader(ABC):
|
class WeightsLoader(ABC):
|
||||||
"""
|
"""
|
||||||
@ -88,11 +86,8 @@ class UnquantizedWeight(Weight):
|
|||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
from text_generation_server.layers.linear import FastLinear
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
return FastLinearROCm(self.weight, bias)
|
|
||||||
else:
|
|
||||||
return FastLinear(self.weight, bias)
|
return FastLinear(self.weight, bias)
|
||||||
|
|
||||||
|
|
||||||
@ -197,7 +192,7 @@ class Weights:
|
|||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
return slice_
|
return slice_
|
||||||
|
|
||||||
def _has_tensor(self, tensor_name: str):
|
def has_tensor(self, tensor_name: str):
|
||||||
try:
|
try:
|
||||||
self.get_filename(tensor_name)
|
self.get_filename(tensor_name)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -207,7 +202,9 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
def get_tensor(
|
||||||
|
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
@ -218,6 +215,7 @@ class Weights:
|
|||||||
tensor.dtype
|
tensor.dtype
|
||||||
not in [
|
not in [
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
|
torch.int8,
|
||||||
torch.int16,
|
torch.int16,
|
||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
@ -253,7 +251,8 @@ class Weights:
|
|||||||
# u4 which are disguised as int32. exl2 uses int16.
|
# u4 which are disguised as int32. exl2 uses int16.
|
||||||
# FP8 uses torch.float8_e4m3fn.
|
# FP8 uses torch.float8_e4m3fn.
|
||||||
if (
|
if (
|
||||||
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
tensor.dtype
|
||||||
|
not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
|
||||||
and to_dtype
|
and to_dtype
|
||||||
):
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
@ -329,6 +328,7 @@ class Weights:
|
|||||||
tensor.dtype
|
tensor.dtype
|
||||||
not in [
|
not in [
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
|
torch.int8,
|
||||||
torch.int16,
|
torch.int16,
|
||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
|
@ -2,7 +2,7 @@ use std::sync::Arc;
|
|||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
use crate::radix::RadixAllocator;
|
use crate::radix::RadixAllocator;
|
||||||
|
use text_generation_router::usage_stats::Env;
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct BlockAllocation {
|
pub struct BlockAllocation {
|
||||||
pub allocation_id: u64,
|
pub allocation_id: u64,
|
||||||
@ -141,6 +141,7 @@ pub struct SimpleAllocator {
|
|||||||
free_blocks: Vec<u32>,
|
free_blocks: Vec<u32>,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
is_hpu_device: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SimpleAllocator {
|
impl SimpleAllocator {
|
||||||
@ -150,6 +151,7 @@ impl SimpleAllocator {
|
|||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
free_blocks: (1..blocks).collect(),
|
free_blocks: (1..blocks).collect(),
|
||||||
window_size,
|
window_size,
|
||||||
|
is_hpu_device: Env::new().is_hpu_device(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -179,9 +181,15 @@ impl Allocator for SimpleAllocator {
|
|||||||
if required_blocks > self.free_blocks.len() as u32 {
|
if required_blocks > self.free_blocks.len() as u32 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let blocks = self
|
if self.is_hpu_device {
|
||||||
|
self.free_blocks.sort_by(|a, b| b.cmp(a));
|
||||||
|
}
|
||||||
|
let mut blocks = self
|
||||||
.free_blocks
|
.free_blocks
|
||||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||||
|
if self.is_hpu_device {
|
||||||
|
blocks.sort();
|
||||||
|
}
|
||||||
let mut slots =
|
let mut slots =
|
||||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||||
|
|
||||||
|
@ -28,8 +28,8 @@ impl Env {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_hpu_device(&self) -> bool {
|
pub fn should_start_a_single_hpu_shard(&self) -> bool {
|
||||||
self.hpu_env != "N/A"
|
self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1559,7 +1559,7 @@ fn spawn_shards(
|
|||||||
) -> Result<(), LauncherError> {
|
) -> Result<(), LauncherError> {
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
if rank != 0 && env_runtime::Env::new().is_hpu_device() {
|
if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() {
|
||||||
tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server");
|
tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -1639,7 +1639,7 @@ fn spawn_shards(
|
|||||||
if shard_ready == num_shard {
|
if shard_ready == num_shard {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if env_runtime::Env::new().is_hpu_device() {
|
if env_runtime::Env::new().should_start_a_single_hpu_shard() {
|
||||||
tracing::info!("HPU detected, shard is ready");
|
tracing::info!("HPU detected, shard is ready");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -157,6 +157,7 @@ pub struct Env {
|
|||||||
docker_label: &'static str,
|
docker_label: &'static str,
|
||||||
nvidia_info: Option<Vec<NvidiaSmiInfo>>,
|
nvidia_info: Option<Vec<NvidiaSmiInfo>>,
|
||||||
xpu_info: Option<Vec<XpuSmiInfo>>,
|
xpu_info: Option<Vec<XpuSmiInfo>>,
|
||||||
|
hpu_info: Option<Vec<HpuSmiInfo>>,
|
||||||
system_env: SystemInfo,
|
system_env: SystemInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,6 +290,60 @@ impl XpuSmiInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Clone)]
|
||||||
|
struct HpuSmiInfo {
|
||||||
|
name: String,
|
||||||
|
pci_bus_id: String,
|
||||||
|
driver_version: String,
|
||||||
|
temperature: String,
|
||||||
|
utilization: String,
|
||||||
|
memory_total: String,
|
||||||
|
memory_free: String,
|
||||||
|
memory_used: String,
|
||||||
|
power_draw_instant: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HpuSmiInfo {
|
||||||
|
fn new() -> Option<Vec<HpuSmiInfo>> {
|
||||||
|
let output = Command::new("hl-smi")
|
||||||
|
.args([
|
||||||
|
"--query-aip=name,bus_id,driver_version,temperature.aip,utilization.aip,memory.total,memory.free,memory.used,power.draw",
|
||||||
|
"--format=csv"
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = String::from_utf8(output.stdout).ok()?;
|
||||||
|
|
||||||
|
let mut rdr = ReaderBuilder::new()
|
||||||
|
.has_headers(true)
|
||||||
|
.from_reader(stdout.as_bytes());
|
||||||
|
|
||||||
|
let mut infos = Vec::new();
|
||||||
|
|
||||||
|
for result in rdr.records() {
|
||||||
|
let record = result.ok()?;
|
||||||
|
infos.push(HpuSmiInfo {
|
||||||
|
name: record[0].to_string(),
|
||||||
|
pci_bus_id: record[1].to_string(),
|
||||||
|
driver_version: record[2].to_string(),
|
||||||
|
temperature: record[3].to_string(),
|
||||||
|
utilization: record[4].to_string(),
|
||||||
|
memory_total: record[5].to_string(),
|
||||||
|
memory_free: record[6].to_string(),
|
||||||
|
memory_used: record[7].to_string(),
|
||||||
|
power_draw_instant: record[8].to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(infos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Debug, Clone)]
|
#[derive(Serialize, Debug, Clone)]
|
||||||
pub struct SystemInfo {
|
pub struct SystemInfo {
|
||||||
cpu_count: usize,
|
cpu_count: usize,
|
||||||
@ -335,10 +390,14 @@ impl Env {
|
|||||||
system_env: SystemInfo::new(),
|
system_env: SystemInfo::new(),
|
||||||
nvidia_info: NvidiaSmiInfo::new(),
|
nvidia_info: NvidiaSmiInfo::new(),
|
||||||
xpu_info: XpuSmiInfo::new(),
|
xpu_info: XpuSmiInfo::new(),
|
||||||
|
hpu_info: HpuSmiInfo::new(),
|
||||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||||
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
|
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn is_hpu_device(&self) -> bool {
|
||||||
|
self.hpu_info.is_some()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_container() -> io::Result<bool> {
|
pub fn is_container() -> io::Result<bool> {
|
||||||
|
Loading…
Reference in New Issue
Block a user